[misc][plugin] add plugin system implementation (#7426)
This commit is contained in:
parent
373538f973
commit
16422ea76f
@ -77,11 +77,13 @@ steps:
|
|||||||
- pytest -v -s core
|
- pytest -v -s core
|
||||||
|
|
||||||
- label: Entrypoints Test # 20min
|
- label: Entrypoints Test # 20min
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
commands:
|
commands:
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
- pytest -v -s entrypoints/llm
|
- pytest -v -s entrypoints/llm
|
||||||
- pytest -v -s entrypoints/openai
|
- pytest -v -s entrypoints/openai
|
||||||
|
|
||||||
@ -154,6 +156,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/models
|
- tests/models
|
||||||
commands:
|
commands:
|
||||||
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
- pytest -v -s models -m \"not vlm\"
|
- pytest -v -s models -m \"not vlm\"
|
||||||
|
|
||||||
- label: Vision Language Models Test # 42min
|
- label: Vision Language Models Test # 42min
|
||||||
@ -289,6 +292,7 @@ steps:
|
|||||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
|
- pytest -v -s distributed/test_distributed_oot.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||||
|
|
||||||
|
@ -23,4 +23,5 @@ pyzmq
|
|||||||
librosa # Required for audio processing
|
librosa # Required for audio processing
|
||||||
soundfile # Required for audio processing
|
soundfile # Required for audio processing
|
||||||
gguf == 0.9.1
|
gguf == 0.9.1
|
||||||
|
importlib_metadata
|
||||||
compressed-tensors == 0.5.0
|
compressed-tensors == 0.5.0
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from collections import UserList
|
from collections import UserList
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
|
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
|
||||||
@ -11,6 +13,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||||
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||||
@ -757,3 +760,26 @@ def num_gpus_available():
|
|||||||
in current process."""
|
in current process."""
|
||||||
|
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|
||||||
|
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
_dummy_path = os.path.join(temp_dir, "dummy_opt")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_opt_path():
|
||||||
|
json_path = os.path.join(_dummy_path, "config.json")
|
||||||
|
if not os.path.exists(_dummy_path):
|
||||||
|
snapshot_download(repo_id="facebook/opt-125m",
|
||||||
|
local_dir=_dummy_path,
|
||||||
|
ignore_patterns=[
|
||||||
|
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
|
||||||
|
"*.msgpack"
|
||||||
|
])
|
||||||
|
assert os.path.exists(json_path)
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
config["architectures"] = ["MyOPTForCausalLM"]
|
||||||
|
with open(json_path, "w") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
return _dummy_path
|
||||||
|
6
tests/distributed/test_distributed_oot.py
Normal file
6
tests/distributed/test_distributed_oot.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
from ..entrypoints.openai.test_oot_registration import (
|
||||||
|
run_and_test_dummy_opt_api_server)
|
||||||
|
|
||||||
|
|
||||||
|
def test_distributed_oot(dummy_opt_path: str):
|
||||||
|
run_and_test_dummy_opt_api_server(dummy_opt_path, tp=2)
|
@ -1,94 +1,42 @@
|
|||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from openai import OpenAI, OpenAIError
|
|
||||||
|
|
||||||
from vllm import ModelRegistry
|
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.utils import get_open_port
|
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
|
|
||||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
assert chatml_jinja_path.exists()
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
|
|
||||||
class MyOPTForCausalLM(OPTForCausalLM):
|
def run_and_test_dummy_opt_api_server(model, tp=1):
|
||||||
|
# the model is registered through the plugin
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
server_args = [
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
||||||
# this dummy model always predicts the first token
|
|
||||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
|
||||||
logits.zero_()
|
|
||||||
logits[:, 0] += 1.0
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def server_function(port: int):
|
|
||||||
# register our dummy model
|
|
||||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
|
||||||
|
|
||||||
sys.argv = ["placeholder.py"] + [
|
|
||||||
"--model",
|
|
||||||
"facebook/opt-125m",
|
|
||||||
"--gpu-memory-utilization",
|
"--gpu-memory-utilization",
|
||||||
"0.10",
|
"0.10",
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"float32",
|
"float32",
|
||||||
"--api-key",
|
|
||||||
"token-abc123",
|
|
||||||
"--port",
|
|
||||||
str(port),
|
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
str(chatml_jinja_path),
|
str(chatml_jinja_path),
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
|
"-tp",
|
||||||
|
f"{tp}",
|
||||||
]
|
]
|
||||||
|
with RemoteOpenAIServer(model, server_args) as server:
|
||||||
import runpy
|
client = server.get_client()
|
||||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
def test_oot_registration_for_api_server():
|
"role": "system",
|
||||||
port = get_open_port()
|
"content": "You are a helpful assistant."
|
||||||
ctx = torch.multiprocessing.get_context()
|
}, {
|
||||||
server = ctx.Process(target=server_function, args=(port, ))
|
"role": "user",
|
||||||
server.start()
|
"content": "Hello!"
|
||||||
|
}],
|
||||||
try:
|
temperature=0,
|
||||||
client = OpenAI(
|
|
||||||
base_url=f"http://localhost:{port}/v1",
|
|
||||||
api_key="token-abc123",
|
|
||||||
)
|
)
|
||||||
now = time.time()
|
generated_text = completion.choices[0].message.content
|
||||||
while True:
|
assert generated_text is not None
|
||||||
try:
|
# make sure only the first token is generated
|
||||||
completion = client.chat.completions.create(
|
rest = generated_text.replace("<s>", "")
|
||||||
model="facebook/opt-125m",
|
assert rest == ""
|
||||||
messages=[{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello!"
|
|
||||||
}],
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except OpenAIError as e:
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
time.sleep(3)
|
|
||||||
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
|
|
||||||
msg = "Server did not start in time"
|
|
||||||
raise RuntimeError(msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
server.terminate()
|
|
||||||
|
|
||||||
generated_text = completion.choices[0].message.content
|
|
||||||
assert generated_text is not None
|
def test_oot_registration_for_api_server(dummy_opt_path: str):
|
||||||
# make sure only the first token is generated
|
run_and_test_dummy_opt_api_server(dummy_opt_path)
|
||||||
# TODO(youkaichao): Fix the test with plugin
|
|
||||||
rest = generated_text.replace("<s>", "") # noqa
|
|
||||||
# assert rest == ""
|
|
||||||
|
@ -1,32 +1,27 @@
|
|||||||
from typing import Optional
|
import os
|
||||||
|
|
||||||
import torch
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, ModelRegistry, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
# NOTE: the order of the tests is important
|
||||||
|
# the first test does not load any plugins
|
||||||
|
# the second test loads the plugin
|
||||||
|
# they share the same process, so the plugin is loaded for the second test
|
||||||
|
|
||||||
|
|
||||||
class MyOPTForCausalLM(OPTForCausalLM):
|
def test_plugin(dummy_opt_path):
|
||||||
|
os.environ["VLLM_PLUGINS"] = ""
|
||||||
def compute_logits(
|
with pytest.raises(Exception) as excinfo:
|
||||||
self,
|
LLM(model=dummy_opt_path, load_format="dummy")
|
||||||
hidden_states: torch.Tensor,
|
assert "are not supported for now" in str(excinfo.value)
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
# this dummy model always predicts the first token
|
|
||||||
logits = super().compute_logits(hidden_states, sampling_metadata)
|
|
||||||
logits.zero_()
|
|
||||||
logits[:, 0] += 1.0
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def test_oot_registration():
|
def test_oot_registration(dummy_opt_path):
|
||||||
# register our dummy model
|
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
|
||||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
|
||||||
prompts = ["Hello, my name is", "The text does not matter"]
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||||||
sampling_params = SamplingParams(temperature=0)
|
sampling_params = SamplingParams(temperature=0)
|
||||||
llm = LLM(model="facebook/opt-125m")
|
llm = LLM(model=dummy_opt_path, load_format="dummy")
|
||||||
first_token = llm.get_tokenizer().decode(0)
|
first_token = llm.get_tokenizer().decode(0)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
9
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
9
tests/plugins/vllm_add_dummy_model/setup.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from setuptools import setup
|
||||||
|
|
||||||
|
setup(name='vllm_add_dummy_model',
|
||||||
|
version='0.1',
|
||||||
|
packages=['vllm_add_dummy_model'],
|
||||||
|
entry_points={
|
||||||
|
'vllm.general_plugins':
|
||||||
|
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||||
|
})
|
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import ModelRegistry
|
||||||
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MyOPTForCausalLM(OPTForCausalLM):
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
|
||||||
|
# this dummy model always predicts the first token
|
||||||
|
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||||||
|
if logits is not None:
|
||||||
|
logits.zero_()
|
||||||
|
logits[:, 0] += 1.0
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def register():
|
||||||
|
# register our dummy model
|
||||||
|
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs():
|
||||||
|
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
|
@ -227,6 +227,9 @@ class LLMEngine:
|
|||||||
)
|
)
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
|
from vllm.plugins import load_general_plugins
|
||||||
|
load_general_plugins()
|
||||||
|
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
10
vllm/envs.py
10
vllm/envs.py
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
VLLM_HOST_IP: str = ""
|
VLLM_HOST_IP: str = ""
|
||||||
@ -55,6 +55,7 @@ if TYPE_CHECKING:
|
|||||||
VERBOSE: bool = False
|
VERBOSE: bool = False
|
||||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||||
|
VLLM_PLUGINS: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -362,6 +363,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda:
|
lambda:
|
||||||
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
|
(os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in
|
||||||
("1", "true")),
|
("1", "true")),
|
||||||
|
|
||||||
|
# a list of plugin names to load, separated by commas.
|
||||||
|
# if this is not set, it means all plugins will be loaded
|
||||||
|
# if this is set to an empty string, no plugins will be loaded
|
||||||
|
"VLLM_PLUGINS":
|
||||||
|
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
||||||
|
"VLLM_PLUGINS"].split(","),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -166,7 +166,7 @@ class ModelRegistry:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_archs() -> List[str]:
|
def get_supported_archs() -> List[str]:
|
||||||
return list(_MODELS.keys())
|
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
def register_model(model_arch: str, model_cls: Type[nn.Module]):
|
||||||
|
31
vllm/plugins/__init__.py
Normal file
31
vllm/plugins/__init__.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_general_plugins():
|
||||||
|
"""WARNING: plugins can be loaded for multiple times in different
|
||||||
|
processes. They should be designed in a way that they can be loaded
|
||||||
|
multiple times without causing issues.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
if sys.version_info < (3, 10):
|
||||||
|
from importlib_metadata import entry_points
|
||||||
|
else:
|
||||||
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
|
allowed_plugins = envs.VLLM_PLUGINS
|
||||||
|
|
||||||
|
discovered_plugins = entry_points(group='vllm.general_plugins')
|
||||||
|
for plugin in discovered_plugins:
|
||||||
|
logger.info("Found general plugin: %s", plugin.name)
|
||||||
|
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||||
|
try:
|
||||||
|
func = plugin.load()
|
||||||
|
func()
|
||||||
|
logger.info("Loaded general plugin: %s", plugin.name)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load general plugin: %s",
|
||||||
|
plugin.name)
|
@ -411,6 +411,9 @@ class WorkerWrapperBase:
|
|||||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||||
|
|
||||||
|
from vllm.plugins import load_general_plugins
|
||||||
|
load_general_plugins()
|
||||||
|
|
||||||
if self.worker_class_fn:
|
if self.worker_class_fn:
|
||||||
worker_class = self.worker_class_fn()
|
worker_class = self.worker_class_fn()
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user