[misc][plugin] add plugin system implementation (#7426)
This commit is contained in:
parent
373538f973
commit
16422ea76f
@ -77,11 +77,13 @@ steps:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 20min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/openai
|
||||
|
||||
@ -154,6 +156,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models -m \"not vlm\"
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
@ -289,6 +292,7 @@ steps:
|
||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
|
@ -23,4 +23,5 @@ pyzmq
|
||||
librosa # Required for audio processing
|
||||
soundfile # Required for audio processing
|
||||
gguf == 0.9.1
|
||||
importlib_metadata
|
||||
compressed-tensors == 0.5.0
|
||||
|
@ -1,7 +1,9 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from collections import UserList
|
||||
from enum import Enum
|
||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
|
||||
@ -11,6 +13,7 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||
@ -757,3 +760,26 @@ def num_gpus_available():
|
||||
in current process."""
|
||||
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
_dummy_path = os.path.join(temp_dir, "dummy_opt")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_opt_path():
|
||||
json_path = os.path.join(_dummy_path, "config.json")
|
||||
if not os.path.exists(_dummy_path):
|
||||
snapshot_download(repo_id="facebook/opt-125m",
|
||||
local_dir=_dummy_path,
|
||||
ignore_patterns=[
|
||||
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||
"*.msgpack"
|
||||
])
|
||||
assert os.path.exists(json_path)
|
||||
with open(json_path, "r") as f:
|
||||
config = json.load(f)
|
||||
config["architectures"] = ["MyOPTForCausalLM"]
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
return _dummy_path
|
||||
|
6
tests/distributed/test_distributed_oot.py
Normal file
6
tests/distributed/test_distributed_oot.py
Normal file
@ -0,0 +1,6 @@
|
||||
from ..entrypoints.openai.test_oot_registration import (
|
||||
run_and_test_dummy_opt_api_server)
|
||||
|
||||
|
||||
def test_distributed_oot(dummy_opt_path: str):
|
||||
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)
|
@ -1,70 +1,27 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
from openai import OpenAI, OpenAIError
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
|
||||
|
||||
def server_function(port: int):
|
||||
# register our dummy model
|
||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||
|
||||
sys.argv = ["placeholder.py"] + [
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
def run_and_test_dummy_opt_api_server(model, tp=1):
|
||||
# the model is registered through the plugin
|
||||
server_args = [
|
||||
"--gpu-memory-utilization",
|
||||
"0.10",
|
||||
"--dtype",
|
||||
"float32",
|
||||
"--api-key",
|
||||
"token-abc123",
|
||||
"--port",
|
||||
str(port),
|
||||
"--chat-template",
|
||||
str(chatml_jinja_path),
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"-tp",
|
||||
f"{tp}",
|
||||
]
|
||||
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||
|
||||
|
||||
def test_oot_registration_for_api_server():
|
||||
port = get_open_port()
|
||||
ctx = torch.multiprocessing.get_context()
|
||||
server = ctx.Process(target=server_function, args=(port, ))
|
||||
server.start()
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
base_url=f"http://localhost:{port}/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
now = time.time()
|
||||
while True:
|
||||
try:
|
||||
with RemoteOpenAIServer(model, server_args) as server:
|
||||
client = server.get_client()
|
||||
completion = client.chat.completions.create(
|
||||
model="facebook/opt-125m",
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
@ -74,21 +31,12 @@ def test_oot_registration_for_api_server():
|
||||
}],
|
||||
temperature=0,
|
||||
)
|
||||
break
|
||||
except OpenAIError as e:
|
||||
if "Connection error" in str(e):
|
||||
time.sleep(3)
|
||||
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
|
||||
msg = "Server did not start in time"
|
||||
raise RuntimeError(msg) from e
|
||||
else:
|
||||
raise e
|
||||
finally:
|
||||
server.terminate()
|
||||
|
||||
generated_text = completion.choices[0].message.content
|
||||
assert generated_text is not None
|
||||
# make sure only the first token is generated
|
||||
# TODO(youkaichao): Fix the test with plugin
|
||||
rest = generated_text.replace("<s>", "") # noqa
|
||||
# assert rest == ""
|
||||
rest = generated_text.replace("<s>", "")
|
||||
assert rest == ""
|
||||
|
||||
|
||||
def test_oot_registration_for_api_server(dummy_opt_path: str):
|
||||
run_and_test_dummy_opt_api_server(dummy_opt_path)
|
||||
|
@ -1,32 +1,27 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, ModelRegistry, SamplingParams
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE: the order of the tests is important
|
||||
# the first test does not load any plugins
|
||||
# the second test loads the plugin
|
||||
# they share the same process, so the plugin is loaded for the second test
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
def test_plugin(dummy_opt_path):
|
||||
os.environ["VLLM_PLUGINS"] = ""
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
LLM(model=dummy_opt_path, load_format="dummy")
|
||||
assert "are not supported for now" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_oot_registration():
|
||||
# register our dummy model
|
||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||
def test_oot_registration(dummy_opt_path):
|
||||
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||
prompts = ["Hello, my name is", "The text does not matter"]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model=dummy_opt_path, load_format="dummy")
|
||||
first_token = llm.get_tokenizer().decode(0)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
|
9
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
9
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
@ -0,0 +1,9 @@
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
def compute_logits(
|
||||
self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
||||
# this dummy model always predicts the first token
|
||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||
if logits is not None:
|
||||
logits.zero_()
|
||||
logits[:, 0] += 1.0
|
||||
return logits
|
||||
|
||||
|
||||
def register():
|
||||
# register our dummy model
|
||||
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
|
@ -227,6 +227,9 @@ class LLMEngine:
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
|
10
vllm/envs.py
10
vllm/envs.py
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
VLLM_HOST_IP: str = ""
|
||||
@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
||||
VERBOSE: bool = False
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -362,6 +363,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda:
|
||||
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
|
||||
("1", "true")),
|
||||
|
||||
# a list of plugin names to load, separated by commas.
|
||||
# if this is not set, it means all plugins will be loaded
|
||||
# if this is set to an empty string, no plugins will be loaded
|
||||
"VLLM_PLUGINS":
|
||||
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
||||
"VLLM_PLUGINS"].split(","),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -166,7 +166,7 @@ class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
||||
|
||||
@staticmethod
|
||||
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||
|
31
vllm/plugins/__init__.py
Normal file
31
vllm/plugins/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_general_plugins():
|
||||
"""WARNING: plugins can be loaded for multiple times in different
|
||||
processes. They should be designed in a way that they can be loaded
|
||||
multiple times without causing issues.
|
||||
"""
|
||||
import sys
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
allowed_plugins = envs.VLLM_PLUGINS
|
||||
|
||||
discovered_plugins = entry_points(group='vllm.general_plugins')
|
||||
for plugin in discovered_plugins:
|
||||
logger.info("Found general plugin: %s", plugin.name)
|
||||
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||
try:
|
||||
func = plugin.load()
|
||||
func()
|
||||
logger.info("Loaded general plugin: %s", plugin.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to load general plugin: %s",
|
||||
plugin.name)
|
@ -411,6 +411,9 @@ class WorkerWrapperBase:
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
load_general_plugins()
|
||||
|
||||
if self.worker_class_fn:
|
||||
worker_class = self.worker_class_fn()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user