[Misc] Delete unused LoRA modules (#13151)

This commit is contained in:
Jee Jee Li 2025-02-13 00:58:24 +08:00 committed by GitHub
parent 314cfade02
commit 82cabf53a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 8 deletions

View File

@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert isinstance(model.get_submodule("gate_up_proj"), assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA) MergedColumnParallelLinearWithLoRA)
# Verify packed lora is correct
model_lora_clone = model_lora.clone(1)
model_lora_clone1 = model_lora1.clone(1)
assert manager.add_adapter(model_lora) assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1) assert manager.add_adapter(model_lora1)
assert model_lora.get_lora("gate_proj") is None
assert model_lora.get_lora("up_proj") is None
assert model_lora1.get_lora("up_proj") is None
packed_lora = model_lora.get_lora("gate_up_proj") packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
torch.testing.assert_close(packed_lora.lora_a[0], torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a) model_lora_clone.get_lora("gate_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[0], torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b) model_lora_clone.get_lora("gate_proj").lora_b)
torch.testing.assert_close(packed_lora.lora_a[1], torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a) model_lora_clone.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[1], torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b) model_lora_clone.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj") packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None assert packed_lora1.lora_b[0] is None
torch.testing.assert_close(packed_lora1.lora_a[1], torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a) model_lora_clone1.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora1.lora_b[1], torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b) model_lora_clone1.get_lora("up_proj").lora_b)

View File

@ -5,7 +5,8 @@ import math
import os import os
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)
import safetensors.torch import safetensors.torch
import torch import torch
@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items(): for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = [] replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = lora_model.get_lora(r)
replacement_loras.append(lora) replacement_loras.append(lora)
if lora: if lora:
has_replacement = True has_replacement = True
replaced_module.add(r)
if not has_replacement: if not has_replacement:
continue continue
for i in range(len(replacement_loras)): for i in range(len(replacement_loras)):
@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
replacement_loras[i] = None replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters, return deactivate_adapter(adapter_id, self._active_adapters,

View File

@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
dtype=torch.long, dtype=torch.long,
device=device) device=device)
# 5 is the number of indicies tensors. # 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded, # base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices # embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5 self.indices_len: List[Optional[int]] = [None] * 5