[Core] Support dynamically loading Lora adapter from HuggingFace (#6234)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Jiaxin Shan 2024-07-22 15:42:40 -07:00 committed by GitHub
parent 69d5ae38dc
commit 42c7f66a38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 201 additions and 18 deletions

View File

@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
lora_path="abc"))
waiting.append(seq_group)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
lora_path="abc"))
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)

View File

@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:
@pytest.fixture(scope="session")
def sql_lora_files():
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"
@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")

View File

@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(context_len, lora_id,
long_context_infos[lora_id]["lora"],
long_context_infos[lora_id]["lora"], None,
4096 * scaling_factor)

View File

@ -0,0 +1,39 @@
from typing import List
import pytest
from vllm.lora.models import LoRAModel
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM
# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
expected_lora_modules: List[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_name)
# lora loading should work for either absolute path and hugggingface id.
lora_model = LoRAModel.from_local_checkpoint(
lora_path,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
# Assertions to ensure the model is loaded correctly
assert lora_model is not None, "LoRAModel is not loaded correctly"

View File

@ -1,9 +1,12 @@
from collections import OrderedDict
from unittest.mock import patch
import pytest
from huggingface_hub.utils import HfHubHTTPError
from torch import nn
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
from vllm.lora.utils import (get_adapter_absolute_path,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache
@ -182,3 +185,55 @@ def test_lru_cache():
assert 2 in cache
assert 4 in cache
assert 6 in cache
# Unit tests for get_adapter_absolute_path
@patch('os.path.isabs')
def test_get_adapter_absolute_path_absolute(mock_isabs):
path = '/absolute/path/to/lora'
mock_isabs.return_value = True
assert get_adapter_absolute_path(path) == path
@patch('os.path.expanduser')
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
# Path with ~ that needs to be expanded
path = '~/relative/path/to/lora'
absolute_path = '/home/user/relative/path/to/lora'
mock_expanduser.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('os.path.exists')
@patch('os.path.abspath')
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
# Relative path that exists locally
path = 'relative/path/to/lora'
absolute_path = '/absolute/path/to/lora'
mock_exist.return_value = True
mock_abspath.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier
path = 'org/repo'
absolute_path = '/mock/snapshot/path'
mock_exist.return_value = False
mock_snapshot_download.return_value = absolute_path
assert get_adapter_absolute_path(path) == absolute_path
@patch('huggingface_hub.snapshot_download')
@patch('os.path.exists')
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
mock_snapshot_download):
# Hugging Face model identifier with download error
path = 'org/repo'
mock_exist.return_value = False
mock_snapshot_download.side_effect = HfHubHTTPError(
"failed to query model info")
assert get_adapter_absolute_path(path) == path

View File

@ -43,7 +43,7 @@ class PromptAdapterPath:
@dataclass
class LoRAModulePath:
name: str
local_path: str
path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
@ -83,7 +83,7 @@ class OpenAIServing:
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
]

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
import warnings
from dataclasses import dataclass, field
from typing import Optional
from vllm.adapter_commons.request import AdapterRequest
@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
lora_name: str
lora_int_id: int
lora_local_path: str
lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__
def __post_init__(self):
if 'lora_local_path' in self.__dict__:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead.",
DeprecationWarning,
stacklevel=2)
if not self.lora_path:
self.lora_path = self.lora_local_path or ""
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"
@property
def adapter_id(self):
return self.lora_int_id
@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
def name(self):
return self.lora_name
@property
def path(self):
return self.lora_path
@property
def local_path(self):
return self.lora_local_path
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
return self.lora_path
@local_path.setter
def local_path(self, value):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
self.lora_path = value

View File

@ -1,5 +1,9 @@
import os
from typing import List, Optional, Set, Tuple, Type
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError)
from torch import nn
from transformers import PretrainedConfig
@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight")
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path
# If the path starts with ~, expand the user home directory.
if lora_path.startswith('~'):
return os.path.expanduser(lora_path)
# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)
# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger.exception("Error downloading the HuggingFace model")
return lora_path
return local_snapshot_path

View File

@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__)
@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path,
lora_path,
expected_lora_modules,
max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id,
@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
embedding_padding_modules=self.embedding_padding_modules,
)
except Exception as e:
raise RuntimeError(
f"Loading lora {lora_request.lora_local_path} failed") from e
raise RuntimeError(f"Loading lora {lora_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank:
raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank "

View File

@ -137,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
if lora_request is None:
return None
try:
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
**kwargs)
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except OSError as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(
"No tokenizer found in %s, using base model tokenizer instead. "
"(Exception: %s)", lora_request.lora_local_path, e)
"(Exception: %s)", lora_request.lora_path, e)
tokenizer = None
return tokenizer

View File

@ -691,7 +691,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)