[Bugfix] Fix LoRA loading check (#4138)
Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
parent
a134ef6f5e
commit
d17c8477f1
@ -143,6 +143,12 @@ def baichuan_lora_files():
|
|||||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
|
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def baichuan_zero_lora_files():
|
||||||
|
# all the lora_B weights are initialized to zero.
|
||||||
|
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tinyllama_lora_files():
|
def tinyllama_lora_files():
|
||||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||||
|
@ -3,9 +3,16 @@ import pytest
|
|||||||
from vllm.lora.models import LoRAModel
|
from vllm.lora.models import LoRAModel
|
||||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
|
|
||||||
|
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
|
||||||
|
|
||||||
@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
|
|
||||||
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
|
@pytest.mark.parametrize("lora_name", lora_lst)
|
||||||
|
def test_load_checkpoints(
|
||||||
|
lora_name,
|
||||||
|
baichuan_lora_files,
|
||||||
|
baichuan_zero_lora_files,
|
||||||
|
chatglm3_lora_files,
|
||||||
|
):
|
||||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||||
@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
embedding_modules=embedding_modules,
|
embedding_modules=embedding_modules,
|
||||||
embedding_padding_modules=embed_padding_modules)
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
elif lora_name == "baichuan7B-zero":
|
||||||
|
#Test that the target_modules contain prefix
|
||||||
|
# such as "model.layers.0.self_atten.W_pack", and
|
||||||
|
# the test should pass.
|
||||||
|
LoRAModel.from_local_checkpoint(
|
||||||
|
baichuan_zero_lora_files,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
else:
|
else:
|
||||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||||
# and the test should raise the following error.
|
# and the test should raise the following error.
|
||||||
|
@ -212,7 +212,9 @@ class LoRAModel:
|
|||||||
target_modules = config["target_modules"]
|
target_modules = config["target_modules"]
|
||||||
unexpected_modules = []
|
unexpected_modules = []
|
||||||
for module in target_modules:
|
for module in target_modules:
|
||||||
if module not in expected_lora_modules:
|
# Compatible with more modules, such as:layers.11.self_attn.k_proj
|
||||||
|
part_name = module.split(".")[-1]
|
||||||
|
if part_name not in expected_lora_modules:
|
||||||
unexpected_modules.append(module)
|
unexpected_modules.append(module)
|
||||||
# loaded lora's target modules must be a subset of expected_lora_modules
|
# loaded lora's target modules must be a subset of expected_lora_modules
|
||||||
if unexpected_modules:
|
if unexpected_modules:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user