[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)
This commit is contained in:
parent
e808156f30
commit
36ea79079b
@ -199,6 +199,11 @@ def baichuan_zero_lora_files():
|
|||||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
|
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def baichuan_regex_lora_files():
|
||||||
|
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def minicpmv_lora_files():
|
def minicpmv_lora_files():
|
||||||
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
||||||
|
@ -5,7 +5,9 @@ import pytest
|
|||||||
from vllm.lora.models import LoRAModel
|
from vllm.lora.models import LoRAModel
|
||||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
|
|
||||||
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
|
lora_lst = [
|
||||||
|
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("lora_name", lora_lst)
|
@pytest.mark.parametrize("lora_name", lora_lst)
|
||||||
@ -13,6 +15,7 @@ def test_load_checkpoints(
|
|||||||
lora_name,
|
lora_name,
|
||||||
baichuan_lora_files,
|
baichuan_lora_files,
|
||||||
baichuan_zero_lora_files,
|
baichuan_zero_lora_files,
|
||||||
|
baichuan_regex_lora_files,
|
||||||
chatglm3_lora_files,
|
chatglm3_lora_files,
|
||||||
):
|
):
|
||||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||||
@ -36,7 +39,7 @@ def test_load_checkpoints(
|
|||||||
embedding_modules=embedding_modules,
|
embedding_modules=embedding_modules,
|
||||||
embedding_padding_modules=embed_padding_modules)
|
embedding_padding_modules=embed_padding_modules)
|
||||||
elif lora_name == "baichuan7B-zero":
|
elif lora_name == "baichuan7B-zero":
|
||||||
#Test that the target_modules contain prefix
|
# Test that the target_modules contain prefix
|
||||||
# such as "model.layers.0.self_atten.W_pack", and
|
# such as "model.layers.0.self_atten.W_pack", and
|
||||||
# the test should pass.
|
# the test should pass.
|
||||||
LoRAModel.from_local_checkpoint(
|
LoRAModel.from_local_checkpoint(
|
||||||
@ -46,6 +49,16 @@ def test_load_checkpoints(
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
embedding_modules=embedding_modules,
|
embedding_modules=embedding_modules,
|
||||||
embedding_padding_modules=embed_padding_modules)
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
elif lora_name == "baichuan7B-zero-regex":
|
||||||
|
# Test that the `target_modules` in the form of regular expressions,
|
||||||
|
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
|
||||||
|
LoRAModel.from_local_checkpoint(
|
||||||
|
baichuan_regex_lora_files,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
else:
|
else:
|
||||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||||
# and the test should raise the following error.
|
# and the test should raise the following error.
|
||||||
|
@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
|||||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.punica import PunicaWrapper
|
from vllm.lora.punica import PunicaWrapper
|
||||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||||
|
is_regex_target_modules,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@ -233,6 +234,8 @@ class LoRAModel(AdapterModel):
|
|||||||
# modules.
|
# modules.
|
||||||
unexpected_modules = []
|
unexpected_modules = []
|
||||||
target_modules = config["target_modules"]
|
target_modules = config["target_modules"]
|
||||||
|
if not isinstance(target_modules, list):
|
||||||
|
target_modules = [target_modules]
|
||||||
for module in target_modules:
|
for module in target_modules:
|
||||||
# Compatible with more modules,
|
# Compatible with more modules,
|
||||||
# such as:layers.11.self_attn.k_proj
|
# such as:layers.11.self_attn.k_proj
|
||||||
@ -243,8 +246,8 @@ class LoRAModel(AdapterModel):
|
|||||||
# expected_lora_modules. It is not reliable. See
|
# expected_lora_modules. It is not reliable. See
|
||||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||||
# other better mechanism.
|
# other better mechanism.
|
||||||
if unexpected_modules:
|
if unexpected_modules and not is_regex_target_modules(
|
||||||
print(unexpected_modules, "modules")
|
config["target_modules"], expected_lora_modules):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"While loading {lora_dir}, expected"
|
f"While loading {lora_dir}, expected"
|
||||||
f" target modules in {expected_lora_modules}"
|
f" target modules in {expected_lora_modules}"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Set, Tuple, Type
|
import re
|
||||||
|
from typing import List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||||
@ -113,6 +114,38 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|||||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|
||||||
|
|
||||||
|
def is_regex_target_modules(load_modules: Union[str, List[str]],
|
||||||
|
expected_lora_modules: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
PEFT supports passing `target_modules` in the form of regular expressions,
|
||||||
|
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
||||||
|
determine whether the suffix in the regular expression is present in the
|
||||||
|
`expected_lora_modules`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_valid_regex(pattern):
|
||||||
|
try:
|
||||||
|
re.compile(pattern)
|
||||||
|
return True
|
||||||
|
except re.error:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_subset(sub_list, full_list):
|
||||||
|
return set(sub_list).issubset(set(full_list))
|
||||||
|
|
||||||
|
# Similar to PEFT's processing logic, regex-related operations are only
|
||||||
|
# executed when the load_modules is a `str`.
|
||||||
|
if not isinstance(load_modules, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_valid_regex(load_modules):
|
||||||
|
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
||||||
|
if match:
|
||||||
|
suffix = match.group(1).split("|")
|
||||||
|
return is_subset(suffix, expected_lora_modules)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_adapter_absolute_path(lora_path: str) -> str:
|
def get_adapter_absolute_path(lora_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
Resolves the given lora_path to an absolute local path.
|
Resolves the given lora_path to an absolute local path.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user