[Misc][LoRA] Improve the readability of LoRA error messages (#12102)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
69d765f5a5
commit
07934cc237
@ -17,6 +17,33 @@ MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# generation quality here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
BADREQUEST_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{
|
||||
"r": 1024
|
||||
},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
(
|
||||
"test_bias",
|
||||
{
|
||||
"bias": "all"
|
||||
},
|
||||
"Adapter bias cannot be used without bias_enabled",
|
||||
),
|
||||
("test_dora", {
|
||||
"use_dora": True
|
||||
}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{
|
||||
"modules_to_save": ["lm_head"]
|
||||
},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
|
||||
tmp_path, zephyr_lora_files):
|
||||
invalid_rank = tmp_path / "invalid_rank"
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error",
|
||||
BADREQUEST_CASES)
|
||||
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
|
||||
zephyr_lora_files, test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str):
|
||||
# Create test directory
|
||||
test_dir = tmp_path / test_name
|
||||
|
||||
# Copy adapter from zephyr_lora_files to invalid_rank
|
||||
shutil.copytree(zephyr_lora_files, invalid_rank)
|
||||
# Copy adapter files
|
||||
shutil.copytree(zephyr_lora_files, test_dir)
|
||||
|
||||
with open(invalid_rank / "adapter_config.json") as f:
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(config_change)
|
||||
|
||||
print(adapter_config)
|
||||
|
||||
# assert False
|
||||
|
||||
# Change rank to invalid value
|
||||
adapter_config["r"] = 1024
|
||||
with open(invalid_rank / "adapter_config.json", "w") as f:
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
|
||||
with pytest.raises(openai.BadRequestError,
|
||||
match="is greater than max_lora_rank"):
|
||||
# Test loading the adapter
|
||||
with pytest.raises(openai.BadRequestError, match=expected_error):
|
||||
await client.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name": "invalid-json",
|
||||
"lora_path": str(invalid_rank)
|
||||
"lora_name": test_name,
|
||||
"lora_path": str(test_dir)
|
||||
})
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
@ -30,11 +31,14 @@ def test_load_checkpoints(
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
if lora_name == "baichuan7B":
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
# For the baichuan7B model, load it's LoRA,
|
||||
# and the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@ -43,9 +47,12 @@ def test_load_checkpoints(
|
||||
# Test that the target_modules contain prefix
|
||||
# such as "model.layers.0.self_atten.W_pack", and
|
||||
# the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_zero_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@ -53,9 +60,12 @@ def test_load_checkpoints(
|
||||
elif lora_name == "baichuan7B-zero-regex":
|
||||
# Test that the `target_modules` in the form of regular expressions,
|
||||
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_regex_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@ -64,10 +74,13 @@ def test_load_checkpoints(
|
||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||
# and the test should raise the following error.
|
||||
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
|
||||
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
LoRAModel.from_local_checkpoint(
|
||||
chatglm3_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
|
||||
".layers.": ".baichuan_layers.",
|
||||
},
|
||||
)
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
|
@ -3,6 +3,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||
lora_path = get_adapter_absolute_path(lora_name)
|
||||
|
||||
# lora loading should work for either absolute path and hugggingface id.
|
||||
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
|
@ -1,5 +1,3 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
@ -34,56 +32,6 @@ DEVICES = ([
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
|
||||
|
||||
def test_peft_helper(sql_lora_files):
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
assert peft_helper.r == 8
|
||||
assert peft_helper.lora_alpha == 16
|
||||
assert peft_helper.target_modules == [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
scaling = peft_helper.lora_alpha / peft_helper.r
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
# test RSLoRA
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_rslora=True)
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
|
||||
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
expected_error = "vLLM only supports modules_to_save being None."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
modules_to_save=["lm_head"],
|
||||
)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
expected_error = "vLLM does not yet support DoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_dora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
@ -91,11 +39,8 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
tensors,
|
||||
|
109
tests/lora/test_peft_helper.py
Normal file
109
tests/lora/test_peft_helper.py
Normal file
@ -0,0 +1,109 @@
|
||||
import json
|
||||
import math
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
|
||||
ERROR_CASES = [
|
||||
(
|
||||
"test_rank",
|
||||
{
|
||||
"r": 1024
|
||||
},
|
||||
"is greater than max_lora_rank",
|
||||
),
|
||||
(
|
||||
"test_bias",
|
||||
{
|
||||
"bias": "all"
|
||||
},
|
||||
"Adapter bias cannot be used without bias_enabled",
|
||||
),
|
||||
("test_dora", {
|
||||
"use_dora": True
|
||||
}, "does not yet support DoRA"),
|
||||
(
|
||||
"test_modules_to_save",
|
||||
{
|
||||
"modules_to_save": ["lm_head"]
|
||||
},
|
||||
"only supports modules_to_save being None",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
|
||||
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
|
||||
max_position_embeddings=4096)
|
||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||
peft_helper.validate_legal(lora_config)
|
||||
assert peft_helper.r == 8
|
||||
assert peft_helper.lora_alpha == 16
|
||||
assert peft_helper.target_modules == [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
assert peft_helper.context_length == 16384
|
||||
assert peft_helper.vllm_max_position_embeddings == 4096
|
||||
assert peft_helper.vllm_long_context_scaling_factor == float(
|
||||
math.ceil(peft_helper.context_length /
|
||||
peft_helper.vllm_max_position_embeddings))
|
||||
# test RSLoRA
|
||||
rslora_config = dict(use_rslora=True)
|
||||
test_dir = tmp_path / "test_rslora"
|
||||
shutil.copytree(long_context_lora_files_16k_1, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(rslora_config)
|
||||
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(test_dir,
|
||||
max_position_embeddings=4096)
|
||||
peft_helper.validate_legal(lora_config)
|
||||
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
|
||||
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
|
||||
def test_peft_helper_error(
|
||||
sql_lora_files,
|
||||
tmp_path,
|
||||
test_name: str,
|
||||
config_change: dict,
|
||||
expected_error: str,
|
||||
):
|
||||
test_dir = tmp_path / test_name
|
||||
shutil.copytree(sql_lora_files, test_dir)
|
||||
|
||||
# Load and modify configuration
|
||||
config_path = test_dir / "adapter_config.json"
|
||||
with open(config_path) as f:
|
||||
adapter_config = json.load(f)
|
||||
# Apply configuration changes
|
||||
adapter_config.update(config_change)
|
||||
|
||||
# Save modified configuration
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(adapter_config, f)
|
||||
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
|
||||
# Test loading the adapter
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
PEFTHelper.from_local_dir(
|
||||
test_dir, max_position_embeddings=4096).validate_legal(lora_config)
|
@ -296,6 +296,7 @@ class MQLLMEngine:
|
||||
is_engine_errored=False,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
return
|
||||
# Otherwise, send back the successful load message
|
||||
self._send_outputs(
|
||||
RPCAdapterLoadedResponse(request_id=request.request_id))
|
||||
|
@ -157,24 +157,16 @@ class OpenAIServingModels:
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except ValueError as e:
|
||||
# Adapter not found or lora configuration errors
|
||||
if "No adapter found" in str(e):
|
||||
return create_error_response(message=str(e),
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
else:
|
||||
return create_error_response(
|
||||
message=str(e),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
except BaseException as e:
|
||||
# Some other unexpected problem loading the adapter, e.g. malformed
|
||||
# input files.
|
||||
# More detailed error messages for the user would be nicer here
|
||||
error_type = "BadRequestError"
|
||||
status_code = HTTPStatus.BAD_REQUEST
|
||||
if isinstance(e, ValueError) and "No adapter found" in str(e):
|
||||
error_type = "NotFoundError"
|
||||
status_code = HTTPStatus.NOT_FOUND
|
||||
|
||||
return create_error_response(message=str(e),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
err_type=error_type,
|
||||
status_code=status_code)
|
||||
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@ -180,8 +179,8 @@ class LoRAModel(AdapterModel):
|
||||
cls,
|
||||
lora_dir: str,
|
||||
expected_lora_modules: List[str],
|
||||
peft_helper: PEFTHelper,
|
||||
*,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@ -196,9 +195,7 @@ class LoRAModel(AdapterModel):
|
||||
lora_dir: The local path that has lora data.
|
||||
expected_lora_modules: Name of modules that are expected to be
|
||||
replaced by lora.
|
||||
max_position_embeddings: Max position embedding length. Used to
|
||||
scaling the largest context length. If None, the lora model's
|
||||
context length is not scaled.
|
||||
peft_helper: Loaded lora configuration information.
|
||||
lora_model_id: Lora model id. If not given, automatically set by
|
||||
a global counter.
|
||||
device: Device where the lora model is loaded.
|
||||
@ -207,18 +204,13 @@ class LoRAModel(AdapterModel):
|
||||
Returns:
|
||||
Loaded LoRA Model.
|
||||
"""
|
||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
new_embeddings_tensor_path = os.path.join(
|
||||
lora_dir, "new_embeddings.safetensors")
|
||||
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||
"new_embeddings.bin")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
unexpected_modules: List[Union[list[str], str]]
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors: Dict[str, torch.Tensor] = {}
|
||||
|
@ -1,9 +1,12 @@
|
||||
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -11,6 +14,12 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
"""
|
||||
A helper class for PEFT configurations, specifically designed for LoRA.
|
||||
This class handles configuration validation, compatibility checks for
|
||||
various LoRA implementations.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
r: int
|
||||
lora_alpha: int
|
||||
@ -29,20 +38,18 @@ class PEFTHelper:
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self):
|
||||
def _validate_features(self) -> List[str]:
|
||||
"""
|
||||
Check if there are any unsupported Lora features.
|
||||
"""
|
||||
error_msg = []
|
||||
|
||||
if self.modules_to_save:
|
||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||
|
||||
if self.use_dora:
|
||||
error_msg.append("vLLM does not yet support DoRA.")
|
||||
|
||||
if error_msg:
|
||||
raise ValueError(f"{', '.join(error_msg)}")
|
||||
return error_msg
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_features()
|
||||
if self.use_rslora:
|
||||
logger.info_once("Loading LoRA weights trained with rsLoRA.")
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||
@ -78,3 +85,29 @@ class PEFTHelper:
|
||||
for k, v in config_dict.items() if k in class_fields
|
||||
}
|
||||
return cls(**filtered_dict)
|
||||
|
||||
@classmethod
|
||||
def from_local_dir(cls, lora_path: str,
|
||||
max_position_embeddings: Optional[int]) -> "PEFTHelper":
|
||||
lora_config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||
return cls.from_dict(config)
|
||||
|
||||
def validate_legal(self, lora_config: LoRAConfig) -> None:
|
||||
"""
|
||||
Validates the LoRA configuration settings against application
|
||||
constraints and requirements.
|
||||
"""
|
||||
error_msg = self._validate_features()
|
||||
if self.r > lora_config.max_lora_rank:
|
||||
error_msg.append(
|
||||
f"LoRA rank {self.r} is greater than max_lora_rank"
|
||||
f" {lora_config.max_lora_rank}.")
|
||||
if self.bias != "none" and not lora_config.bias_enabled:
|
||||
error_msg.append(
|
||||
"Adapter bias cannot be used without bias_enabled.")
|
||||
if error_msg:
|
||||
raise ValueError(f"{' '.join(error_msg)}")
|
||||
|
@ -12,6 +12,7 @@ from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
@ -95,6 +96,13 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
expected_lora_modules = list(set(expected_lora_modules))
|
||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
lora_path, self.max_position_embeddings)
|
||||
|
||||
# Validates the LoRA configuration against requirements before
|
||||
# loading weights, throwing an exception if validation fails.
|
||||
peft_helper.validate_legal(self.lora_config)
|
||||
|
||||
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of lora weights.
|
||||
hf_to_vllm_mapper = None
|
||||
@ -105,7 +113,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=lora_request.lora_int_id,
|
||||
device="cpu",
|
||||
dtype=self.lora_config.lora_dtype,
|
||||
@ -120,15 +128,14 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
# - No adapter found to download from huggingface (or in
|
||||
# offline mode)
|
||||
# - No local adapter files found at `lora_request.lora_path`
|
||||
# For NotFoundError
|
||||
raise ValueError(
|
||||
f"Loading lora {lora_request.lora_name} failed: No adapter "
|
||||
f"found for {lora_path}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
||||
if lora.rank > self.lora_config.max_lora_rank:
|
||||
raise ValueError(
|
||||
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||
f"{self.lora_config.max_lora_rank}.")
|
||||
# For BadRequestError
|
||||
raise e
|
||||
|
||||
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
|
||||
f"is greater than lora_extra_vocab_size "
|
||||
|
Loading…
x
Reference in New Issue
Block a user