[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)
This commit is contained in:
parent
e808156f30
commit
36ea79079b
@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def baichuan_regex_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def minicpmv_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
||||
|
@ -5,7 +5,9 @@ import pytest
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
|
||||
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
|
||||
lora_lst = [
|
||||
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_name", lora_lst)
|
||||
@ -13,6 +15,7 @@ def test_load_checkpoints(
|
||||
lora_name,
|
||||
baichuan_lora_files,
|
||||
baichuan_zero_lora_files,
|
||||
baichuan_regex_lora_files,
|
||||
chatglm3_lora_files,
|
||||
):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
@ -36,7 +39,7 @@ def test_load_checkpoints(
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
elif lora_name == "baichuan7B-zero":
|
||||
#Test that the target_modules contain prefix
|
||||
# Test that the target_modules contain prefix
|
||||
# such as "model.layers.0.self_atten.W_pack", and
|
||||
# the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
@ -46,6 +49,16 @@ def test_load_checkpoints(
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
elif lora_name == "baichuan7B-zero-regex":
|
||||
# Test that the `target_modules` in the form of regular expressions,
|
||||
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_regex_lora_files,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
else:
|
||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||
# and the test should raise the following error.
|
||||
|
@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@ -233,6 +234,8 @@ class LoRAModel(AdapterModel):
|
||||
# modules.
|
||||
unexpected_modules = []
|
||||
target_modules = config["target_modules"]
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
for module in target_modules:
|
||||
# Compatible with more modules,
|
||||
# such as:layers.11.self_attn.k_proj
|
||||
@ -243,8 +246,8 @@ class LoRAModel(AdapterModel):
|
||||
# expected_lora_modules. It is not reliable. See
|
||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||
# other better mechanism.
|
||||
if unexpected_modules:
|
||||
print(unexpected_modules, "modules")
|
||||
if unexpected_modules and not is_regex_target_modules(
|
||||
config["target_modules"], expected_lora_modules):
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
import re
|
||||
from typing import List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
|
||||
def is_regex_target_modules(load_modules: Union[str, List[str]],
|
||||
expected_lora_modules: List[str]) -> bool:
|
||||
"""
|
||||
PEFT supports passing `target_modules` in the form of regular expressions,
|
||||
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
||||
determine whether the suffix in the regular expression is present in the
|
||||
`expected_lora_modules`.
|
||||
"""
|
||||
|
||||
def is_valid_regex(pattern):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
return True
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
def is_subset(sub_list, full_list):
|
||||
return set(sub_list).issubset(set(full_list))
|
||||
|
||||
# Similar to PEFT's processing logic, regex-related operations are only
|
||||
# executed when the load_modules is a `str`.
|
||||
if not isinstance(load_modules, str):
|
||||
return False
|
||||
|
||||
if is_valid_regex(load_modules):
|
||||
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
||||
if match:
|
||||
suffix = match.group(1).split("|")
|
||||
return is_subset(suffix, expected_lora_modules)
|
||||
return False
|
||||
|
||||
|
||||
def get_adapter_absolute_path(lora_path: str) -> str:
|
||||
"""
|
||||
Resolves the given lora_path to an absolute local path.
|
||||
|
Loading…
x
Reference in New Issue
Block a user