2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-07-22 15:42:40 -07:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm.lora.models import LoRAModel
|
2025-01-17 19:32:28 +08:00
|
|
|
from vllm.lora.peft_helper import PEFTHelper
|
2024-07-22 15:42:40 -07:00
|
|
|
from vllm.lora.utils import get_adapter_absolute_path
|
|
|
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
|
|
|
|
|
|
|
# Provide absolute path and huggingface lora ids
|
|
|
|
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
|
2025-02-22 16:21:30 +08:00
|
|
|
LLAMA_LORA_MODULES = [
|
|
|
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
|
|
|
|
"lm_head"
|
|
|
|
]
|
2024-07-22 15:42:40 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
|
|
|
|
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
|
|
|
lora_name = request.getfixturevalue(lora_fixture_name)
|
|
|
|
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
|
|
|
embedding_modules = LlamaForCausalLM.embedding_modules
|
|
|
|
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
2025-03-03 01:34:51 +00:00
|
|
|
expected_lora_modules: list[str] = []
|
2025-02-22 16:21:30 +08:00
|
|
|
for module in LLAMA_LORA_MODULES:
|
2024-07-22 15:42:40 -07:00
|
|
|
if module in packed_modules_mapping:
|
|
|
|
expected_lora_modules.extend(packed_modules_mapping[module])
|
|
|
|
else:
|
|
|
|
expected_lora_modules.append(module)
|
|
|
|
|
|
|
|
lora_path = get_adapter_absolute_path(lora_name)
|
|
|
|
|
|
|
|
# lora loading should work for either absolute path and hugggingface id.
|
2025-01-17 19:32:28 +08:00
|
|
|
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
|
2024-07-22 15:42:40 -07:00
|
|
|
lora_model = LoRAModel.from_local_checkpoint(
|
|
|
|
lora_path,
|
|
|
|
expected_lora_modules,
|
2025-01-17 19:32:28 +08:00
|
|
|
peft_helper=peft_helper,
|
2024-07-22 15:42:40 -07:00
|
|
|
lora_model_id=1,
|
|
|
|
device="cpu",
|
|
|
|
embedding_modules=embedding_modules,
|
|
|
|
embedding_padding_modules=embed_padding_modules)
|
|
|
|
|
|
|
|
# Assertions to ensure the model is loaded correctly
|
|
|
|
assert lora_model is not None, "LoRAModel is not loaded correctly"
|