[misc][plugin] add plugin system implementation (#7426)

This commit is contained in:
youkaichao 2024-08-13 16:24:17 -07:00 committed by GitHub
parent 373538f973
commit 16422ea76f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 162 additions and 102 deletions

View File

@ -77,11 +77,13 @@ steps:
- pytest -v -s core - pytest -v -s core
- label: Entrypoints Test # 20min - label: Entrypoints Test # 20min
working_dir: "/vllm-workspace/tests"
fast_check: true fast_check: true
mirror_hardwares: [amd] mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
commands: commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm - pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/openai - pytest -v -s entrypoints/openai
@ -154,6 +156,7 @@ steps:
- vllm/ - vllm/
- tests/models - tests/models
commands: commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models -m \"not vlm\" - pytest -v -s models -m \"not vlm\"
- label: Vision Language Models Test # 42min - label: Vision Language Models Test # 42min
@ -289,6 +292,7 @@ steps:
- pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

View File

@ -23,4 +23,5 @@ pyzmq
librosa # Required for audio processing librosa # Required for audio processing
soundfile # Required for audio processing soundfile # Required for audio processing
gguf == 0.9.1 gguf == 0.9.1
importlib_metadata
compressed-tensors == 0.5.0 compressed-tensors == 0.5.0

View File

@ -1,7 +1,9 @@
import contextlib import contextlib
import gc import gc
import json
import os import os
import sys import sys
import tempfile
from collections import UserList from collections import UserList
from enum import Enum from enum import Enum
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
@ -11,6 +13,7 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
@ -757,3 +760,26 @@ def num_gpus_available():
in current process.""" in current process."""
return cuda_device_count_stateless() return cuda_device_count_stateless()
temp_dir = tempfile.gettempdir()
_dummy_path = os.path.join(temp_dir, "dummy_opt")
@pytest.fixture
def dummy_opt_path():
json_path = os.path.join(_dummy_path, "config.json")
if not os.path.exists(_dummy_path):
snapshot_download(repo_id="facebook/opt-125m",
local_dir=_dummy_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyOPTForCausalLM"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_path

View File

@ -0,0 +1,6 @@
from ..entrypoints.openai.test_oot_registration import (
run_and_test_dummy_opt_api_server)
def test_distributed_oot(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)

View File

@ -1,70 +1,27 @@
import sys
import time
import torch
from openai import OpenAI, OpenAIError
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists() assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM): def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin
def compute_logits(self, hidden_states: torch.Tensor, server_args = [
sampling_metadata: SamplingMetadata) -> torch.Tensor:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def server_function(port: int):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + [
"--model",
"facebook/opt-125m",
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.10", "0.10",
"--dtype", "--dtype",
"float32", "float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template", "--chat-template",
str(chatml_jinja_path), str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
] ]
with RemoteOpenAIServer(model, server_args) as server:
import runpy client = server.get_client()
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server():
port = get_open_port()
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
try:
client = OpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
now = time.time()
while True:
try:
completion = client.chat.completions.create( completion = client.chat.completions.create(
model="facebook/opt-125m", model=model,
messages=[{ messages=[{
"role": "system", "role": "system",
"content": "You are a helpful assistant." "content": "You are a helpful assistant."
@ -74,21 +31,12 @@ def test_oot_registration_for_api_server():
}], }],
temperature=0, temperature=0,
) )
break
except OpenAIError as e:
if "Connection error" in str(e):
time.sleep(3)
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
msg = "Server did not start in time"
raise RuntimeError(msg) from e
else:
raise e
finally:
server.terminate()
generated_text = completion.choices[0].message.content generated_text = completion.choices[0].message.content
assert generated_text is not None assert generated_text is not None
# make sure only the first token is generated # make sure only the first token is generated
# TODO(youkaichao): Fix the test with plugin rest = generated_text.replace("<s>", "")
rest = generated_text.replace("<s>", "") # noqa assert rest == ""
# assert rest == ""
def test_oot_registration_for_api_server(dummy_opt_path: str):
run_and_test_dummy_opt_api_server(dummy_opt_path)

View File

@ -1,32 +1,27 @@
from typing import Optional import os
import torch import pytest
from vllm import LLM, ModelRegistry, SamplingParams from vllm import LLM, SamplingParams
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata # NOTE: the order of the tests is important
# the first test does not load any plugins
# the second test loads the plugin
# they share the same process, so the plugin is loaded for the second test
class MyOPTForCausalLM(OPTForCausalLM): def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
def compute_logits( with pytest.raises(Exception) as excinfo:
self, LLM(model=dummy_opt_path, load_format="dummy")
hidden_states: torch.Tensor, assert "are not supported for now" in str(excinfo.value)
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
logits.zero_()
logits[:, 0] += 1.0
return logits
def test_oot_registration(): def test_oot_registration(dummy_opt_path):
# register our dummy model os.environ["VLLM_PLUGINS"] = "register_dummy_model"
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
prompts = ["Hello, my name is", "The text does not matter"] prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
llm = LLM(model="facebook/opt-125m") llm = LLM(model=dummy_opt_path, load_format="dummy")
first_token = llm.get_tokenizer().decode(0) first_token = llm.get_tokenizer().decode(0)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)

View File

@ -0,0 +1,9 @@
from setuptools import setup
setup(name='vllm_add_dummy_model',
version='0.1',
packages=['vllm_add_dummy_model'],
entry_points={
'vllm.general_plugins':
["register_dummy_model = vllm_add_dummy_model:register"]
})

View File

@ -0,0 +1,26 @@
from typing import Optional
import torch
from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
class MyOPTForCausalLM(OPTForCausalLM):
def compute_logits(
self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
# this dummy model always predicts the first token
logits = super().compute_logits(hidden_states, sampling_metadata)
if logits is not None:
logits.zero_()
logits[:, 0] += 1.0
return logits
def register():
# register our dummy model
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

View File

@ -227,6 +227,9 @@ class LLMEngine:
) )
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
from vllm.plugins import load_general_plugins
load_general_plugins()
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config

View File

@ -1,6 +1,6 @@
import os import os
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
VLLM_HOST_IP: str = "" VLLM_HOST_IP: str = ""
@ -55,6 +55,7 @@ if TYPE_CHECKING:
VERBOSE: bool = False VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
def get_default_cache_root(): def get_default_cache_root():
@ -362,6 +363,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: lambda:
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
("1", "true")), ("1", "true")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS":
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
"VLLM_PLUGINS"].split(","),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@ -166,7 +166,7 @@ class ModelRegistry:
@staticmethod @staticmethod
def get_supported_archs() -> List[str]: def get_supported_archs() -> List[str]:
return list(_MODELS.keys()) return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod @staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]): def register_model(model_arch: str, model_cls: Type[nn.Module]):

31
vllm/plugins/__init__.py Normal file
View File

@ -0,0 +1,31 @@
import logging
import vllm.envs as envs
logger = logging.getLogger(__name__)
def load_general_plugins():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
allowed_plugins = envs.VLLM_PLUGINS
discovered_plugins = entry_points(group='vllm.general_plugins')
for plugin in discovered_plugins:
logger.info("Found general plugin: %s", plugin.name)
if allowed_plugins is None or plugin.name in allowed_plugins:
try:
func = plugin.load()
func()
logger.info("Loaded general plugin: %s", plugin.name)
except Exception:
logger.exception("Failed to load general plugin: %s",
plugin.name)

View File

@ -411,6 +411,9 @@ class WorkerWrapperBase:
# see https://github.com/NVIDIA/nccl/issues/1234 # see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0' os.environ['NCCL_CUMEM_ENABLE'] = '0'
from vllm.plugins import load_general_plugins
load_general_plugins()
if self.worker_class_fn: if self.worker_class_fn:
worker_class = self.worker_class_fn() worker_class = self.worker_class_fn()
else: else: