[misc][plugin] add plugin system implementation (#7426)

This commit is contained in:
youkaichao 2024-08-13 16:24:17 -07:00 committed by GitHub
parent 373538f973
commit 16422ea76f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 162 additions and 102 deletions

View File

@ -77,11 +77,13 @@ steps:
- pytest -v -s core
- label: Entrypoints Test # 20min
working_dir: "/vllm-workspace/tests"
fast_check: true
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/openai
@ -154,6 +156,7 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models -m \"not vlm\"
- label: Vision Language Models Test # 42min
@ -289,6 +292,7 @@ steps:
- pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

View File

@ -23,4 +23,5 @@ pyzmq
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
compressed-tensors == 0.5.0

View File

@ -1,7 +1,9 @@
import contextlib
import gc
import json
import os
import sys
import tempfile
from collections import UserList
from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
@ -11,6 +13,7 @@ import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
@ -757,3 +760,26 @@ def num_gpus_available():
in current process."""
return cuda_device_count_stateless()
temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")
@pytest.fixture
def dummy_opt_path():
json_path = os.path.join(_dummy_path, "config.json")
if not os.path.exists(_dummy_path):
snapshot_download(repo_id="facebook/opt-125m",
local_dir=_dummy_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path

View File

@ -0,0 +1,6 @@
from ..entrypoints.openai.test_oot_registration import (
run_and_test_dummy_opt_api_server)
def test_distributed_oot(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)

View File

@ -1,70 +1,27 @@
import sys
import time
import torch
from openai import OpenAI, OpenAIError
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
from ...utils import VLLM_PATH, RemoteOpenAIServer
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + [
"--model",
"facebook/opt-125m",
def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
]
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server():
port = get_open_port()
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
try:
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
with RemoteOpenAIServer(model, server_args) as server:
client = server.get_client()
completion = client.chat.completions.create(
model="facebook/opt-125m",
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
@ -74,21 +31,12 @@ def test_oot_registration_for_api_server():
}],
temperature=0,
)
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
msg = "Server did not start in time"
raise RuntimeError(msg) from e
else:
raise e
finally:
server.terminate()
generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated
# TODO(youkaichao): Fix the test with plugin
rest = generated_text.replace("<s>", "") # noqa
# assert rest == ""
rest = generated_text.replace("<s>", "")
assert rest == ""
def test_oot_registration_for_api_server(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path)

View File

@ -1,32 +1,27 @@
from typing import Optional
import os
import torch
import pytest
from vllm import LLM, ModelRegistry, SamplingParams
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm import LLM, SamplingParams
# NOTE: the order of the tests is important
# the first test does not load any plugins
# the second test loads the plugin
# they share the same process, so the plugin is loaded for the second test
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)
def test_oot_registration():
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
def test_oot_registration(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m")
llm = LLM(model=dummy_opt_path, load_format="dummy")
first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params)

View File

@ -0,0 +1,9 @@
from setuptools import setup
setup(name='vllm_add_dummy_model',
version='0.1',
packages=['vllm_add_dummy_model'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_model:register"]
})

View File

@ -0,0 +1,26 @@
from typing import Optional
import torch
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits
def register():
# register our dummy model
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

View File

@ -227,6 +227,9 @@ class LLMEngine:
)
# TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config

View File

@ -1,6 +1,6 @@
import os
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
@ -55,6 +55,7 @@ if TYPE_CHECKING:
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
def get_default_cache_root():
@ -362,6 +363,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
}
# end-env-vars-definition

View File

@ -166,7 +166,7 @@ class ModelRegistry:
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):

31
vllm/plugins/__init__.py Normal file
View File

@ -0,0 +1,31 @@
import logging
import vllm.envs as envs
logger = logging.getLogger(__name__)
def load_general_plugins():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
allowed_plugins = envs.VLLM_PLUGINS
discovered_plugins = entry_points(group='vllm.general_plugins')
for plugin in discovered_plugins:
logger.info("Found general plugin: %s", plugin.name)
if allowed_plugins is None or plugin.name in allowed_plugins:
try:
func = plugin.load()
func()
logger.info("Loaded general plugin: %s", plugin.name)
except Exception:
logger.exception("Failed to load general plugin: %s",
plugin.name)

View File

@ -411,6 +411,9 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
from vllm.plugins import load_general_plugins
load_general_plugins()
if self.worker_class_fn:
worker_class = self.worker_class_fn()
else: