[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)

This commit is contained in:
Jee Jee Li 2024-10-11 20:31:21 +08:00 committed by GitHub
parent e808156f30
commit 36ea79079b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 5 deletions

View File

@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
@pytest.fixture(scope="session")
def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def minicpmv_lora_files(): def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")

View File

@ -5,7 +5,9 @@ import pytest
from vllm.lora.models import LoRAModel from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"] lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
@pytest.mark.parametrize("lora_name", lora_lst) @pytest.mark.parametrize("lora_name", lora_lst)
@ -13,6 +15,7 @@ def test_load_checkpoints(
lora_name, lora_name,
baichuan_lora_files, baichuan_lora_files,
baichuan_zero_lora_files, baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files, chatglm3_lora_files,
): ):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
@ -36,7 +39,7 @@ def test_load_checkpoints(
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules) embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero": elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix # Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and # such as "model.layers.0.self_atten.W_pack", and
# the test should pass. # the test should pass.
LoRAModel.from_local_checkpoint( LoRAModel.from_local_checkpoint(
@ -46,6 +49,16 @@ def test_load_checkpoints(
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules) embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else: else:
# For the baichuan7B model, load chatglm3-6b's LoRA, # For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error. # and the test should raise the following error.

View File

@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor, from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -233,6 +234,8 @@ class LoRAModel(AdapterModel):
# modules. # modules.
unexpected_modules = [] unexpected_modules = []
target_modules = config["target_modules"] target_modules = config["target_modules"]
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules: for module in target_modules:
# Compatible with more modules, # Compatible with more modules,
# such as:layers.11.self_attn.k_proj # such as:layers.11.self_attn.k_proj
@ -243,8 +246,8 @@ class LoRAModel(AdapterModel):
# expected_lora_modules. It is not reliable. See # expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no # https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism. # other better mechanism.
if unexpected_modules: if unexpected_modules and not is_regex_target_modules(
print(unexpected_modules, "modules") config["target_modules"], expected_lora_modules):
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"

View File

@ -1,5 +1,6 @@
import os import os
from typing import List, Optional, Set, Tuple, Type import re
from typing import List, Optional, Set, Tuple, Type, Union
import huggingface_hub import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, List[str]],
expected_lora_modules: List[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_adapter_absolute_path(lora_path: str) -> str: def get_adapter_absolute_path(lora_path: str) -> str:
""" """
Resolves the given lora_path to an absolute local path. Resolves the given lora_path to an absolute local path.