[V1] LoRA - Add triton kernels for V1 (#13096)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
0967110e42
commit
5ff0d32580
@ -23,6 +23,7 @@ from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
@ -171,6 +172,8 @@ class OpType(Enum):
|
||||
SGMV_EXPAND = auto()
|
||||
BGMV_EXPAND = auto()
|
||||
BGMV_EXPAND_SLICE = auto()
|
||||
V1_SHRINK = auto()
|
||||
V1_EXPAND = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(s: str) -> "OpType":
|
||||
@ -184,28 +187,43 @@ class OpType(Enum):
|
||||
return OpType.BGMV_EXPAND
|
||||
if s.lower() == "bgmv_expand_slice":
|
||||
return OpType.BGMV_EXPAND_SLICE
|
||||
if s.lower() == "v1_shrink":
|
||||
return OpType.V1_SHRINK
|
||||
if s.lower() == "v1_expand":
|
||||
return OpType.V1_EXPAND
|
||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||
|
||||
def is_shrink_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
|
||||
return self in [
|
||||
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
|
||||
]
|
||||
|
||||
def is_expand_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
|
||||
return self in [
|
||||
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_prefill_op(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
|
||||
return self in [
|
||||
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
|
||||
OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_decode_op(self) -> bool:
|
||||
return self in [
|
||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
|
||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
|
||||
OpType.V1_SHRINK, OpType.V1_EXPAND
|
||||
]
|
||||
|
||||
def is_expand_slice_fn(self) -> bool:
|
||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||
|
||||
def num_slices(self) -> list[int]:
|
||||
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
||||
# SGMV kernels supports slices
|
||||
if self in [
|
||||
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
|
||||
OpType.V1_EXPAND
|
||||
]:
|
||||
# SGMV kernels and v1 kernels supports slices
|
||||
return [1, 2, 3]
|
||||
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
||||
return [1]
|
||||
@ -250,11 +268,13 @@ class OpType(Enum):
|
||||
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
||||
|
||||
b_shape = (num_loras, n, k) # col-major
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
# SGMV shrink supports num_slices inherently in the kernel
|
||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
||||
# SGMV shrink and V1 shrink kernels support num_slices inherently
|
||||
# in the kernel.
|
||||
return ((m, k), b_shape, (num_slices, m, n))
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
# SGMV expand supports num_slices inherently in the kernel
|
||||
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
||||
# SGMV expand and V1 expand kernels support num_slices inherently
|
||||
# in the kernel
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
@ -281,25 +301,30 @@ class OpType(Enum):
|
||||
return bgmv_expand
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return emulate_bgmv_expand_slice
|
||||
if self == OpType.V1_SHRINK:
|
||||
return v1_shrink
|
||||
if self == OpType.V1_EXPAND:
|
||||
return v1_expand
|
||||
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
||||
lora_weights: list[torch.Tensor],
|
||||
**kwargs) -> Callable:
|
||||
"""Each benchmark operation expected the input, lora_weights and outputs
|
||||
"""Each benchmark operation expects the input, lora_weights and outputs
|
||||
in a slightly different format. Refer to self.matmul_shapes().
|
||||
run_ref_group_gemm accounts for those differences in executing a
|
||||
reference group gemm for correctness testing.
|
||||
"""
|
||||
w_dtype = lora_weights[0].dtype
|
||||
num_slices = len(lora_weights)
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
||||
for slice_idx in range(num_slices):
|
||||
ref_group_gemm(ref_out=output[slice_idx, :],
|
||||
input=input,
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
@ -308,19 +333,19 @@ class OpType(Enum):
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
elif self == OpType.BGMV_SHRINK:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input,
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
elif self == OpType.BGMV_EXPAND:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input.clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
elif self == OpType.BGMV_EXPAND_SLICE:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
@ -329,7 +354,8 @@ class OpType(Enum):
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -390,6 +416,8 @@ class BenchmarkTensors:
|
||||
seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
# v1 kernel metadata
|
||||
v1_kernel_meta: Optional[V1KernelMeta] = None
|
||||
|
||||
def io_types(self) -> str:
|
||||
return (f"{dtype_to_str(self.input.dtype)}x"
|
||||
@ -432,10 +460,19 @@ class BenchmarkTensors:
|
||||
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
||||
seq_len_tensor, "cpu")
|
||||
|
||||
v1_kernel_meta = None
|
||||
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
|
||||
v1_kernel_meta = V1KernelMeta.make(
|
||||
max_loras=ctx.num_loras,
|
||||
max_num_tokens=token_lora_indices_tensor.size(0),
|
||||
device="cpu")
|
||||
v1_kernel_meta.prepare_tensors(
|
||||
token_lora_mapping=token_lora_indices_tensor)
|
||||
|
||||
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
||||
seq_len_tensor, seq_start_loc_tensor,
|
||||
prompt_lora_indices_tensor,
|
||||
token_lora_indices_tensor)
|
||||
token_lora_indices_tensor, v1_kernel_meta)
|
||||
|
||||
def sanity_check(self) -> None:
|
||||
"""
|
||||
@ -468,6 +505,13 @@ class BenchmarkTensors:
|
||||
for i in range(len(self.lora_weights_lst)):
|
||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||
|
||||
# v1 meta
|
||||
if self.v1_kernel_meta:
|
||||
for field_name in V1KernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.v1_kernel_meta, field_name)
|
||||
assert isinstance(field, torch.Tensor)
|
||||
setattr(self.v1_kernel_meta, field_name, to_device(field))
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
@ -667,6 +711,78 @@ class BenchmarkTensors:
|
||||
})
|
||||
return {'kwargs_list': kwargs_list}
|
||||
|
||||
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert self.v1_kernel_meta is not None
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, hidden_size]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == hidden_size
|
||||
lora_rank = lw_shape[1]
|
||||
# Expected output shape [num_slices, num_tokens, lora_rank]
|
||||
assert len(o_shape) == 3
|
||||
assert o_shape == (num_slices, num_tokens, lora_rank)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
assert self.v1_kernel_meta is not None
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape : [num_slices, num_tokens, lora_rank]
|
||||
assert len(i_shape) == 3
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[2]
|
||||
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape : [num_tokens, hidden_size * num_slices]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
||||
'offset_start': 0,
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
|
||||
def bench_fn_kwargs(self,
|
||||
op_type: OpType,
|
||||
add_inputs: Optional[bool] = None) -> dict[str, Any]:
|
||||
@ -685,6 +801,10 @@ class BenchmarkTensors:
|
||||
return self.as_bgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_EXPAND_SLICE:
|
||||
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
||||
if op_type == OpType.V1_SHRINK:
|
||||
return self.as_v1_shrink_kwargs()
|
||||
if op_type == OpType.V1_EXPAND:
|
||||
return self.as_v1_expand_kwargs(add_inputs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def test_correctness(self, op_type: OpType,
|
||||
@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
timers = []
|
||||
for bench_ctx in bench_ctxs:
|
||||
for seq_len in args.seq_lengths:
|
||||
bench_ops: list[OpType] = []
|
||||
if seq_len == 1:
|
||||
# bench all decode ops
|
||||
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
||||
else:
|
||||
# bench all prefill ops
|
||||
bench_ops: list[OpType] = args.op_types
|
||||
if seq_len > 1:
|
||||
# bench only prefill ops
|
||||
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
||||
|
||||
seq_len_timers = []
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -63,6 +64,36 @@ DEVICES = ([
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
STAGES = [True, False]
|
||||
|
||||
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
|
||||
# the tests in this file run twice, once with the V0 engine and then with
|
||||
# the V1 engine.
|
||||
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
|
||||
# with the inclusion of V1 tests to maintain the CI test times.
|
||||
NUM_RANDOM_SEEDS = 5
|
||||
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
|
||||
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
|
||||
# the CI test times.
|
||||
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines_lora):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
|
||||
# Reload punica_gpu as the kernels used are tied to engine type.
|
||||
from vllm.lora.punica_wrapper import punica_gpu
|
||||
importlib.reload(punica_gpu)
|
||||
|
||||
# Release any memory we might be holding on to. CI runs OOMs otherwise.
|
||||
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
|
||||
_LORA_B_PTR_DICT)
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_random_id_to_index(num_loras: int,
|
||||
num_slots: int,
|
||||
@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
return embedding, lora_embedding
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
|
||||
return expanded_embedding, lora_embedding
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
|
||||
return linear, logits_processor, lora_logits_processor
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16,
|
||||
@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
assert lora_linear.lora_bias_stacked is None
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
for i in range(NUM_RANDOM_SEEDS):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
is_neox_style, rotary_dim, head_size,
|
||||
seq_len) -> None:
|
||||
dtype = torch.float16
|
||||
max_loras = 8
|
||||
seed = 0
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
long_lora_scaling_factors=scaling_factors,
|
||||
@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("seed", list(range(256)))
|
||||
@pytest.mark.parametrize(
|
||||
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
||||
random.seed(seed)
|
||||
vocab_size = random.randint(4000, 64000)
|
||||
|
@ -5,10 +5,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
import vllm.lora.ops.triton_ops.v1 # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (PunicaTensors, assert_close, generate_data,
|
||||
@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_shrink kernel against a reference
|
||||
implementation.
|
||||
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
|
||||
reference implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
)
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
# SGMV shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
scaling,
|
||||
)
|
||||
|
||||
sgmv_shrink_for_nslices(
|
||||
nslices,
|
||||
# V1 shrink kernel
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
scaling,
|
||||
)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
# Reference
|
||||
sgmv_shrink_for_nslices(
|
||||
nslices,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
scaling,
|
||||
)
|
||||
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
def check_expand_kernels(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_expand kernel against a reference
|
||||
implementation.
|
||||
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
|
||||
reference implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Setup metadata information for SGMV and reference kernels
|
||||
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
|
||||
data.prompt_lora_mapping, batches, max_seq_length,
|
||||
token_nums)
|
||||
|
||||
# Setup metadata information for the V1 kernel.
|
||||
v1_meta = V1KernelMeta.make(max_loras=num_loras,
|
||||
max_num_tokens=token_nums,
|
||||
device='cuda')
|
||||
v1_meta.prepare_tensors(data.token_lora_mapping)
|
||||
|
||||
# Setup output tensors
|
||||
ref_out_tensor = data.ref_out_tensor
|
||||
sgmv_out_tensor = data.our_out_tensor
|
||||
v1_out_tensor = data.our_out_tensor.clone()
|
||||
|
||||
with _dict_lock:
|
||||
# SGMV expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
sgmv_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
# V1 expand kernel
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.v1_expand(data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
v1_out_tensor,
|
||||
*v1_meta.meta_args(token_nums=token_nums),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
# Reference
|
||||
sgmv_expand_for_nslices(nslices,
|
||||
hidden_size,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
ref_out_tensor,
|
||||
*sgmv_meta_args,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
assert_close(sgmv_out_tensor, ref_out_tensor)
|
||||
assert_close(v1_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
@ -439,7 +478,7 @@ SEED = [0]
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv(
|
||||
def test_kernels(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
@ -450,29 +489,32 @@ def test_punica_sgmv(
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_shrink_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_expand_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@ -484,7 +526,7 @@ def test_punica_sgmv(
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv_hidden_size(
|
||||
def test_kernels_hidden_size(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
@ -495,29 +537,32 @@ def test_punica_sgmv_hidden_size(
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
"""
|
||||
Tests SGMV and V1 kernels.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
check_shrink_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
check_expand_kernels(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
|
@ -326,9 +326,11 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device)
|
||||
self.punica_wrapper = get_punica_wrapper(
|
||||
max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device,
|
||||
max_loras=self.lora_config.max_loras)
|
||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
|
@ -54,7 +54,7 @@ _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
|
||||
|
||||
|
||||
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
|
||||
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
@ -100,7 +100,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
|
||||
|
||||
|
||||
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
|
||||
device: str):
|
||||
device: torch.device):
|
||||
"""
|
||||
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
|
11
vllm/lora/ops/triton_ops/v1/__init__.py
Normal file
11
vllm/lora/ops/triton_ops/v1/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
|
||||
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
|
||||
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
|
||||
|
||||
__all__ = [
|
||||
"v1_expand",
|
||||
"v1_shrink",
|
||||
"V1KernelMeta",
|
||||
]
|
282
vllm/lora/ops/triton_ops/v1/v1_expand.py
Normal file
282
vllm/lora/ops/triton_ops/v1/v1_expand.py
Normal file
@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _v1_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_mn = tl.program_id(axis=0)
|
||||
pid_m = pid_mn % cta_m_num
|
||||
pid_n = (pid_mn // cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_id,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _v1_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: List[
|
||||
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
output_tensor: torch.
|
||||
Tensor, # shape [num_tokens, hidden_size * num_slices]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (List[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check.
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
|
||||
inputs.device)
|
||||
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
M = inputs.size(1)
|
||||
ADD_INPUTS = add_inputs
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
CAST_TYPE = False
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
MAX_NREG = None
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only a few input tokens require
|
||||
# LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_v1_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
MAX_N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
NUM_SLICES,
|
||||
same_stride,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
maxnreg=MAX_NREG,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _v1_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="v1_expand",
|
||||
op_func=_v1_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_v1_expand_fake,
|
||||
)
|
||||
v1_expand = torch.ops.vllm.v1_expand
|
||||
|
||||
except AttributeError:
|
||||
v1_expand = _v1_expand
|
117
vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py
Normal file
117
vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py
Normal file
@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
V1 LoRA kernels metadata preparation utilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class V1KernelMeta:
|
||||
token_lora_mapping: torch.Tensor
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor
|
||||
active_lora_ids: torch.Tensor
|
||||
num_tokens_per_lora: torch.Tensor
|
||||
lora_token_start_loc: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(max_loras: int, max_num_tokens: int,
|
||||
device: Union[torch.device, str]) -> "V1KernelMeta":
|
||||
|
||||
token_lora_mapping = torch.empty(max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# +1 because "no-lora" is also a possibility
|
||||
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
|
||||
# is a possibility.
|
||||
active_lora_ids = torch.empty(max_loras + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# using running example, [3, 10, 5, 2] is a possibility.
|
||||
num_tokens_per_lora = torch.zeros(max_loras + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# +2 for this because, the first index is always 0.
|
||||
# using running example, lora_token_start_loc
|
||||
# is [0, 3, 13, 18, 20].
|
||||
lora_token_start_loc = torch.zeros(max_loras + 2,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return V1KernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
active_lora_ids=active_lora_ids,
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc)
|
||||
|
||||
def _reset(self):
|
||||
self.active_lora_ids.fill_(-1)
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
Prepare kernel metadata tensors for the current forward pass.
|
||||
|
||||
Args:
|
||||
token_lora_tensor (torch.Tensor): Tensor containing lora indices
|
||||
for each input token.
|
||||
"""
|
||||
|
||||
self._reset()
|
||||
|
||||
num_tokens = token_lora_mapping.size(0)
|
||||
|
||||
# copy token lora mapping
|
||||
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
|
||||
non_blocking=True)
|
||||
|
||||
# token_indices_sorted_by_lora_ids
|
||||
_, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping,
|
||||
stable=True)
|
||||
# start gpu transfer
|
||||
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
|
||||
token_indices_sorted_by_lora_ids, non_blocking=True)
|
||||
|
||||
# active_lora_ids, num_tokens_per_lora
|
||||
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
|
||||
sorted=False,
|
||||
return_counts=True)
|
||||
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
|
||||
non_blocking=True)
|
||||
self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(
|
||||
num_tokens_per_lora, non_blocking=True)
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_(
|
||||
lora_token_start_loc, non_blocking=True)
|
||||
|
||||
def meta_args(
|
||||
self, token_nums: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
forward pass execution of the kernel. The function returns all the
|
||||
metadata required by the kernel, in order, as a tuple, so it can be
|
||||
unpacked directly during the v1_shrink/v1_expand function call.
|
||||
|
||||
Args:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass.
|
||||
"""
|
||||
return (self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
self.num_tokens_per_lora, self.lora_token_start_loc,
|
||||
self.active_lora_ids)
|
236
vllm/lora/ops/triton_ops/v1/v1_shrink.py
Normal file
236
vllm/lora/ops/triton_ops/v1/v1_shrink.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
|
||||
lora_token_start_loc, lora_ids, scaling, input_d0_stride,
|
||||
input_d1_stride, lora_d0_stride, lora_d1_stride,
|
||||
lora_d2_stride, output_d0_stride, output_d1_stride,
|
||||
output_d2_stride, BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
|
||||
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _v1_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: List[
|
||||
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor
|
||||
lora_a_weights (List[torch.Tensor]): LoRA weights
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
M = inputs.size(0)
|
||||
NUM_SLICES = len(lora_a_weights)
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 256 if M < 128 else 32
|
||||
SPLIT_K = 64 if M < 128 else 8
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
MAX_NREG = None
|
||||
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only few of the input tokens
|
||||
# require LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_v1_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
NUM_SLICES,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
maxnreg=MAX_NREG,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _v1_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: List[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="v1_shrink",
|
||||
op_func=_v1_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_v1_shrink_fake,
|
||||
)
|
||||
v1_shrink = torch.ops.vllm.v1_shrink
|
||||
|
||||
except AttributeError:
|
||||
v1_shrink = _v1_shrink
|
@ -6,24 +6,83 @@ Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, final
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops import sgmv_shrink
|
||||
if env.VLLM_USE_V1:
|
||||
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
|
||||
v1_shrink)
|
||||
else:
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops import sgmv_shrink
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
class V1KernelMixin:
|
||||
|
||||
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
|
||||
max_batches: int, device: Union[torch.device, str]):
|
||||
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
|
||||
max_num_batched_tokens,
|
||||
device=device)
|
||||
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
|
||||
max_batches,
|
||||
device=device)
|
||||
|
||||
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor):
|
||||
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
|
||||
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
|
||||
|
||||
def _v1_apply_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
):
|
||||
v1_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.token_mapping_v1_meta.meta_args(x.size(0)),
|
||||
scale,
|
||||
)
|
||||
|
||||
def _v1_apply_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
offset_start: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
v1_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.token_mapping_v1_meta.meta_args(x.size(0)),
|
||||
offset_start=offset_start,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
||||
"""
|
||||
PunicaWrapperGPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
@ -35,6 +94,36 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
|
||||
self.max_loras = kwargs['max_loras']
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
|
||||
max_batches, device)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: LoRAMapping,
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
**kwargs):
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
self.is_prefill = mapping.is_prefill
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
self._v1_prepare_metadata_tensors(self.token_lora_indices,
|
||||
self.sampler_indices)
|
||||
else:
|
||||
# Forward to base class update_metadata
|
||||
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
|
||||
max_loras, vocab_size,
|
||||
extra_vocab_size,
|
||||
long_lora_context, **kwargs)
|
||||
|
||||
def _apply_shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
@ -66,7 +155,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
w_t_all: Tuple[torch.Tensor, ...],
|
||||
offset_start: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
@ -118,14 +207,21 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_shrink_prefill(y, x, lora_a_stacked, scale)
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink_decode(y[slice_idx], x,
|
||||
lora_a_stacked[slice_idx], scale)
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_shrink_prefill(
|
||||
y, # type: ignore
|
||||
x,
|
||||
lora_a_stacked,
|
||||
scale)
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink_decode(y[slice_idx], x,
|
||||
lora_a_stacked[slice_idx], scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
@ -160,25 +256,38 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_expand_prefill(y,
|
||||
x,
|
||||
lora_b_stacked,
|
||||
offset_start,
|
||||
add_inputs=True)
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
# TODO (varun): Profile with add_inputs = False. i.e. move the
|
||||
# addition out of the kernel
|
||||
self._v1_apply_expand(
|
||||
y,
|
||||
x, # type: ignore
|
||||
lora_b_stacked,
|
||||
offset_start,
|
||||
add_inputs=True)
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand_decode(
|
||||
|
||||
if self.is_prefill:
|
||||
# NOTE fused kernel
|
||||
self._apply_expand_prefill(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
x, # type: ignore
|
||||
lora_b_stacked,
|
||||
offset_start,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_start += output_slices[slice_idx]
|
||||
add_inputs=True)
|
||||
else:
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand_decode(
|
||||
y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
offset_start,
|
||||
output_slices[slice_idx],
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
offset_start += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
@ -200,18 +309,24 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
add_inputs (bool): Default to True.
|
||||
"""
|
||||
|
||||
if self.is_prefill:
|
||||
sgmv_expand(
|
||||
x.unsqueeze(dim=0),
|
||||
[lora_b_stacked],
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
if env.VLLM_USE_V1:
|
||||
self._v1_apply_expand(y,
|
||||
x.unsqueeze(dim=0), (lora_b_stacked, ),
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs)
|
||||
else:
|
||||
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
|
||||
add_inputs)
|
||||
if self.is_prefill:
|
||||
sgmv_expand(
|
||||
x.unsqueeze(dim=0),
|
||||
(lora_b_stacked, ),
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
else:
|
||||
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
|
||||
add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
@ -257,19 +372,25 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default ,refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros(
|
||||
buffer = torch.zeros( # type: ignore
|
||||
(len(output_slices), x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
self.add_shrink(
|
||||
buffer, # type: ignore
|
||||
x,
|
||||
lora_a_stacked,
|
||||
scale,
|
||||
**kwargs)
|
||||
self.add_expand(
|
||||
y,
|
||||
buffer, # type: ignore
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
@ -305,11 +426,22 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
# LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
add_inputs=True)
|
||||
|
||||
if env.VLLM_USE_V1:
|
||||
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
|
||||
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
|
||||
|
||||
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
|
||||
y,
|
||||
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
|
||||
add_inputs=True)
|
||||
else:
|
||||
|
||||
# V0 LogitsProcessorWithLoRA always using bgmv.
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer,
|
||||
lora_b_stacked,
|
||||
y,
|
||||
self.sampler_indices,
|
||||
add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
|
@ -62,9 +62,9 @@ class LoRAModelRunnerMixin:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
|
||||
# We dont make any distinction between prefills and decodes in the
|
||||
# scheduler. To that effect, set is_prefill to True so we use the
|
||||
# sgmv punica kernels always.
|
||||
# Set is_prefill to True, so we always use the SGMV kernels.
|
||||
# For cuda platforms, we have specialized triton kernels, and
|
||||
# the cuda path ignores `is_prefill`.
|
||||
lora_mapping = LoRAMapping(token_lora_mapping,
|
||||
prompt_lora_mapping,
|
||||
is_prefill=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user