2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-09-05 18:10:33 -07:00
|
|
|
from http import HTTPStatus
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm.config import ModelConfig
|
2025-01-10 00:56:36 -07:00
|
|
|
from vllm.engine.protocol import EngineClient
|
2024-09-05 18:10:33 -07:00
|
|
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
|
|
|
LoadLoraAdapterRequest,
|
|
|
|
UnloadLoraAdapterRequest)
|
2024-12-31 18:21:51 -08:00
|
|
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
|
|
|
OpenAIServingModels)
|
2024-12-12 01:25:16 -08:00
|
|
|
from vllm.lora.request import LoRARequest
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
MODEL_NAME = "meta-llama/Llama-2-7b"
|
2024-09-19 23:20:56 -07:00
|
|
|
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
2024-09-05 18:10:33 -07:00
|
|
|
LORA_LOADING_SUCCESS_MESSAGE = (
|
|
|
|
"Success: LoRA adapter '{lora_name}' added successfully.")
|
|
|
|
LORA_UNLOADING_SUCCESS_MESSAGE = (
|
|
|
|
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
|
|
|
|
|
|
|
|
2024-12-31 18:21:51 -08:00
|
|
|
async def _async_serving_models_init() -> OpenAIServingModels:
|
2024-09-05 18:10:33 -07:00
|
|
|
mock_model_config = MagicMock(spec=ModelConfig)
|
2025-01-10 00:56:36 -07:00
|
|
|
mock_engine_client = MagicMock(spec=EngineClient)
|
2024-09-05 18:10:33 -07:00
|
|
|
# Set the max_model_len attribute to avoid missing attribute
|
|
|
|
mock_model_config.max_model_len = 2048
|
|
|
|
|
2025-01-10 00:56:36 -07:00
|
|
|
serving_models = OpenAIServingModels(engine_client=mock_engine_client,
|
|
|
|
base_model_paths=BASE_MODEL_PATHS,
|
2024-12-31 18:21:51 -08:00
|
|
|
model_config=mock_model_config,
|
|
|
|
lora_modules=None,
|
|
|
|
prompt_adapters=None)
|
2025-01-10 00:56:36 -07:00
|
|
|
await serving_models.init_static_loras()
|
2024-12-31 18:21:51 -08:00
|
|
|
|
|
|
|
return serving_models
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
|
2024-12-12 01:25:16 -08:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_serving_model_name():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
|
|
|
assert serving_models.model_name(None) == MODEL_NAME
|
2024-12-12 01:25:16 -08:00
|
|
|
request = LoRARequest(lora_name="adapter",
|
|
|
|
lora_path="/path/to/adapter2",
|
|
|
|
lora_int_id=1)
|
2024-12-31 18:21:51 -08:00
|
|
|
assert serving_models.model_name(request) == request.lora_name
|
2024-12-12 01:25:16 -08:00
|
|
|
|
|
|
|
|
2024-09-05 18:10:33 -07:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_load_lora_adapter_success():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = LoadLoraAdapterRequest(lora_name="adapter",
|
|
|
|
lora_path="/path/to/adapter2")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.load_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
2024-12-31 18:21:51 -08:00
|
|
|
assert len(serving_models.lora_requests) == 1
|
|
|
|
assert serving_models.lora_requests[0].lora_name == "adapter"
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_load_lora_adapter_missing_fields():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.load_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert isinstance(response, ErrorResponse)
|
|
|
|
assert response.type == "InvalidUserInput"
|
|
|
|
assert response.code == HTTPStatus.BAD_REQUEST
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_load_lora_adapter_duplicate():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
|
|
|
lora_path="/path/to/adapter1")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.load_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
|
|
|
lora_name='adapter1')
|
2024-12-31 18:21:51 -08:00
|
|
|
assert len(serving_models.lora_requests) == 1
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
|
|
|
lora_path="/path/to/adapter1")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.load_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert isinstance(response, ErrorResponse)
|
|
|
|
assert response.type == "InvalidUserInput"
|
|
|
|
assert response.code == HTTPStatus.BAD_REQUEST
|
2024-12-31 18:21:51 -08:00
|
|
|
assert len(serving_models.lora_requests) == 1
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_unload_lora_adapter_success():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
|
|
|
lora_path="/path/to/adapter1")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.load_lora_adapter(request)
|
|
|
|
assert len(serving_models.lora_requests) == 1
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.unload_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
|
|
|
lora_name='adapter1')
|
2024-12-31 18:21:51 -08:00
|
|
|
assert len(serving_models.lora_requests) == 0
|
2024-09-05 18:10:33 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_unload_lora_adapter_missing_fields():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.unload_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert isinstance(response, ErrorResponse)
|
|
|
|
assert response.type == "InvalidUserInput"
|
|
|
|
assert response.code == HTTPStatus.BAD_REQUEST
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_unload_lora_adapter_not_found():
|
2024-12-31 18:21:51 -08:00
|
|
|
serving_models = await _async_serving_models_init()
|
2024-09-05 18:10:33 -07:00
|
|
|
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
2024-12-31 18:21:51 -08:00
|
|
|
response = await serving_models.unload_lora_adapter(request)
|
2024-09-05 18:10:33 -07:00
|
|
|
assert isinstance(response, ErrorResponse)
|
2025-01-10 00:56:36 -07:00
|
|
|
assert response.type == "NotFoundError"
|
|
|
|
assert response.code == HTTPStatus.NOT_FOUND
|