[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server (#10546)
Signed-off-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local> Co-authored-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local>
This commit is contained in:
parent
54a66e5fee
commit
fdcb850f14
@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \
|
||||
|
||||
## Dynamically serving LoRA Adapters
|
||||
|
||||
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
|
||||
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
|
||||
to change models on-the-fly is needed.
|
||||
In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed.
|
||||
|
||||
Note: Enabling this feature in production environments is risky as users may participate in model adapter management.
|
||||
|
||||
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
|
||||
To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||
is set to `True`.
|
||||
|
||||
```bash
|
||||
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
||||
```
|
||||
|
||||
### Using API Endpoints
|
||||
Loading a LoRA Adapter:
|
||||
|
||||
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
|
||||
@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
|
||||
}'
|
||||
```
|
||||
|
||||
### Using Plugins
|
||||
Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter.
|
||||
|
||||
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
|
||||
|
||||
You can either install existing plugins or implement your own.
|
||||
|
||||
Steps to implement your own LoRAResolver plugin:
|
||||
1. Implement the LoRAResolver interface.
|
||||
|
||||
Example of a simple S3 LoRAResolver implementation:
|
||||
|
||||
```python
|
||||
import os
|
||||
import s3fs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver
|
||||
|
||||
class S3LoRAResolver(LoRAResolver):
|
||||
def __init__(self):
|
||||
self.s3 = s3fs.S3FileSystem()
|
||||
self.s3_path_format = os.getenv("S3_PATH_TEMPLATE")
|
||||
self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE")
|
||||
|
||||
async def resolve_lora(self, base_model_name, lora_name):
|
||||
s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
|
||||
local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
|
||||
|
||||
# Download the LoRA from S3 to the local path
|
||||
await self.s3._get(
|
||||
s3_path, local_path, recursive=True, maxdepth=1
|
||||
)
|
||||
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name,
|
||||
lora_path=local_path,
|
||||
lora_int_id=abs(hash(lora_name))
|
||||
)
|
||||
return lora_request
|
||||
```
|
||||
|
||||
2. Register LoRAResolver plugin.
|
||||
|
||||
```python
|
||||
from vllm.lora.resolver import LoRAResolverRegistry
|
||||
|
||||
s3_resolver = S3LoRAResolver()
|
||||
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
|
||||
```
|
||||
|
||||
For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md).
|
||||
|
||||
## New format for `--lora-modules`
|
||||
|
||||
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
|
||||
|
209
tests/entrypoints/openai/test_lora_resolvers.py
Normal file
209
tests/entrypoints/openai/test_lora_resolvers.py
Normal file
@ -0,0 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
|
||||
MOCK_RESOLVER_NAME = "mock_test_resolver"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
"""Minimal mock ModelConfig for testing."""
|
||||
model: str = MODEL_NAME
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
tokenizer_mode: str = "auto"
|
||||
max_model_len: int = 100
|
||||
tokenizer_revision: Optional[str] = None
|
||||
multimodal_config: MultiModalConfig = field(
|
||||
default_factory=MultiModalConfig)
|
||||
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
class MockLoRAResolver(LoRAResolver):
|
||||
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
if lora_name == "test-lora":
|
||||
return LoRARequest(lora_name="test-lora",
|
||||
lora_int_id=1,
|
||||
lora_local_path="/fake/path/test-lora")
|
||||
elif lora_name == "invalid-lora":
|
||||
return LoRARequest(lora_name="invalid-lora",
|
||||
lora_int_id=2,
|
||||
lora_local_path="/fake/path/invalid-lora")
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def register_mock_resolver():
|
||||
"""Fixture to register and unregister the mock LoRA resolver."""
|
||||
resolver = MockLoRAResolver()
|
||||
LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
|
||||
yield
|
||||
# Cleanup: remove the resolver after the test runs
|
||||
if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
|
||||
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serving_setup():
|
||||
"""Provides a mocked engine and serving completion instance."""
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
def mock_add_lora_side_effect(lora_request: LoRARequest):
|
||||
"""Simulate engine behavior when adding LoRAs."""
|
||||
if lora_request.lora_name == "test-lora":
|
||||
# Simulate successful addition
|
||||
return
|
||||
elif lora_request.lora_name == "invalid-lora":
|
||||
# Simulate failure during addition (e.g. invalid format)
|
||||
raise ValueError(f"Simulated failure adding LoRA: "
|
||||
f"{lora_request.lora_name}")
|
||||
|
||||
mock_engine.add_lora.side_effect = mock_add_lora_side_effect
|
||||
mock_engine.generate.reset_mock()
|
||||
mock_engine.add_lora.reset_mock()
|
||||
|
||||
mock_model_config = MockModelConfig()
|
||||
models = OpenAIServingModels(engine_client=mock_engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
|
||||
serving_completion = OpenAIServingCompletion(mock_engine,
|
||||
mock_model_config,
|
||||
models,
|
||||
request_logger=None)
|
||||
|
||||
return mock_engine, serving_completion
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_with_lora_resolver(mock_serving_setup,
|
||||
monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
|
||||
lora_model_name = "test-lora"
|
||||
req_found = CompletionRequest(
|
||||
model=lora_model_name,
|
||||
prompt="Generate with LoRA",
|
||||
)
|
||||
|
||||
# Suppress potential errors during the mocked generate call,
|
||||
# as we are primarily checking for add_lora and generate calls
|
||||
with suppress(Exception):
|
||||
await serving_completion.create_completion(req_found)
|
||||
|
||||
mock_engine.add_lora.assert_called_once()
|
||||
called_lora_request = mock_engine.add_lora.call_args[0][0]
|
||||
assert isinstance(called_lora_request, LoRARequest)
|
||||
assert called_lora_request.lora_name == lora_model_name
|
||||
|
||||
mock_engine.generate.assert_called_once()
|
||||
called_lora_request = mock_engine.generate.call_args[1]['lora_request']
|
||||
assert isinstance(called_lora_request, LoRARequest)
|
||||
assert called_lora_request.lora_name == lora_model_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_resolver_not_found(mock_serving_setup,
|
||||
monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
|
||||
non_existent_model = "non-existent-lora-adapter"
|
||||
req = CompletionRequest(
|
||||
model=non_existent_model,
|
||||
prompt="what is 1+1?",
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(req)
|
||||
|
||||
mock_engine.add_lora.assert_not_called()
|
||||
mock_engine.generate.assert_not_called()
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.code == HTTPStatus.NOT_FOUND.value
|
||||
assert non_existent_model in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_resolver_add_lora_fails(
|
||||
mock_serving_setup, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
|
||||
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
|
||||
invalid_model = "invalid-lora"
|
||||
req = CompletionRequest(
|
||||
model=invalid_model,
|
||||
prompt="what is 1+1?",
|
||||
)
|
||||
|
||||
response = await serving_completion.create_completion(req)
|
||||
|
||||
# Assert add_lora was called before the failure
|
||||
mock_engine.add_lora.assert_called_once()
|
||||
called_lora_request = mock_engine.add_lora.call_args[0][0]
|
||||
assert isinstance(called_lora_request, LoRARequest)
|
||||
assert called_lora_request.lora_name == invalid_model
|
||||
|
||||
# Assert generate was *not* called due to the failure
|
||||
mock_engine.generate.assert_not_called()
|
||||
|
||||
# Assert the correct error response
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.code == HTTPStatus.BAD_REQUEST.value
|
||||
assert invalid_model in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_completion_flag_not_set(mock_serving_setup):
|
||||
mock_engine, serving_completion = mock_serving_setup
|
||||
|
||||
lora_model_name = "test-lora"
|
||||
req_found = CompletionRequest(
|
||||
model=lora_model_name,
|
||||
prompt="Generate with LoRA",
|
||||
)
|
||||
|
||||
await serving_completion.create_completion(req_found)
|
||||
|
||||
mock_engine.add_lora.assert_not_called()
|
||||
mock_engine.generate.assert_not_called()
|
74
tests/lora/test_resolver.py
Normal file
74
tests/lora/test_resolver.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
|
||||
|
||||
class DummyLoRAResolver(LoRAResolver):
|
||||
"""A dummy LoRA resolver for testing."""
|
||||
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
if lora_name == "test_lora":
|
||||
return LoRARequest(
|
||||
lora_name=lora_name,
|
||||
lora_path=f"/dummy/path/{base_model_name}/{lora_name}",
|
||||
lora_int_id=abs(hash(lora_name)))
|
||||
return None
|
||||
|
||||
|
||||
def test_resolver_registry_registration():
|
||||
"""Test basic resolver registration functionality."""
|
||||
registry = LoRAResolverRegistry
|
||||
resolver = DummyLoRAResolver()
|
||||
|
||||
# Register a new resolver
|
||||
registry.register_resolver("dummy", resolver)
|
||||
assert "dummy" in registry.get_supported_resolvers()
|
||||
|
||||
# Get registered resolver
|
||||
retrieved_resolver = registry.get_resolver("dummy")
|
||||
assert retrieved_resolver is resolver
|
||||
|
||||
|
||||
def test_resolver_registry_duplicate_registration():
|
||||
"""Test registering a resolver with an existing name."""
|
||||
registry = LoRAResolverRegistry
|
||||
resolver1 = DummyLoRAResolver()
|
||||
resolver2 = DummyLoRAResolver()
|
||||
|
||||
registry.register_resolver("dummy", resolver1)
|
||||
registry.register_resolver("dummy", resolver2)
|
||||
|
||||
assert registry.get_resolver("dummy") is resolver2
|
||||
|
||||
|
||||
def test_resolver_registry_unknown_resolver():
|
||||
"""Test getting a non-existent resolver."""
|
||||
registry = LoRAResolverRegistry
|
||||
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
registry.get_resolver("unknown_resolver")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dummy_resolver_resolve():
|
||||
"""Test the dummy resolver's resolve functionality."""
|
||||
dummy_resolver = DummyLoRAResolver()
|
||||
base_model_name = "base_model_test"
|
||||
lora_name = "test_lora"
|
||||
|
||||
# Test successful resolution
|
||||
result = await dummy_resolver.resolve_lora(base_model_name, lora_name)
|
||||
assert isinstance(result, LoRARequest)
|
||||
assert result.lora_name == lora_name
|
||||
assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}"
|
||||
|
||||
# Test failed resolution
|
||||
result = await dummy_resolver.resolve_lora(base_model_name,
|
||||
"nonexistent_lora")
|
||||
assert result is None
|
@ -10,6 +10,7 @@ from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
@ -125,18 +126,29 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [
|
||||
lora.lora_name for lora in self.models.lora_requests
|
||||
]:
|
||||
return None
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
|
||||
load_result := await self.models.resolve_lora(request.model)):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if isinstance(load_result, ErrorResponse) and \
|
||||
load_result.code == HTTPStatus.BAD_REQUEST.value:
|
||||
error_response = load_result
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
@ -63,11 +66,19 @@ class OpenAIServingModels:
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: list[LoRARequest] = []
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
@ -234,6 +245,65 @@ class OpenAIServingModels:
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
for existing in self.lora_requests:
|
||||
if existing.lora_name == lora_name:
|
||||
return existing
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
|
83
vllm/lora/resolver.py
Normal file
83
vllm/lora/resolver.py
Normal file
@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AbstractSet, Dict, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAResolver(ABC):
|
||||
"""Base class for LoRA adapter resolvers.
|
||||
|
||||
This class defines the interface for resolving and fetching LoRA adapters.
|
||||
Implementations of this class should handle the logic for locating and
|
||||
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
|
||||
etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
"""Abstract method to resolve and fetch a LoRA model adapter.
|
||||
|
||||
Implements logic to locate and download LoRA adapter based on the name.
|
||||
Implementations might fetch from a blob storage or other sources.
|
||||
|
||||
Args:
|
||||
base_model_name: The name/identifier of the base model to resolve.
|
||||
lora_name: The name/identifier of the LoRA model to resolve.
|
||||
|
||||
Returns:
|
||||
Optional[LoRARequest]: The resolved LoRA model information, or None
|
||||
if the LoRA model cannot be found.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LoRAResolverRegistry:
|
||||
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
|
||||
|
||||
def get_supported_resolvers(self) -> AbstractSet[str]:
|
||||
"""Get all registered resolver names."""
|
||||
return self.resolvers.keys()
|
||||
|
||||
def register_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver: LoRAResolver,
|
||||
) -> None:
|
||||
"""Register a LoRA resolver.
|
||||
Args:
|
||||
resolver_name: Name to register the resolver under.
|
||||
resolver: The LoRA resolver instance to register.
|
||||
"""
|
||||
if resolver_name in self.resolvers:
|
||||
logger.warning(
|
||||
"LoRA resolver %s is already registered, and will be "
|
||||
"overwritten by the new resolver instance %s.", resolver_name,
|
||||
resolver)
|
||||
|
||||
self.resolvers[resolver_name] = resolver
|
||||
|
||||
def get_resolver(self, resolver_name: str) -> LoRAResolver:
|
||||
"""Get a registered resolver instance by name.
|
||||
Args:
|
||||
resolver_name: Name of the resolver to get.
|
||||
Returns:
|
||||
The resolver instance.
|
||||
Raises:
|
||||
KeyError: If the resolver is not found in the registry.
|
||||
"""
|
||||
if resolver_name not in self.resolvers:
|
||||
raise KeyError(
|
||||
f"LoRA resolver '{resolver_name}' not found. "
|
||||
f"Available resolvers: {list(self.resolvers.keys())}")
|
||||
return self.resolvers[resolver_name]
|
||||
|
||||
|
||||
LoRAResolverRegistry = _LoRAResolverRegistry()
|
Loading…
x
Reference in New Issue
Block a user