diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 3c4d6a6a..115b9253 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -23,6 +23,7 @@ from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -171,6 +172,8 @@ class OpType(Enum): SGMV_EXPAND = auto() BGMV_EXPAND = auto() BGMV_EXPAND_SLICE = auto() + V1_SHRINK = auto() + V1_EXPAND = auto() @staticmethod def from_str(s: str) -> "OpType": @@ -184,28 +187,43 @@ class OpType(Enum): return OpType.BGMV_EXPAND if s.lower() == "bgmv_expand_slice": return OpType.BGMV_EXPAND_SLICE + if s.lower() == "v1_shrink": + return OpType.V1_SHRINK + if s.lower() == "v1_expand": + return OpType.V1_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: - return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK] + return self in [ + OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK + ] def is_expand_fn(self) -> bool: - return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] + return self in [ + OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND + ] def is_prefill_op(self) -> bool: - return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND] + return self in [ + OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK, + OpType.V1_EXPAND + ] def is_decode_op(self) -> bool: return self in [ - OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE + OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE, + OpType.V1_SHRINK, OpType.V1_EXPAND ] def is_expand_slice_fn(self) -> bool: return self in [OpType.BGMV_EXPAND_SLICE] def num_slices(self) -> list[int]: - if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: - # SGMV kernels supports slices + if self in [ + OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK, + OpType.V1_EXPAND + ]: + # SGMV kernels and v1 kernels supports slices return [1, 2, 3] if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: return [1] @@ -250,11 +268,13 @@ class OpType(Enum): m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) b_shape = (num_loras, n, k) # col-major - if self == OpType.SGMV_SHRINK: - # SGMV shrink supports num_slices inherently in the kernel + if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: + # SGMV shrink and V1 shrink kernels support num_slices inherently + # in the kernel. return ((m, k), b_shape, (num_slices, m, n)) - if self == OpType.SGMV_EXPAND: - # SGMV expand supports num_slices inherently in the kernel + if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: + # SGMV expand and V1 expand kernels support num_slices inherently + # in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) if self == OpType.BGMV_SHRINK: return ((m, k), b_shape, (m, n)) @@ -281,25 +301,30 @@ class OpType(Enum): return bgmv_expand if self == OpType.BGMV_EXPAND_SLICE: return emulate_bgmv_expand_slice + if self == OpType.V1_SHRINK: + return v1_shrink + if self == OpType.V1_EXPAND: + return v1_expand + raise ValueError(f"Unrecognized optype {self}") def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, lora_weights: list[torch.Tensor], **kwargs) -> Callable: - """Each benchmark operation expected the input, lora_weights and outputs + """Each benchmark operation expects the input, lora_weights and outputs in a slightly different format. Refer to self.matmul_shapes(). run_ref_group_gemm accounts for those differences in executing a reference group gemm for correctness testing. """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) - if self == OpType.SGMV_SHRINK: + if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]: for slice_idx in range(num_slices): ref_group_gemm(ref_out=output[slice_idx, :], input=input, lora_weights=lora_weights[slice_idx], **kwargs) - if self == OpType.SGMV_EXPAND: + elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size @@ -308,19 +333,19 @@ class OpType(Enum): input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], **kwargs) - if self == OpType.BGMV_SHRINK: + elif self == OpType.BGMV_SHRINK: assert num_slices == 1 ref_group_gemm(ref_out=output, input=input, lora_weights=lora_weights[0], **kwargs) - if self == OpType.BGMV_EXPAND: + elif self == OpType.BGMV_EXPAND: assert num_slices == 1 ref_group_gemm(ref_out=output, input=input.clone().to(dtype=w_dtype), lora_weights=lora_weights[0], **kwargs) - if self == OpType.BGMV_EXPAND_SLICE: + elif self == OpType.BGMV_EXPAND_SLICE: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size @@ -329,7 +354,8 @@ class OpType(Enum): input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], **kwargs) - raise ValueError(f"Unrecognized optype {self}") + else: + raise ValueError(f"Unrecognized optype {self}") @dataclass @@ -390,6 +416,8 @@ class BenchmarkTensors: seq_start_loc: torch.Tensor prompt_lora_mapping: torch.Tensor token_lora_mapping: torch.Tensor + # v1 kernel metadata + v1_kernel_meta: Optional[V1KernelMeta] = None def io_types(self) -> str: return (f"{dtype_to_str(self.input.dtype)}x" @@ -432,10 +460,19 @@ class BenchmarkTensors: total_tokens, ctx.batch_size, prompt_lora_indices_tensor, seq_len_tensor, "cpu") + v1_kernel_meta = None + if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]: + v1_kernel_meta = V1KernelMeta.make( + max_loras=ctx.num_loras, + max_num_tokens=token_lora_indices_tensor.size(0), + device="cpu") + v1_kernel_meta.prepare_tensors( + token_lora_mapping=token_lora_indices_tensor) + return BenchmarkTensors(input_tensor, lora_weights, output_tensor, seq_len_tensor, seq_start_loc_tensor, prompt_lora_indices_tensor, - token_lora_indices_tensor) + token_lora_indices_tensor, v1_kernel_meta) def sanity_check(self) -> None: """ @@ -468,6 +505,13 @@ class BenchmarkTensors: for i in range(len(self.lora_weights_lst)): self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) + # v1 meta + if self.v1_kernel_meta: + for field_name in V1KernelMeta.__dataclass_fields__: + field = getattr(self.v1_kernel_meta, field_name) + assert isinstance(field, torch.Tensor) + setattr(self.v1_kernel_meta, field_name, to_device(field)) + def metadata(self) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len @@ -667,6 +711,78 @@ class BenchmarkTensors: }) return {'kwargs_list': kwargs_list} + def as_v1_shrink_kwargs(self) -> dict[str, Any]: + assert self.v1_kernel_meta is not None + self.sanity_check() + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape [num_tokens, hidden_size] + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [num_loras, lora_rank, hidden_size] + assert len(lw_shape) == 3 + assert lw_shape[2] == hidden_size + lora_rank = lw_shape[1] + # Expected output shape [num_slices, num_tokens, lora_rank] + assert len(o_shape) == 3 + assert o_shape == (num_slices, num_tokens, lora_rank) + + return { + 'inputs': self.input, + 'lora_a_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, + 'token_indices_sorted_by_lora_ids': + self.v1_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, + 'lora_ids': self.v1_kernel_meta.active_lora_ids, + 'scaling': 1.0, + } + + def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: + assert self.v1_kernel_meta is not None + self.sanity_check() + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata() + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ + 0].shape, self.output.shape + # Expected input shape : [num_slices, num_tokens, lora_rank] + assert len(i_shape) == 3 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[2] + # Expected lora weight shape : [num_lora, hidden_size, lora_rank] + assert len(lw_shape) == 3 + assert lw_shape[2] == lora_rank + hidden_size = lw_shape[1] + # Expected output shape : [num_tokens, hidden_size * num_slices] + assert len(o_shape) == 2 + assert o_shape == (num_tokens, hidden_size * num_slices) + + return { + 'inputs': self.input, + 'lora_b_weights': self.lora_weights_lst, + 'output_tensor': self.output, + 'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping, + 'token_indices_sorted_by_lora_ids': + self.v1_kernel_meta.token_indices_sorted_by_lora_ids, + 'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora, + 'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc, + 'lora_ids': self.v1_kernel_meta.active_lora_ids, + 'offset_start': 0, + 'add_inputs': add_inputs, + } + def bench_fn_kwargs(self, op_type: OpType, add_inputs: Optional[bool] = None) -> dict[str, Any]: @@ -685,6 +801,10 @@ class BenchmarkTensors: return self.as_bgmv_expand_kwargs(add_inputs) if op_type == OpType.BGMV_EXPAND_SLICE: return self.as_bgmv_expand_slice_kwargs(add_inputs) + if op_type == OpType.V1_SHRINK: + return self.as_v1_shrink_kwargs() + if op_type == OpType.V1_EXPAND: + return self.as_v1_expand_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") def test_correctness(self, op_type: OpType, @@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): timers = [] for bench_ctx in bench_ctxs: for seq_len in args.seq_lengths: - bench_ops: list[OpType] = [] - if seq_len == 1: - # bench all decode ops - bench_ops = [op for op in args.op_types if op.is_decode_op()] - else: - # bench all prefill ops + bench_ops: list[OpType] = args.op_types + if seq_len > 1: + # bench only prefill ops bench_ops = [op for op in args.op_types if op.is_prefill_op()] seq_len_timers = [] diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 428a1c71..8c8e55ed 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib import random from copy import deepcopy from dataclasses import dataclass @@ -63,6 +64,36 @@ DEVICES = ([ # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] +# With the inclusion of V1 tests (look at the run_with_both_engines_lora), +# the tests in this file run twice, once with the V0 engine and then with +# the V1 engine. +# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half +# with the inclusion of V1 tests to maintain the CI test times. +NUM_RANDOM_SEEDS = 5 +# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to +# 256 before. It is cut to half with the inclusion of V1 tests to maintain +# the CI test times. +VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128 + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + + # Reload punica_gpu as the kernels used are tied to engine type. + from vllm.lora.punica_wrapper import punica_gpu + importlib.reload(punica_gpu) + + # Release any memory we might be holding on to. CI runs OOMs otherwise. + from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, + _LORA_B_PTR_DICT) + _LORA_B_PTR_DICT.clear() + _LORA_A_PTR_DICT.clear() + + yield + def get_random_id_to_index(num_loras: int, num_slots: int, @@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: torch.set_default_device(device) max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: return embedding, lora_embedding - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, torch.set_default_device(device) max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, return expanded_embedding, lora_embedding - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, torch.set_default_device(device) max_loras = 8 - punica_wrapper = get_punica_wrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, return linear, logits_processor, lora_logits_processor - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage, if current_platform.is_cuda_alike(): torch.cuda.set_device(device) - torch.set_default_device(device) - punica_wrapper = get_punica_wrapper(8192, 256, device) - assert check_punica_wrapper(punica_wrapper) max_loras = 8 + torch.set_default_device(device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=torch.float16, @@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage, assert lora_linear.lora_bias_stacked is None return linear, lora_linear - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, if current_platform.is_cuda_alike(): torch.cuda.set_device(device) - torch.set_default_device(device) - punica_wrapper = get_punica_wrapper(8192, 256, device) - assert check_punica_wrapper(punica_wrapper) max_loras = 8 + torch.set_default_device(device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, @@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, assert lora_linear.lora_bias_stacked is None return linear, lora_linear - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, if current_platform.is_cuda_alike(): torch.cuda.set_device(device) - torch.set_default_device(device) - punica_wrapper = get_punica_wrapper(8192, 256, device) - assert check_punica_wrapper(punica_wrapper) max_loras = 8 + torch.set_default_device(device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) + assert check_punica_wrapper(punica_wrapper) lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, @@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, assert lora_linear.lora_bias_stacked is None return linear, lora_linear - for i in range(10): + for i in range(NUM_RANDOM_SEEDS): set_random_seed(i) id_to_index = get_random_id_to_index(num_loras, max_loras) @@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, is_neox_style, rotary_dim, head_size, seq_len) -> None: dtype = torch.float16 + max_loras = 8 seed = 0 current_platform.seed_everything(seed) torch.set_default_device(device) - punica_wrapper = get_punica_wrapper(8192, 256, device) + punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, long_lora_scaling_factors=scaling_factors, @@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, @pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("seed", list(range(256))) +@pytest.mark.parametrize( + "seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS))) def test_vocab_parallel_embedding_indices(tp_size, seed): random.seed(seed) vocab_size = random.randint(4000, 64000) diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index c75e8661..a412a80d 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -5,10 +5,12 @@ import pytest import torch import vllm.lora.ops.triton_ops # noqa: F401 +import vllm.lora.ops.triton_ops.v1 # noqa: F401 from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.lora.ops.triton_ops.v1 import V1KernelMeta from vllm.platforms import current_platform from .utils import (PunicaTensors, assert_close, generate_data, @@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int, _dict_lock = Lock() -def check_sgmv_shrink(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, dtype: torch.dtype, - device: str, seq_length: int, scaling: float): +def check_shrink_kernels(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, scaling: float): """ - Compare outputs of vllm.sgmv_shrink kernel against a reference - implementation. + Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a + reference implementation. """ data: PunicaTensors = generate_data_for_nslices( batches, @@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int, ) max_seq_length, token_nums = data.meta() + # Setup metadata information for SGMV and reference kernels + sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, + data.prompt_lora_mapping, batches, max_seq_length, + token_nums) + + # Setup metadata information for the V1 kernel. + v1_meta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=token_nums, + device='cuda') + v1_meta.prepare_tensors(data.token_lora_mapping) + + ref_out_tensor = data.ref_out_tensor + sgmv_out_tensor = data.our_out_tensor + v1_out_tensor = data.our_out_tensor.clone() + # Preventing cache error pointer. with _dict_lock: + # SGMV shrink kernel _LORA_A_PTR_DICT.clear() torch.ops.vllm.sgmv_shrink( data.inputs_tensor, data.lora_weights, - data.our_out_tensor, - data.b_seq_start_loc, - data.seq_len_tensor, - data.prompt_lora_mapping, - batches, - max_seq_length, - token_nums, + sgmv_out_tensor, + *sgmv_meta_args, scaling, ) - sgmv_shrink_for_nslices( - nslices, + # V1 shrink kernel + _LORA_A_PTR_DICT.clear() + torch.ops.vllm.v1_shrink( data.inputs_tensor, data.lora_weights, - data.ref_out_tensor, - data.b_seq_start_loc, - data.seq_len_tensor, - data.prompt_lora_mapping, - batches, - max_seq_length, - token_nums, + v1_out_tensor, + *v1_meta.meta_args(token_nums=token_nums), scaling, ) - assert_close(data.our_out_tensor, data.ref_out_tensor) + + # Reference + sgmv_shrink_for_nslices( + nslices, + data.inputs_tensor, + data.lora_weights, + ref_out_tensor, + *sgmv_meta_args, + scaling, + ) + + assert_close(sgmv_out_tensor, ref_out_tensor) + assert_close(v1_out_tensor, ref_out_tensor) -def check_sgmv_expand(batches: int, num_loras: int, rank: int, - hidden_size: int, nslices: int, dtype: torch.dtype, - device: str, seq_length: int, add_inputs: bool): +def check_expand_kernels(batches: int, num_loras: int, rank: int, + hidden_size: int, nslices: int, dtype: torch.dtype, + device: str, seq_length: int, add_inputs: bool): """ - Compare outputs of vllm.sgmv_expand kernel against a reference - implementation. + Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a + reference implementation. """ data: PunicaTensors = generate_data_for_nslices( batches, @@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int, max_seq_length, token_nums = data.meta() + # Setup metadata information for SGMV and reference kernels + sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor, + data.prompt_lora_mapping, batches, max_seq_length, + token_nums) + + # Setup metadata information for the V1 kernel. + v1_meta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=token_nums, + device='cuda') + v1_meta.prepare_tensors(data.token_lora_mapping) + + # Setup output tensors + ref_out_tensor = data.ref_out_tensor + sgmv_out_tensor = data.our_out_tensor + v1_out_tensor = data.our_out_tensor.clone() + with _dict_lock: + # SGMV expand kernel _LORA_B_PTR_DICT.clear() torch.ops.vllm.sgmv_expand( data.inputs_tensor, data.lora_weights, - data.our_out_tensor, - data.b_seq_start_loc, - data.seq_len_tensor, - data.prompt_lora_mapping, - batches, - max_seq_length, - token_nums, + sgmv_out_tensor, + *sgmv_meta_args, offset_start=0, add_inputs=add_inputs, ) + # V1 expand kernel + _LORA_B_PTR_DICT.clear() + torch.ops.vllm.v1_expand(data.inputs_tensor, + data.lora_weights, + v1_out_tensor, + *v1_meta.meta_args(token_nums=token_nums), + offset_start=0, + add_inputs=add_inputs) + + # Reference sgmv_expand_for_nslices(nslices, hidden_size, data.inputs_tensor, data.lora_weights, - data.ref_out_tensor, - data.b_seq_start_loc, - data.seq_len_tensor, - data.prompt_lora_mapping, - batches, - max_seq_length, - token_nums, + ref_out_tensor, + *sgmv_meta_args, add_inputs=add_inputs) - assert_close(data.our_out_tensor, data.ref_out_tensor) + assert_close(sgmv_out_tensor, ref_out_tensor) + assert_close(v1_out_tensor, ref_out_tensor) def check_bgmv_shrink(batches: int, num_loras: int, rank: int, @@ -439,7 +478,7 @@ SEED = [0] @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) -def test_punica_sgmv( +def test_kernels( batches: int, num_loras: int, rank: int, @@ -450,29 +489,32 @@ def test_punica_sgmv( seed: int, op_type: str, ): + """ + Tests SGMV and V1 kernels. + """ torch.set_default_device(device) current_platform.seed_everything(seed) if op_type == "shrink": - check_sgmv_shrink(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_shrink_kernels(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) else: - check_sgmv_expand(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_expand_kernels(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) @pytest.mark.parametrize("batches", hs_test_params['batches']) @@ -484,7 +526,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) -def test_punica_sgmv_hidden_size( +def test_kernels_hidden_size( batches: int, num_loras: int, rank: int, @@ -495,29 +537,32 @@ def test_punica_sgmv_hidden_size( seed: int, op_type: str, ): + """ + Tests SGMV and V1 kernels. + """ torch.set_default_device(device) current_platform.seed_everything(seed) if op_type == "shrink": - check_sgmv_shrink(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - scaling=0.5) + check_shrink_kernels(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + scaling=0.5) else: - check_sgmv_expand(batches=batches, - num_loras=num_loras, - rank=rank, - hidden_size=hidden_size, - nslices=nslices, - dtype=dtype, - device=device, - seq_length=128, - add_inputs=True) + check_expand_kernels(batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + seq_length=128, + add_inputs=True) @pytest.mark.parametrize("batches", test_params['batches']) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e1294884..174b9f0b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -326,9 +326,11 @@ class LoRAModelManager(AdapterModelManager): self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device) + self.punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=self.lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 78409b91..b52a842c 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -54,7 +54,7 @@ _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): +def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): """ `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. @@ -100,7 +100,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, - device: str): + device: torch.device): """ `_LORA_B_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. diff --git a/vllm/lora/ops/triton_ops/v1/__init__.py b/vllm/lora/ops/triton_ops/v1/__init__.py new file mode 100644 index 00000000..1d2c46f4 --- /dev/null +++ b/vllm/lora/ops/triton_ops/v1/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand +from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta +from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink + +__all__ = [ + "v1_expand", + "v1_shrink", + "V1KernelMeta", +] \ No newline at end of file diff --git a/vllm/lora/ops/triton_ops/v1/v1_expand.py b/vllm/lora/ops/triton_ops/v1/v1_expand.py new file mode 100644 index 00000000..20c7f8f4 --- /dev/null +++ b/vllm/lora/ops/triton_ops/v1/v1_expand.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _v1_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_mn = tl.program_id(axis=0) + pid_m = pid_mn % cta_m_num + pid_n = (pid_mn // cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) + if pid_n * BLOCK_N >= curr_N: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_expand_kernel( + pid_n, + lora_id, + slice_id, + input_ptr, + lora_ptr, + out_ptr, + curr_N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + slice_start_loc, + # input ptr strides + input_d0_stride, + input_d1_stride, + input_d2_stride, + # lora ptr strides + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, + # out ptr strides + output_d0_stride, + output_d1_stride, + # constants + BLOCK_M, + BLOCK_N, + BLOCK_K, + SAME_STRIDE, + SLICE_NUM, + EVEN_K, + CAST_TYPE, + ADD_INPUTS) + + +@torch.inference_mode() +def _v1_expand( + inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + lora_b_weights: List[ + torch.Tensor], # shape [num_lora, hidden_size, lora_rank] + output_tensor: torch. + Tensor, # shape [num_tokens, hidden_size * num_slices] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (List[torch.Tensor]): lora'b weight + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(0) == len(lora_b_weights) + assert output_tensor.is_contiguous() + + # metadata sanity check. + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + inputs.device) + + K = lora_b_weights[0].shape[-1] # K= rank + M = inputs.size(1) + ADD_INPUTS = add_inputs + MAX_LORAS = lora_ids.size(0) + CAST_TYPE = False + NUM_SLICES = len(lora_b_weights) + + # Triton kernel configs. + BLOCK_M = 64 + BLOCK_N = 128 + BLOCK_K = 16 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + MAX_NREG = None + + EVEN_K = K % BLOCK_K == 0 # type: ignore + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only a few input tokens require + # LoRA. This might not be the best in all cases. + grid = ( + triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks simply exit. + MAX_LORAS, + ) + + _v1_expand_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + MAX_N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + slice_start_tensor, + inputs.stride(0), + inputs.stride(1), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + hidden_sizes_tensor, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + NUM_SLICES, + same_stride, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + maxnreg=MAX_NREG, + ) + + return + + +def _v1_expand_fake( + inputs: torch.Tensor, + lora_b_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + offset_start: int = 0, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="v1_expand", + op_func=_v1_expand, + mutates_args=["output_tensor"], + fake_impl=_v1_expand_fake, + ) + v1_expand = torch.ops.vllm.v1_expand + +except AttributeError: + v1_expand = _v1_expand diff --git a/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py b/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py new file mode 100644 index 00000000..57b4dd7a --- /dev/null +++ b/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +V1 LoRA kernels metadata preparation utilities. +""" + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + + +@dataclass +class V1KernelMeta: + token_lora_mapping: torch.Tensor + token_indices_sorted_by_lora_ids: torch.Tensor + active_lora_ids: torch.Tensor + num_tokens_per_lora: torch.Tensor + lora_token_start_loc: torch.Tensor + + @staticmethod + def make(max_loras: int, max_num_tokens: int, + device: Union[torch.device, str]) -> "V1KernelMeta": + + token_lora_mapping = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + # +1 because "no-lora" is also a possibility + # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] + # is a possibility. + active_lora_ids = torch.empty(max_loras + 1, + dtype=torch.int32, + device=device) + + # using running example, [3, 10, 5, 2] is a possibility. + num_tokens_per_lora = torch.zeros(max_loras + 1, + dtype=torch.int32, + device=device) + + # +2 for this because, the first index is always 0. + # using running example, lora_token_start_loc + # is [0, 3, 13, 18, 20]. + lora_token_start_loc = torch.zeros(max_loras + 2, + dtype=torch.int32, + device=device) + return V1KernelMeta( + token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc) + + def _reset(self): + self.active_lora_ids.fill_(-1) + self.num_tokens_per_lora.fill_(0) + self.lora_token_start_loc.fill_(0) + + def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: + """ + Prepare kernel metadata tensors for the current forward pass. + + Args: + token_lora_tensor (torch.Tensor): Tensor containing lora indices + for each input token. + """ + + self._reset() + + num_tokens = token_lora_mapping.size(0) + + # copy token lora mapping + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, + non_blocking=True) + + # token_indices_sorted_by_lora_ids + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, + stable=True) + # start gpu transfer + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( + token_indices_sorted_by_lora_ids, non_blocking=True) + + # active_lora_ids, num_tokens_per_lora + lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, + sorted=False, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, + non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True) + + # lora_token_start_loc + lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) + self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True) + + def meta_args( + self, token_nums: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: + """ + This function returns the kernel metadata required for the current + forward pass execution of the kernel. The function returns all the + metadata required by the kernel, in order, as a tuple, so it can be + unpacked directly during the v1_shrink/v1_expand function call. + + Args: + token_nums (int): Number of input tokens in the current forward + pass. + """ + return (self.token_lora_mapping[:token_nums], + self.token_indices_sorted_by_lora_ids[:token_nums], + self.num_tokens_per_lora, self.lora_token_start_loc, + self.active_lora_ids) diff --git a/vllm/lora/ops/triton_ops/v1/v1_shrink.py b/vllm/lora/ops/triton_ops/v1/v1_shrink.py new file mode 100644 index 00000000..39affd18 --- /dev/null +++ b/vllm/lora/ops/triton_ops/v1/v1_shrink.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import List + +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel +from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, + token_indices_sorted_by_lora_ids, num_tokens_per_lora, + lora_token_start_loc, lora_ids, scaling, input_d0_stride, + input_d1_stride, lora_d0_stride, lora_d1_stride, + lora_d2_stride, output_d0_stride, output_d1_stride, + output_d2_stride, BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr): + + cta_n_num = tl.cdiv(N, BLOCK_N) + cta_m_num = tl.cdiv(M, BLOCK_M) + + pid_sk_m_n = tl.program_id(axis=0) + pid_sk = pid_sk_m_n % SPLIT_K + pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num + pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # Early exit for the no-lora case. + return + + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = pid_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # Early exit CTA. + return + + # num rows this CTA should process. + cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset) + + # Identify all rows that this CTA should process. + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + + # Load all relevant row indices. + offset_m = tl.arange(0, BLOCK_M) % cta_m_len + ram = tl.load(cta_lora_seq_indices + offset_m) + + do_shrink_kernel( + pid_n, + pid_sk, + slice_id, + lora_id, + input_ptr, + lora_ptr, + out_ptr, + N, + K, + cta_m_len, + ram, # array identifying the rows of Input ptr to operate on + # input strides + input_d0_stride, + input_d1_stride, + # lora strides + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, + # output strides + output_d0_stride, + output_d1_stride, + output_d2_stride, + scaling, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SLICE_NUM) + + +@torch.inference_mode() +def _v1_shrink( + inputs: torch.Tensor, # shape [num_tokens, hidden_size] + lora_a_weights: List[ + torch.Tensor], # shape [num_loras, lora_rank, hidden_size] + output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] + token_lora_mapping: torch.Tensor, # shape [num_tokens] + token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens] + num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] + lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] + lora_ids: torch.Tensor, # shape [max-loras + 1] + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): Input tensor + lora_a_weights (List[torch.Tensor]): LoRA weights + output_tensor (torch.Tensor): output tensor + token_lora_mapping (torch.Tensor): A tensor mapping each input token + to the lora-id related to that token. A value of -1 indicates that + LoRA doesn't apply to that token. + token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from + the A matrix grouped by LoRA IDs. + num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number + of tokens that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc (torch.Tensor): A cumulative sum of + num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that + lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the region in token_indices_sorted_by_lora_ids that + LoRA lora_ids[i] should process. + lora_ids (torch.Tensor): LoRA ids to process. + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights[0].dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == lora_a_weights[0].size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + # metadata sanity check + assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( + 0) + assert lora_ids.size(0) == num_tokens_per_lora.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank + M = inputs.size(0) + NUM_SLICES = len(lora_a_weights) + MAX_LORAS = lora_ids.size(0) + + # Triton kernel configs + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 256 if M < 128 else 32 + SPLIT_K = 64 if M < 128 else 8 + NUM_WARPS = 4 + NUM_CTAS = 1 + NUM_STAGES = 2 + MAX_NREG = None + + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore + + # TODO (varun): This grid formulation maximizes parallelization at the + # cost of wasteful thread block launch when only few of the input tokens + # require LoRA. This might not be the best in all cases. + grid = ( + SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + NUM_SLICES, + # Each LoRA receives its own set of thread blocks for output + # computation. If some LoRA doesn't have any tokens to process, its + # thread blocks exit early. + MAX_LORAS, + ) + + _v1_shrink_kernel[grid]( + inputs, + lora_ptr_tensor, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor.stride(2), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + NUM_SLICES, + num_warps=NUM_WARPS, + num_ctas=NUM_CTAS, + num_stages=NUM_STAGES, + maxnreg=MAX_NREG, + ) + + return + + +def _v1_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: List[torch.Tensor], + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + scaling: float, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="v1_shrink", + op_func=_v1_shrink, + mutates_args=["output_tensor"], + fake_impl=_v1_shrink_fake, + ) + v1_shrink = torch.ops.vllm.v1_shrink + +except AttributeError: + v1_shrink = _v1_shrink diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 9ccd9c36..3a4fcd04 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -6,24 +6,83 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import Optional, Tuple, Union, final +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch +import vllm.envs as env +from vllm.lora.layers import LoRAMapping from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import bgmv_expand - from vllm.lora.ops.triton_ops import bgmv_expand_slice - from vllm.lora.ops.triton_ops import bgmv_shrink - from vllm.lora.ops.triton_ops import sgmv_expand - from vllm.lora.ops.triton_ops import sgmv_shrink + if env.VLLM_USE_V1: + from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand, + v1_shrink) + else: + from vllm.lora.ops.triton_ops import bgmv_expand + from vllm.lora.ops.triton_ops import bgmv_expand_slice + from vllm.lora.ops.triton_ops import bgmv_shrink + from vllm.lora.ops.triton_ops import sgmv_expand + from vllm.lora.ops.triton_ops import sgmv_shrink from .punica_base import PunicaWrapperBase +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.models import LongContextLoRAContext + + +class V1KernelMixin: + + def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int, + max_batches: int, device: Union[torch.device, str]): + self.token_mapping_v1_meta = V1KernelMeta.make(max_loras, + max_num_batched_tokens, + device=device) + self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras, + max_batches, + device=device) + + def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor, + sampler_indices: torch.Tensor): + self.token_mapping_v1_meta.prepare_tensors(token_lora_indices) + self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices) + + def _v1_apply_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], + scale: float, + ): + v1_shrink( + x, + w_t_all, + y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + scale, + ) + + def _v1_apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], + offset_start: int, + add_inputs: bool, + ): + v1_expand( + x, + w_t_all, + y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + offset_start=offset_start, + add_inputs=add_inputs, + ) + @final -class PunicaWrapperGPU(PunicaWrapperBase): +class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin): """ PunicaWrapperGPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for @@ -35,6 +94,36 @@ class PunicaWrapperGPU(PunicaWrapperBase): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + self.max_loras = kwargs['max_loras'] + + if env.VLLM_USE_V1: + self._v1_make_metadata(self.max_loras, max_num_batched_tokens, + max_batches, device) + + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + if env.VLLM_USE_V1: + self.is_prefill = mapping.is_prefill + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + self._v1_prepare_metadata_tensors(self.token_lora_indices, + self.sampler_indices) + else: + # Forward to base class update_metadata + PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id, + max_loras, vocab_size, + extra_vocab_size, + long_lora_context, **kwargs) + def _apply_shrink_prefill( self, y: torch.Tensor, @@ -66,7 +155,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], offset_start: int, add_inputs: bool, ): @@ -118,14 +207,21 @@ class PunicaWrapperGPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) - if self.is_prefill: - # NOTE fused kernel - self._apply_shrink_prefill(y, x, lora_a_stacked, scale) + if env.VLLM_USE_V1: + self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore else: - # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink_decode(y[slice_idx], x, - lora_a_stacked[slice_idx], scale) + if self.is_prefill: + # NOTE fused kernel + self._apply_shrink_prefill( + y, # type: ignore + x, + lora_a_stacked, + scale) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink_decode(y[slice_idx], x, + lora_a_stacked[slice_idx], scale) def add_expand(self, y: torch.Tensor, @@ -160,25 +256,38 @@ class PunicaWrapperGPU(PunicaWrapperBase): if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) - if self.is_prefill: - # NOTE fused kernel - self._apply_expand_prefill(y, - x, - lora_b_stacked, - offset_start, - add_inputs=True) + + if env.VLLM_USE_V1: + # TODO (varun): Profile with add_inputs = False. i.e. move the + # addition out of the kernel + self._v1_apply_expand( + y, + x, # type: ignore + lora_b_stacked, + offset_start, + add_inputs=True) else: - # TODO fuse these kernels - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand_decode( + + if self.is_prefill: + # NOTE fused kernel + self._apply_expand_prefill( y, - x[slice_idx], - lora_b_stacked[slice_idx], + x, # type: ignore + lora_b_stacked, offset_start, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_start += output_slices[slice_idx] + add_inputs=True) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand_decode( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_start, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_start += output_slices[slice_idx] y = y.view_as(y_org) def add_lora_embedding(self, @@ -200,18 +309,24 @@ class PunicaWrapperGPU(PunicaWrapperBase): add_inputs (bool): Default to True. """ - if self.is_prefill: - sgmv_expand( - x.unsqueeze(dim=0), - [lora_b_stacked], - y, - *self.prefill_metadata, - offset_start=0, - add_inputs=add_inputs, - ) + if env.VLLM_USE_V1: + self._v1_apply_expand(y, + x.unsqueeze(dim=0), (lora_b_stacked, ), + offset_start=0, + add_inputs=add_inputs) else: - bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, - add_inputs) + if self.is_prefill: + sgmv_expand( + x.unsqueeze(dim=0), + (lora_b_stacked, ), + y, + *self.prefill_metadata, + offset_start=0, + add_inputs=add_inputs, + ) + else: + bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, + add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -257,19 +372,25 @@ class PunicaWrapperGPU(PunicaWrapperBase): r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros( + buffer = torch.zeros( # type: ignore (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device, ) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + self.add_shrink( + buffer, # type: ignore + x, + lora_a_stacked, + scale, + **kwargs) + self.add_expand( + y, + buffer, # type: ignore + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) def add_lora_logits(self, y: torch.Tensor, @@ -305,11 +426,22 @@ class PunicaWrapperGPU(PunicaWrapperBase): buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) + + if env.VLLM_USE_V1: + v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0), + *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) + + v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked], + y, + *self.prompt_mapping_v1_meta.meta_args(buffer.size(0)), + add_inputs=True) + else: + + # V0 LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) y = y.view_as(y_org) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index f34aacac..0b30a467 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -62,9 +62,9 @@ class LoRAModelRunnerMixin: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - # We dont make any distinction between prefills and decodes in the - # scheduler. To that effect, set is_prefill to True so we use the - # sgmv punica kernels always. + # Set is_prefill to True, so we always use the SGMV kernels. + # For cuda platforms, we have specialized triton kernels, and + # the cuda path ignores `is_prefill`. lora_mapping = LoRAMapping(token_lora_mapping, prompt_lora_mapping, is_prefill=True)