[Misc] Delete unused LoRA modules (#13151)
This commit is contained in:
parent
314cfade02
commit
82cabf53a3
@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
|||||||
|
|
||||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||||
MergedColumnParallelLinearWithLoRA)
|
MergedColumnParallelLinearWithLoRA)
|
||||||
|
# Verify packed lora is correct
|
||||||
|
model_lora_clone = model_lora.clone(1)
|
||||||
|
model_lora_clone1 = model_lora1.clone(1)
|
||||||
assert manager.add_adapter(model_lora)
|
assert manager.add_adapter(model_lora)
|
||||||
assert manager.add_adapter(model_lora1)
|
assert manager.add_adapter(model_lora1)
|
||||||
|
|
||||||
|
assert model_lora.get_lora("gate_proj") is None
|
||||||
|
assert model_lora.get_lora("up_proj") is None
|
||||||
|
assert model_lora1.get_lora("up_proj") is None
|
||||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||||
|
|
||||||
torch.testing.assert_close(packed_lora.lora_a[0],
|
torch.testing.assert_close(packed_lora.lora_a[0],
|
||||||
model_lora.get_lora("gate_proj").lora_a)
|
model_lora_clone.get_lora("gate_proj").lora_a)
|
||||||
torch.testing.assert_close(packed_lora.lora_b[0],
|
torch.testing.assert_close(packed_lora.lora_b[0],
|
||||||
model_lora.get_lora("gate_proj").lora_b)
|
model_lora_clone.get_lora("gate_proj").lora_b)
|
||||||
torch.testing.assert_close(packed_lora.lora_a[1],
|
torch.testing.assert_close(packed_lora.lora_a[1],
|
||||||
model_lora.get_lora("up_proj").lora_a)
|
model_lora_clone.get_lora("up_proj").lora_a)
|
||||||
torch.testing.assert_close(packed_lora.lora_b[1],
|
torch.testing.assert_close(packed_lora.lora_b[1],
|
||||||
model_lora.get_lora("up_proj").lora_b)
|
model_lora_clone.get_lora("up_proj").lora_b)
|
||||||
|
|
||||||
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||||
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||||
@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
|||||||
assert packed_lora1.lora_a[0] is None
|
assert packed_lora1.lora_a[0] is None
|
||||||
assert packed_lora1.lora_b[0] is None
|
assert packed_lora1.lora_b[0] is None
|
||||||
torch.testing.assert_close(packed_lora1.lora_a[1],
|
torch.testing.assert_close(packed_lora1.lora_a[1],
|
||||||
model_lora1.get_lora("up_proj").lora_a)
|
model_lora_clone1.get_lora("up_proj").lora_a)
|
||||||
torch.testing.assert_close(packed_lora1.lora_b[1],
|
torch.testing.assert_close(packed_lora1.lora_b[1],
|
||||||
model_lora1.get_lora("up_proj").lora_b)
|
model_lora_clone1.get_lora("up_proj").lora_b)
|
||||||
|
@ -5,7 +5,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||||
for module_name, new_module_names in self.packed_modules.items():
|
for module_name, new_module_names in self.packed_modules.items():
|
||||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||||
|
replaced_module: Set[str] = set()
|
||||||
has_replacement = False
|
has_replacement = False
|
||||||
for r in new_module_names:
|
for r in new_module_names:
|
||||||
lora = lora_model.get_lora(r)
|
lora = lora_model.get_lora(r)
|
||||||
replacement_loras.append(lora)
|
replacement_loras.append(lora)
|
||||||
if lora:
|
if lora:
|
||||||
has_replacement = True
|
has_replacement = True
|
||||||
|
replaced_module.add(r)
|
||||||
if not has_replacement:
|
if not has_replacement:
|
||||||
continue
|
continue
|
||||||
for i in range(len(replacement_loras)):
|
for i in range(len(replacement_loras)):
|
||||||
@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
replacement_loras[i] = None
|
replacement_loras[i] = None
|
||||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||||
replacement_loras)
|
replacement_loras)
|
||||||
|
# Remove the modules that have been replaced.
|
||||||
|
for module in replaced_module:
|
||||||
|
lora_model.loras.pop(module, None)
|
||||||
|
|
||||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||||
|
@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
# 5 is the number of indicies tensors.
|
# 5 is the number of indices tensors.
|
||||||
# base_indices, sampler_indices, sampler_indices_padded,
|
# base_indices, sampler_indices, sampler_indices_padded,
|
||||||
# embeddings_indices,long_lora_indices
|
# embeddings_indices,long_lora_indices
|
||||||
self.indices_len: List[Optional[int]] = [None] * 5
|
self.indices_len: List[Optional[int]] = [None] * 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user