[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a4c4daf364
commit
b45f0d7946
@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
|||||||
return generated_texts
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
def generate_and_test(llm, sql_lora_files):
|
||||||
def test_llama_lora(sql_lora_files):
|
|
||||||
|
|
||||||
llm = vllm.LLM(MODEL_PATH,
|
|
||||||
enable_lora=True,
|
|
||||||
max_num_seqs=16,
|
|
||||||
max_loras=4,
|
|
||||||
tensor_parallel_size=1)
|
|
||||||
|
|
||||||
print("lora adapter created")
|
print("lora adapter created")
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||||
|
|
||||||
@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
|
|||||||
print("removing lora")
|
print("removing lora")
|
||||||
|
|
||||||
|
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_llama_lora(sql_lora_files):
|
||||||
|
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=1)
|
||||||
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
def test_llama_lora_warmup(sql_lora_files):
|
def test_llama_lora_warmup(sql_lora_files):
|
||||||
"""Test that the LLM initialization works with a warmup LORA path and
|
"""Test that the LLM initialization works with a warmup LORA path and
|
||||||
@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
|
|||||||
max_loras=4,
|
max_loras=4,
|
||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
)
|
)
|
||||||
|
generate_and_test(llm, sql_lora_files)
|
||||||
print("lora adapter created")
|
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("lora 1")
|
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("no lora")
|
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("lora 2")
|
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("removing lora")
|
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
|||||||
tensor_parallel_size=4,
|
tensor_parallel_size=4,
|
||||||
fully_sharded_loras=True,
|
fully_sharded_loras=True,
|
||||||
)
|
)
|
||||||
print("lora adapter created")
|
generate_and_test(llm, sql_lora_files)
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("lora 1")
|
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
|
|
||||||
|
|
||||||
print("no lora")
|
@multi_gpu_test(num_gpus=4)
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
@fork_new_process_for_each_test
|
||||||
|
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
|
||||||
|
|
||||||
print("lora 2")
|
llm = vllm.LLM(
|
||||||
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
|
MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
print("removing lora")
|
max_num_seqs=16,
|
||||||
|
max_loras=4,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
fully_sharded_loras=True,
|
||||||
|
enable_lora_bias=True,
|
||||||
|
)
|
||||||
|
generate_and_test(llm, sql_lora_files)
|
||||||
|
@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.punica_wrapper.add_expand(output,
|
self.punica_wrapper.add_expand(output,
|
||||||
buffer,
|
buffer,
|
||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
|
self.bias_stacked,
|
||||||
add_input=True)
|
add_input=True)
|
||||||
# now have column partitioned output
|
# now have column partitioned output
|
||||||
|
|
||||||
@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
|||||||
layer.lora_a_stacked[idx], 1.0)
|
layer.lora_a_stacked[idx], 1.0)
|
||||||
|
|
||||||
buffers = tensor_model_parallel_all_gather(buffers)
|
buffers = tensor_model_parallel_all_gather(buffers)
|
||||||
left_offset = 0
|
layer.punica_wrapper.add_expand_packed_nslice(
|
||||||
for idx in range(n):
|
output,
|
||||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
buffers,
|
||||||
|
layer.lora_b_stacked,
|
||||||
if layer.bias_stacked is not None:
|
layer.bias_stacked,
|
||||||
bias = layer.bias_stacked[idx]
|
1.0,
|
||||||
if bias is not None:
|
layer.output_slices,
|
||||||
bias = bias.view(-1, bias.shape[-1])
|
)
|
||||||
bias = bias[layer.punica_wrapper.token_lora_indices]
|
|
||||||
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
|
|
||||||
output[:, left_offset:left_offset + shard_size] += bias
|
|
||||||
|
|
||||||
layer.punica_wrapper.add_expand_slice(
|
|
||||||
output,
|
|
||||||
buffers[idx],
|
|
||||||
layer.lora_b_stacked[idx],
|
|
||||||
left_offset,
|
|
||||||
shard_size,
|
|
||||||
add_input=True,
|
|
||||||
)
|
|
||||||
left_offset += shard_size
|
|
||||||
|
|
||||||
output = output.view(*out_orig_shape)
|
output = output.view(*out_orig_shape)
|
||||||
# now have column partitioned and packed output
|
# now have column partitioned and packed output
|
||||||
@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
|||||||
self.punica_wrapper.add_expand(output,
|
self.punica_wrapper.add_expand(output,
|
||||||
buffer,
|
buffer,
|
||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
|
self.bias_all,
|
||||||
add_input=True)
|
add_input=True)
|
||||||
# now have column partitioned output
|
# now have column partitioned output
|
||||||
output = output.view(*out_orig_shape)
|
output = output.view(*out_orig_shape)
|
||||||
@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
# reduced before being used
|
# reduced before being used
|
||||||
shard_size = self.lora_b_stacked.shape[2]
|
shard_size = self.lora_b_stacked.shape[2]
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
|
|
||||||
if self.bias_stacked is not None:
|
|
||||||
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
|
||||||
bias = bias[self.punica_wrapper.token_lora_indices]
|
|
||||||
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
|
||||||
output += bias
|
|
||||||
|
|
||||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||||
self.lora_b_stacked, start_idx,
|
self.lora_b_stacked,
|
||||||
|
self.bias_stacked, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
output = output.view(*out_orig_shape)
|
output = output.view(*out_orig_shape)
|
||||||
return output
|
return output
|
||||||
|
@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
|
|||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
def apply_bias(
|
|
||||||
indices: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
bias_stacked: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""Applies bias to output
|
|
||||||
|
|
||||||
Input shapes:
|
|
||||||
bias_stacked: (num_loras, output_dim)
|
|
||||||
indices: (batch_size)
|
|
||||||
output: (batch_size, output_dim)
|
|
||||||
"""
|
|
||||||
org_output = output
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
indices = indices.view(-1)
|
|
||||||
|
|
||||||
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
|
||||||
bias_stacked = bias_stacked[indices]
|
|
||||||
bias_stacked[indices == -1] = 0
|
|
||||||
output += bias_stacked
|
|
||||||
|
|
||||||
return output.view_as(org_output)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_bias_packed_nslice(
|
|
||||||
indices: torch.Tensor,
|
|
||||||
output: torch.Tensor,
|
|
||||||
output_slices: Tuple[int, ...],
|
|
||||||
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
||||||
):
|
|
||||||
"""Applies bias to output
|
|
||||||
|
|
||||||
Input shapes:
|
|
||||||
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
|
||||||
indices: (batch_size)
|
|
||||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
|
||||||
output_slices: n-1 element tuple of (slice_size...),
|
|
||||||
where n is number of slices
|
|
||||||
"""
|
|
||||||
org_output = output
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
indices = indices.view(-1)
|
|
||||||
|
|
||||||
offset_left = 0
|
|
||||||
for slice_idx, slice in enumerate(output_slices):
|
|
||||||
bias = bias_stacked[slice_idx]
|
|
||||||
if bias is not None:
|
|
||||||
bias = bias.view(-1, bias.shape[-1])
|
|
||||||
bias = bias[indices]
|
|
||||||
bias[indices == -1] = 0
|
|
||||||
output[:, offset_left:offset_left + slice] += bias
|
|
||||||
|
|
||||||
offset_left += slice
|
|
||||||
|
|
||||||
return output.view_as(org_output)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAMapping(AdapterMapping):
|
class LoRAMapping(AdapterMapping):
|
||||||
is_prefill: bool = False
|
is_prefill: bool = False
|
||||||
@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.punica_wrapper.add_expand(full_output,
|
self.punica_wrapper.add_expand(full_output,
|
||||||
full_lora_a_embeddings,
|
full_lora_a_embeddings,
|
||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
|
bias_all=None,
|
||||||
add_input=True)
|
add_input=True)
|
||||||
return full_output.view_as(full_output_org)
|
return full_output.view_as(full_output_org)
|
||||||
|
|
||||||
@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
if self.bias_stacked is not None:
|
|
||||||
self.indices = self.punica_wrapper.token_lora_indices
|
|
||||||
output = apply_bias(
|
|
||||||
self.indices,
|
|
||||||
output,
|
|
||||||
self.bias_stacked,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, self.bias_stacked,
|
||||||
|
1.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
if self.bias_stacked is not None:
|
|
||||||
self.indices = self.punica_wrapper.token_lora_indices
|
|
||||||
output = apply_bias(
|
|
||||||
self.indices,
|
|
||||||
output,
|
|
||||||
self.bias_stacked,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, self.bias_stacked,
|
||||||
|
1.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
) for _ in range(n_slices))
|
) for _ in range(n_slices))
|
||||||
else:
|
else:
|
||||||
self.bias_stacked = None
|
self.bias_stacked = None
|
||||||
|
|
||||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||||
|
self.output_slices = (self.output_dim, self.output_dim)
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[0][index] = 0
|
self.lora_a_stacked[0][index] = 0
|
||||||
@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
if self.bias_stacked is not None:
|
|
||||||
self.indices = self.punica_wrapper.token_lora_indices
|
|
||||||
output = apply_bias_packed_nslice(
|
|
||||||
self.indices,
|
|
||||||
output,
|
|
||||||
(self.output_dim, self.output_dim),
|
|
||||||
self.bias_stacked,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_lora_packed_nslice(
|
self.punica_wrapper.add_lora_packed_nslice(
|
||||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
output, x, self.lora_a_stacked, self.lora_b_stacked,
|
||||||
(self.output_dim, self.output_dim))
|
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
if self.bias_stacked is not None:
|
|
||||||
self.indices = self.punica_wrapper.token_lora_indices
|
|
||||||
output = apply_bias_packed_nslice(
|
|
||||||
self.indices,
|
|
||||||
output,
|
|
||||||
self.output_slices,
|
|
||||||
self.bias_stacked,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
||||||
self.lora_a_stacked,
|
self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0,
|
self.lora_b_stacked,
|
||||||
|
self.bias_stacked, 1.0,
|
||||||
self.output_slices)
|
self.output_slices)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
|
|
||||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||||
if self.bias_stacked is not None:
|
|
||||||
self.indices = self.punica_wrapper.token_lora_indices
|
|
||||||
output = apply_bias(
|
|
||||||
self.indices,
|
|
||||||
output,
|
|
||||||
self.bias_stacked,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, self.bias_stacked,
|
||||||
|
1.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
|
@ -450,6 +450,62 @@ class PunicaWrapper:
|
|||||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||||
y_slice_size, add_input)
|
y_slice_size, add_input)
|
||||||
|
|
||||||
|
def apply_bias(
|
||||||
|
self,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
bias_stacked: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Applies bias to output
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
bias_stacked: (num_loras, output_dim)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, output_dim)
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
|
||||||
|
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
||||||
|
bias_stacked = bias_stacked[indices]
|
||||||
|
bias_stacked[indices == -1] = 0
|
||||||
|
output += bias_stacked
|
||||||
|
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
def apply_bias_packed_nslice(
|
||||||
|
self,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
||||||
|
):
|
||||||
|
"""Applies bias to output
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||||
|
output_slices: n-1 element tuple of (slice_size...),
|
||||||
|
where n is number of slices
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
|
||||||
|
offset_left = 0
|
||||||
|
for slice_idx, slice in enumerate(output_slices):
|
||||||
|
bias = bias_stacked[slice_idx]
|
||||||
|
if bias is not None:
|
||||||
|
bias = bias.view(-1, bias.shape[-1])
|
||||||
|
bias = bias[indices]
|
||||||
|
bias[indices == -1] = 0
|
||||||
|
output[:, offset_left:offset_left + slice] += bias
|
||||||
|
offset_left += slice
|
||||||
|
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
def add_shrink(
|
def add_shrink(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
@ -474,16 +530,19 @@ class PunicaWrapper:
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
|
bias_all: Optional[torch.Tensor],
|
||||||
add_input: bool = True,
|
add_input: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
|
||||||
GEMM of lora'b.
|
GEMM of lora'b.
|
||||||
When `is_prefill` is true, it indicates that it is currently the
|
When `is_prefill` is true, it indicates that it is currently the
|
||||||
prefill stage, and the `expand_prefill` function should be called.
|
prefill stage, and the `expand_prefill` function should be called.
|
||||||
Otherwise, it is the decode stage, and the expand_decode function
|
Otherwise, it is the decode stage, and the expand_decode function
|
||||||
should be called.
|
should be called.
|
||||||
"""
|
"""
|
||||||
|
if bias_all is not None:
|
||||||
|
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||||
|
|
||||||
expand_fun: Callable = (self.expand_prefill
|
expand_fun: Callable = (self.expand_prefill
|
||||||
if self.is_prefill else self.expand_decode)
|
if self.is_prefill else self.expand_decode)
|
||||||
@ -493,23 +552,54 @@ class PunicaWrapper:
|
|||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
|
bias_all: Optional[torch.Tensor],
|
||||||
y_offset: Optional[int],
|
y_offset: Optional[int],
|
||||||
y_slice_size: Optional[int],
|
y_slice_size: Optional[int],
|
||||||
add_input: bool = True):
|
add_input: bool = True):
|
||||||
"""
|
"""
|
||||||
Similar to `add_expand`
|
Similar to `add_expand`
|
||||||
"""
|
"""
|
||||||
|
if bias_all is not None:
|
||||||
|
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||||
|
|
||||||
expand_slice_fun: Callable = (self.expand_slice_prefill
|
expand_slice_fun: Callable = (self.expand_slice_prefill
|
||||||
if self.is_prefill else
|
if self.is_prefill else
|
||||||
self.expand_slice_decode)
|
self.expand_slice_decode)
|
||||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
||||||
|
|
||||||
|
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
bias_stacked: Optional[Tuple[torch.Tensor,
|
||||||
|
...]],
|
||||||
|
scale: float,
|
||||||
|
output_slices: Tuple[int, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Similar to `add_expand`
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
offset_left = 0
|
||||||
|
if bias_stacked is not None:
|
||||||
|
self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
||||||
|
output_slices, bias_stacked)
|
||||||
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
|
self.add_expand_slice(y,
|
||||||
|
x[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx],
|
||||||
|
None,
|
||||||
|
offset_left,
|
||||||
|
output_slices[slice_idx],
|
||||||
|
add_input=True)
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
def add_lora(self,
|
def add_lora(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
wa_t_all: torch.Tensor,
|
wa_t_all: torch.Tensor,
|
||||||
wb_t_all: torch.Tensor,
|
wb_t_all: torch.Tensor,
|
||||||
|
bias_all: Optional[torch.Tensor],
|
||||||
scale: float,
|
scale: float,
|
||||||
y_offset: Optional[int] = None,
|
y_offset: Optional[int] = None,
|
||||||
y_slice_size: Optional[int] = None,
|
y_slice_size: Optional[int] = None,
|
||||||
@ -522,12 +612,13 @@ class PunicaWrapper:
|
|||||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)
|
).squeeze(0)+bias[i]
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor
|
||||||
wa_t_all (torch.Tensor): lora_a's weight
|
wa_t_all (torch.Tensor): lora_a's weight
|
||||||
wb_t_all (torch.Tensor): lora_b's weight
|
wb_t_all (torch.Tensor): lora_b's weight
|
||||||
|
bias_all: (torch.Tensor): lora's bias
|
||||||
scale (float): Scaling factor.
|
scale (float): Scaling factor.
|
||||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||||
column of y.
|
column of y.
|
||||||
@ -544,27 +635,26 @@ class PunicaWrapper:
|
|||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
if bias_all is not None:
|
||||||
|
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||||
self.add_shrink(buffer, x, wa_t_all, scale)
|
self.add_shrink(buffer, x, wa_t_all, scale)
|
||||||
if y_offset is None and y_slice_size is None:
|
if y_offset is None and y_slice_size is None:
|
||||||
self.add_expand(y, buffer, wb_t_all, add_input=True)
|
self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
|
||||||
else:
|
else:
|
||||||
self.add_expand_slice(y,
|
self.add_expand_slice(y,
|
||||||
buffer,
|
buffer,
|
||||||
wb_t_all,
|
wb_t_all,
|
||||||
|
None,
|
||||||
y_offset,
|
y_offset,
|
||||||
y_slice_size,
|
y_slice_size,
|
||||||
add_input=True)
|
add_input=True)
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||||
lora_a_stacked: Tuple[torch.Tensor,
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
torch.Tensor,
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
torch.Tensor],
|
bias_all: Tuple[Optional[torch.Tensor],
|
||||||
lora_b_stacked: Tuple[torch.Tensor,
|
...], scale: float,
|
||||||
torch.Tensor,
|
|
||||||
torch.Tensor],
|
|
||||||
scale: float,
|
|
||||||
output_slices: Tuple[int, ...]) -> None:
|
output_slices: Tuple[int, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Applies lora to each input. Similar to add_lora, This method is
|
Applies lora to each input. Similar to add_lora, This method is
|
||||||
@ -575,10 +665,13 @@ class PunicaWrapper:
|
|||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
offset_left = 0
|
offset_left = 0
|
||||||
|
if bias_all is not None:
|
||||||
|
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
||||||
|
output_slices, bias_all)
|
||||||
# TODO fuse these kernels
|
# TODO fuse these kernels
|
||||||
for slice_idx in range(len(output_slices)):
|
for slice_idx in range(len(output_slices)):
|
||||||
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
||||||
lora_b_stacked[slice_idx], scale, offset_left,
|
lora_b_stacked[slice_idx], None, scale, offset_left,
|
||||||
output_slices[slice_idx])
|
output_slices[slice_idx])
|
||||||
offset_left += output_slices[slice_idx]
|
offset_left += output_slices[slice_idx]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user