[Bugfix] Fix LoRA loading check (#4138)

Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Jee Li 2024-04-19 15:59:54 +08:00 committed by GitHub
parent a134ef6f5e
commit d17c8477f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 3 deletions

View File

@ -143,6 +143,12 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
# all the lora_B weights are initialized to zero.
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tinyllama_lora_files(): def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

View File

@ -3,9 +3,16 @@ import pytest
from vllm.lora.models import LoRAModel from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): @pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
device="cpu", device="cpu",
embedding_modules=embedding_modules, embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules) embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else: else:
# For the baichuan7B model, load chatglm3-6b's LoRA, # For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error. # and the test should raise the following error.

View File

@ -212,7 +212,9 @@ class LoRAModel:
target_modules = config["target_modules"] target_modules = config["target_modules"]
unexpected_modules = [] unexpected_modules = []
for module in target_modules: for module in target_modules:
if module not in expected_lora_modules: # Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module) unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules # loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules: if unexpected_modules: