[Misc] Delete unused LoRA modules (#13151)
This commit is contained in:
parent
314cfade02
commit
82cabf53a3
@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
# Verify packed lora is correct
|
||||
model_lora_clone = model_lora.clone(1)
|
||||
model_lora_clone1 = model_lora1.clone(1)
|
||||
assert manager.add_adapter(model_lora)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
|
||||
assert model_lora.get_lora("gate_proj") is None
|
||||
assert model_lora.get_lora("up_proj") is None
|
||||
assert model_lora1.get_lora("up_proj") is None
|
||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||
|
||||
torch.testing.assert_close(packed_lora.lora_a[0],
|
||||
model_lora.get_lora("gate_proj").lora_a)
|
||||
model_lora_clone.get_lora("gate_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[0],
|
||||
model_lora.get_lora("gate_proj").lora_b)
|
||||
model_lora_clone.get_lora("gate_proj").lora_b)
|
||||
torch.testing.assert_close(packed_lora.lora_a[1],
|
||||
model_lora.get_lora("up_proj").lora_a)
|
||||
model_lora_clone.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[1],
|
||||
model_lora.get_lora("up_proj").lora_b)
|
||||
model_lora_clone.get_lora("up_proj").lora_b)
|
||||
|
||||
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||
@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
assert packed_lora1.lora_a[0] is None
|
||||
assert packed_lora1.lora_b[0] is None
|
||||
torch.testing.assert_close(packed_lora1.lora_a[1],
|
||||
model_lora1.get_lora("up_proj").lora_a)
|
||||
model_lora_clone1.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora1.lora_b[1],
|
||||
model_lora1.get_lora("up_proj").lora_b)
|
||||
model_lora_clone1.get_lora("up_proj").lora_b)
|
||||
|
@ -5,7 +5,8 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
|
||||
Union)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
replaced_module: Set[str] = set()
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
replaced_module.add(r)
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
# Remove the modules that have been replaced.
|
||||
for module in replaced_module:
|
||||
lora_model.loras.pop(module, None)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
|
@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indicies tensors.
|
||||
# 5 is the number of indices tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user