[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-03 01:53:36 +08:00 committed by GitHub
parent a4c4daf364
commit b45f0d7946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 156 additions and 175 deletions

View File

@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts return generated_texts
@fork_new_process_for_each_test def generate_and_test(llm, sql_lora_files):
def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
print("lora adapter created") print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
print("removing lora") print("removing lora")
@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
generate_and_test(llm, sql_lora_files)
@fork_new_process_for_each_test @fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files): def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and """Test that the LLM initialization works with a warmup LORA path and
@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_loras=4, max_loras=4,
tensor_parallel_size=4, tensor_parallel_size=4,
) )
generate_and_test(llm, sql_lora_files)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora")
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
tensor_parallel_size=4, tensor_parallel_size=4,
fully_sharded_loras=True, fully_sharded_loras=True,
) )
print("lora adapter created") generate_and_test(llm, sql_lora_files)
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora") @multi_gpu_test(num_gpus=4)
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT @fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
print("lora 2") llm = vllm.LLM(
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT MODEL_PATH,
enable_lora=True,
print("removing lora") max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
)
generate_and_test(llm, sql_lora_files)

View File

@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self.punica_wrapper.add_expand(output, self.punica_wrapper.add_expand(output,
buffer, buffer,
self.lora_b_stacked, self.lora_b_stacked,
self.bias_stacked,
add_input=True) add_input=True)
# now have column partitioned output # now have column partitioned output
@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
layer.lora_a_stacked[idx], 1.0) layer.lora_a_stacked[idx], 1.0)
buffers = tensor_model_parallel_all_gather(buffers) buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0 layer.punica_wrapper.add_expand_packed_nslice(
for idx in range(n): output,
shard_size = layer.lora_b_stacked[idx].shape[2] buffers,
layer.lora_b_stacked,
if layer.bias_stacked is not None: layer.bias_stacked,
bias = layer.bias_stacked[idx] 1.0,
if bias is not None: layer.output_slices,
bias = bias.view(-1, bias.shape[-1]) )
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias
layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
)
left_offset += shard_size
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
# now have column partitioned and packed output # now have column partitioned and packed output
@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
self.punica_wrapper.add_expand(output, self.punica_wrapper.add_expand(output,
buffer, buffer,
self.lora_b_stacked, self.lora_b_stacked,
self.bias_all,
add_input=True) add_input=True)
# now have column partitioned output # now have column partitioned output
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used # reduced before being used
shard_size = self.lora_b_stacked.shape[2] shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer, self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx, self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size) shard_size)
output = output.view(*out_orig_shape) output = output.view(*out_orig_shape)
return output return output

View File

@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
return dec return dec
def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@dataclass @dataclass
class LoRAMapping(AdapterMapping): class LoRAMapping(AdapterMapping):
is_prefill: bool = False is_prefill: bool = False
@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.punica_wrapper.add_expand(full_output, self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings, full_lora_a_embeddings,
self.lora_b_stacked, self.lora_b_stacked,
bias_all=None,
add_input=True) add_input=True)
return full_output.view_as(full_output_org) return full_output.view_as(full_output_org)
@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, self.bias_stacked,
1.0)
return output return output
def forward(self, input_): def forward(self, input_):
@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, self.bias_stacked,
1.0)
return output return output
def forward(self, input_): def forward(self, input_):
@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for _ in range(n_slices)) ) for _ in range(n_slices))
else: else:
self.bias_stacked = None self.bias_stacked = None
self.output_dim = self.lora_b_stacked[0].shape[2] self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)
def reset_lora(self, index: int): def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0 self.lora_a_stacked[0][index] = 0
@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice( self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, output, x, self.lora_a_stacked, self.lora_b_stacked,
(self.output_dim, self.output_dim)) self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
return output return output
@classmethod @classmethod
@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x, self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked, self.lora_a_stacked,
self.lora_b_stacked, 1.0, self.lora_b_stacked,
self.bias_stacked, 1.0,
self.output_slices) self.output_slices)
return output return output
@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x) output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0) self.lora_b_stacked, self.bias_stacked,
1.0)
return output return output
def forward(self, input_): def forward(self, input_):

View File

@ -450,6 +450,62 @@ class PunicaWrapper:
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input) y_slice_size, add_input)
def apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
def add_shrink( def add_shrink(
self, self,
y: torch.Tensor, y: torch.Tensor,
@ -474,16 +530,19 @@ class PunicaWrapper:
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
add_input: bool = True, add_input: bool = True,
): ):
""" """
Perform the ` y+=x@w_t_all` computation, which is suitable for the Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
GEMM of lora'b. GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called. prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function Otherwise, it is the decode stage, and the expand_decode function
should be called. should be called.
""" """
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
expand_fun: Callable = (self.expand_prefill expand_fun: Callable = (self.expand_prefill
if self.is_prefill else self.expand_decode) if self.is_prefill else self.expand_decode)
@ -493,23 +552,54 @@ class PunicaWrapper:
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
y_offset: Optional[int], y_offset: Optional[int],
y_slice_size: Optional[int], y_slice_size: Optional[int],
add_input: bool = True): add_input: bool = True):
""" """
Similar to `add_expand` Similar to `add_expand`
""" """
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
expand_slice_fun: Callable = (self.expand_slice_prefill expand_slice_fun: Callable = (self.expand_slice_prefill
if self.is_prefill else if self.is_prefill else
self.expand_slice_decode) self.expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...],
bias_stacked: Optional[Tuple[torch.Tensor,
...]],
scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Similar to `add_expand`
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
if bias_stacked is not None:
self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self.add_expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
None,
offset_left,
output_slices[slice_idx],
add_input=True)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora(self, def add_lora(self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor, wb_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
scale: float, scale: float,
y_offset: Optional[int] = None, y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None, y_slice_size: Optional[int] = None,
@ -522,12 +612,13 @@ class PunicaWrapper:
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale * scale
).squeeze(0) ).squeeze(0)+bias[i]
Args: Args:
y (torch.Tensor): Output tensor. Will be changed in-place. y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight wb_t_all (torch.Tensor): lora_b's weight
bias_all: (torch.Tensor): lora's bias
scale (float): Scaling factor. scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting y_offset (Optional[int], optional): Offset to apply to the starting
column of y. column of y.
@ -544,27 +635,26 @@ class PunicaWrapper:
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
self.add_shrink(buffer, x, wa_t_all, scale) self.add_shrink(buffer, x, wa_t_all, scale)
if y_offset is None and y_slice_size is None: if y_offset is None and y_slice_size is None:
self.add_expand(y, buffer, wb_t_all, add_input=True) self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
else: else:
self.add_expand_slice(y, self.add_expand_slice(y,
buffer, buffer,
wb_t_all, wb_t_all,
None,
y_offset, y_offset,
y_slice_size, y_slice_size,
add_input=True) add_input=True)
y = y.view_as(y_org) y = y.view_as(y_org)
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
torch.Tensor, lora_b_stacked: Tuple[torch.Tensor, ...],
torch.Tensor], bias_all: Tuple[Optional[torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...], scale: float,
torch.Tensor,
torch.Tensor],
scale: float,
output_slices: Tuple[int, ...]) -> None: output_slices: Tuple[int, ...]) -> None:
""" """
Applies lora to each input. Similar to add_lora, This method is Applies lora to each input. Similar to add_lora, This method is
@ -575,10 +665,13 @@ class PunicaWrapper:
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
offset_left = 0 offset_left = 0
if bias_all is not None:
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_all)
# TODO fuse these kernels # TODO fuse these kernels
for slice_idx in range(len(output_slices)): for slice_idx in range(len(output_slices)):
self.add_lora(y, x, lora_a_stacked[slice_idx], self.add_lora(y, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], scale, offset_left, lora_b_stacked[slice_idx], None, scale, offset_left,
output_slices[slice_idx]) output_slices[slice_idx])
offset_left += output_slices[slice_idx] offset_left += output_slices[slice_idx]