[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-12-03 01:53:36 +08:00 committed by GitHub
parent a4c4daf364
commit b45f0d7946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 156 additions and 175 deletions

View File

@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts
@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
def generate_and_test(llm, sql_lora_files):
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
print("removing lora")
@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
generate_and_test(llm, sql_lora_files)
@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora")
generate_and_test(llm, sql_lora_files)
@multi_gpu_test(num_gpus=4)
@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
tensor_parallel_size=4,
fully_sharded_loras=True,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
generate_and_test(llm, sql_lora_files)
print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
print("removing lora")
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
)
generate_and_test(llm, sql_lora_files)

View File

@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output
@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
layer.lora_a_stacked[idx], 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias
layer.punica_wrapper.add_expand_slice(
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers[idx],
layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)
left_offset += shard_size
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_all,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
output = output.view(*out_orig_shape)
return output

View File

@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
return dec
def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
bias_all=None,
add_input=True)
return full_output.view_as(full_output_org)
@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output
def forward(self, input_):
@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output
def forward(self, input_):
@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) for _ in range(n_slices))
else:
self.bias_stacked = None
self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)
def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim))
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
return output
@classmethod
@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked,
self.lora_b_stacked, 1.0,
self.lora_b_stacked,
self.bias_stacked, 1.0,
self.output_slices)
return output
@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output
def forward(self, input_):

View File

@ -450,6 +450,62 @@ class PunicaWrapper:
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_input)
def apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
def add_shrink(
self,
y: torch.Tensor,
@ -474,16 +530,19 @@ class PunicaWrapper:
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
add_input: bool = True,
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
GEMM of lora'b.
When `is_prefill` is true, it indicates that it is currently the
prefill stage, and the `expand_prefill` function should be called.
Otherwise, it is the decode stage, and the expand_decode function
should be called.
"""
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
expand_fun: Callable = (self.expand_prefill
if self.is_prefill else self.expand_decode)
@ -493,23 +552,54 @@ class PunicaWrapper:
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
y_offset: Optional[int],
y_slice_size: Optional[int],
add_input: bool = True):
"""
Similar to `add_expand`
"""
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
expand_slice_fun: Callable = (self.expand_slice_prefill
if self.is_prefill else
self.expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_b_stacked: Tuple[torch.Tensor, ...],
bias_stacked: Optional[Tuple[torch.Tensor,
...]],
scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Similar to `add_expand`
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
if bias_stacked is not None:
self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self.add_expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
None,
offset_left,
output_slices[slice_idx],
add_input=True)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
bias_all: Optional[torch.Tensor],
scale: float,
y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None,
@ -522,12 +612,13 @@ class PunicaWrapper:
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
).squeeze(0)+bias[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
wa_t_all (torch.Tensor): lora_a's weight
wb_t_all (torch.Tensor): lora_b's weight
bias_all: (torch.Tensor): lora's bias
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
@ -544,27 +635,26 @@ class PunicaWrapper:
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
if bias_all is not None:
y = self.apply_bias(self.token_lora_indices, y, bias_all)
self.add_shrink(buffer, x, wa_t_all, scale)
if y_offset is None and y_slice_size is None:
self.add_expand(y, buffer, wb_t_all, add_input=True)
self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
else:
self.add_expand_slice(y,
buffer,
wb_t_all,
None,
y_offset,
y_slice_size,
add_input=True)
y = y.view_as(y_org)
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
scale: float,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
bias_all: Tuple[Optional[torch.Tensor],
...], scale: float,
output_slices: Tuple[int, ...]) -> None:
"""
Applies lora to each input. Similar to add_lora, This method is
@ -575,10 +665,13 @@ class PunicaWrapper:
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0
if bias_all is not None:
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
output_slices, bias_all)
# TODO fuse these kernels
for slice_idx in range(len(output_slices)):
self.add_lora(y, x, lora_a_stacked[slice_idx],
lora_b_stacked[slice_idx], scale, offset_left,
lora_b_stacked[slice_idx], None, scale, offset_left,
output_slices[slice_idx])
offset_left += output_slices[slice_idx]