[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server (#10546)

Signed-off-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local>
Co-authored-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local>
This commit is contained in:
Angky William 2025-04-15 15:31:38 -07:00 committed by GitHub
parent 54a66e5fee
commit fdcb850f14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 505 additions and 6 deletions

View File

@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \
## Dynamically serving LoRA Adapters
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
to change models on-the-fly is needed.
In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed.
Note: Enabling this feature in production environments is risky as users may participate in model adapter management.
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`.
```bash
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
```
### Using API Endpoints
Loading a LoRA Adapter:
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
}'
```
### Using Plugins
Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter.
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
You can either install existing plugins or implement your own.
Steps to implement your own LoRAResolver plugin:
1. Implement the LoRAResolver interface.
Example of a simple S3 LoRAResolver implementation:
```python
import os
import s3fs
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver
class S3LoRAResolver(LoRAResolver):
def __init__(self):
self.s3 = s3fs.S3FileSystem()
self.s3_path_format = os.getenv("S3_PATH_TEMPLATE")
self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE")
async def resolve_lora(self, base_model_name, lora_name):
s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
# Download the LoRA from S3 to the local path
await self.s3._get(
s3_path, local_path, recursive=True, maxdepth=1
)
lora_request = LoRARequest(
lora_name=lora_name,
lora_path=local_path,
lora_int_id=abs(hash(lora_name))
)
return lora_request
```
2. Register LoRAResolver plugin.
```python
from vllm.lora.resolver import LoRAResolverRegistry
s3_resolver = S3LoRAResolver()
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
```
For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md).
## New format for `--lora-modules`
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:

View File

@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import suppress
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Optional
from unittest.mock import MagicMock
import pytest
from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
MOCK_RESOLVER_NAME = "mock_test_resolver"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
max_model_len: int = 100
tokenizer_revision: Optional[str] = None
multimodal_config: MultiModalConfig = field(
default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
encoder_config = None
generation_config: str = "auto"
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
class MockLoRAResolver(LoRAResolver):
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
if lora_name == "test-lora":
return LoRARequest(lora_name="test-lora",
lora_int_id=1,
lora_local_path="/fake/path/test-lora")
elif lora_name == "invalid-lora":
return LoRARequest(lora_name="invalid-lora",
lora_int_id=2,
lora_local_path="/fake/path/invalid-lora")
return None
@pytest.fixture(autouse=True)
def register_mock_resolver():
"""Fixture to register and unregister the mock LoRA resolver."""
resolver = MockLoRAResolver()
LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
yield
# Cleanup: remove the resolver after the test runs
if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
# Simulate successful addition
return
elif lora_request.lora_name == "invalid-lora":
# Simulate failure during addition (e.g. invalid format)
raise ValueError(f"Simulated failure adding LoRA: "
f"{lora_request.lora_name}")
mock_engine.add_lora.side_effect = mock_add_lora_side_effect
mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock()
mock_model_config = MockModelConfig()
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_completion = OpenAIServingCompletion(mock_engine,
mock_model_config,
models,
request_logger=None)
return mock_engine, serving_completion
@pytest.mark.asyncio
async def test_serving_completion_with_lora_resolver(mock_serving_setup,
monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)
# Suppress potential errors during the mocked generate call,
# as we are primarily checking for add_lora and generate calls
with suppress(Exception):
await serving_completion.create_completion(req_found)
mock_engine.add_lora.assert_called_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name
mock_engine.generate.assert_called_once()
called_lora_request = mock_engine.generate.call_args[1]['lora_request']
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name
@pytest.mark.asyncio
async def test_serving_completion_resolver_not_found(mock_serving_setup,
monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
non_existent_model = "non-existent-lora-adapter"
req = CompletionRequest(
model=non_existent_model,
prompt="what is 1+1?",
)
response = await serving_completion.create_completion(req)
mock_engine.add_lora.assert_not_called()
mock_engine.generate.assert_not_called()
assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.NOT_FOUND.value
assert non_existent_model in response.message
@pytest.mark.asyncio
async def test_serving_completion_resolver_add_lora_fails(
mock_serving_setup, monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")
mock_engine, serving_completion = mock_serving_setup
invalid_model = "invalid-lora"
req = CompletionRequest(
model=invalid_model,
prompt="what is 1+1?",
)
response = await serving_completion.create_completion(req)
# Assert add_lora was called before the failure
mock_engine.add_lora.assert_called_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == invalid_model
# Assert generate was *not* called due to the failure
mock_engine.generate.assert_not_called()
# Assert the correct error response
assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.BAD_REQUEST.value
assert invalid_model in response.message
@pytest.mark.asyncio
async def test_serving_completion_flag_not_set(mock_serving_setup):
mock_engine, serving_completion = mock_serving_setup
lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)
await serving_completion.create_completion(req_found)
mock_engine.add_lora.assert_not_called()
mock_engine.generate.assert_not_called()

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
class DummyLoRAResolver(LoRAResolver):
"""A dummy LoRA resolver for testing."""
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
if lora_name == "test_lora":
return LoRARequest(
lora_name=lora_name,
lora_path=f"/dummy/path/{base_model_name}/{lora_name}",
lora_int_id=abs(hash(lora_name)))
return None
def test_resolver_registry_registration():
"""Test basic resolver registration functionality."""
registry = LoRAResolverRegistry
resolver = DummyLoRAResolver()
# Register a new resolver
registry.register_resolver("dummy", resolver)
assert "dummy" in registry.get_supported_resolvers()
# Get registered resolver
retrieved_resolver = registry.get_resolver("dummy")
assert retrieved_resolver is resolver
def test_resolver_registry_duplicate_registration():
"""Test registering a resolver with an existing name."""
registry = LoRAResolverRegistry
resolver1 = DummyLoRAResolver()
resolver2 = DummyLoRAResolver()
registry.register_resolver("dummy", resolver1)
registry.register_resolver("dummy", resolver2)
assert registry.get_resolver("dummy") is resolver2
def test_resolver_registry_unknown_resolver():
"""Test getting a non-existent resolver."""
registry = LoRAResolverRegistry
with pytest.raises(KeyError, match="not found"):
registry.get_resolver("unknown_resolver")
@pytest.mark.asyncio
async def test_dummy_resolver_resolve():
"""Test the dummy resolver's resolve functionality."""
dummy_resolver = DummyLoRAResolver()
base_model_name = "base_model_test"
lora_name = "test_lora"
# Test successful resolution
result = await dummy_resolver.resolve_lora(base_model_name, lora_name)
assert isinstance(result, LoRARequest)
assert result.lora_name == lora_name
assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}"
# Test failed resolution
result = await dummy_resolver.resolve_lora(base_model_name,
"nonexistent_lora")
assert result is None

View File

@ -10,6 +10,7 @@ from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
@ -125,18 +126,29 @@ class OpenAIServing:
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
error_response = None
if self._is_model_supported(request.model):
return None
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
load_result := await self.models.resolve_lora(request.model)):
if isinstance(load_result, LoRARequest):
return None
if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return self.create_error_response(
return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)

View File

@ -2,6 +2,8 @@
import json
import pathlib
from asyncio import Lock
from collections import defaultdict
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
UnloadLoRAAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
@ -63,11 +66,19 @@ class OpenAIServingModels:
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.model_config = model_config
self.static_lora_modules = lora_modules
self.lora_requests: list[LoRARequest] = []
self.lora_id_counter = AtomicCounter(0)
self.lora_resolvers: list[LoRAResolver] = []
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
):
self.lora_resolvers.append(
LoRAResolverRegistry.get_resolver(lora_resolver_name))
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
@ -234,6 +245,65 @@ class OpenAIServingModels:
return None
async def resolve_lora(
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async with self.lora_resolver_lock[lora_name]:
# First check if this LoRA is already loaded
for existing in self.lora_requests:
if existing.lora_name == lora_name:
return existing
base_model_name = self.model_config.model
unique_id = self.lora_id_counter.inc(1)
found_adapter = False
# Try to resolve using available resolvers
for resolver in self.lora_resolvers:
lora_request = await resolver.resolve_lora(
base_model_name, lora_name)
if lora_request is not None:
found_adapter = True
lora_request.lora_int_id = unique_id
try:
await self.engine_client.add_lora(lora_request)
self.lora_requests.append(lora_request)
logger.info(
"Resolved and loaded LoRA adapter '%s' using %s",
lora_name, resolver.__class__.__name__)
return lora_request
except BaseException as e:
logger.warning(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver.", lora_name,
resolver.__class__.__name__, e)
continue
if found_adapter:
# An adapter was found, but all attempts to load it failed.
return create_error_response(
message=(f"LoRA adapter '{lora_name}' was found "
"but could not be loaded."),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
else:
# No adapter was found
return create_error_response(
message=f"LoRA adapter {lora_name} does not exist",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def create_error_response(
message: str,

83
vllm/lora/resolver.py Normal file
View File

@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AbstractSet, Dict, Optional
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@abstractmethod
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@dataclass
class _LoRAResolverRegistry:
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> AbstractSet[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()
def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.", resolver_name,
resolver)
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}")
return self.resolvers[resolver_name]
LoRAResolverRegistry = _LoRAResolverRegistry()