[platforms] enable platform plugins (#11602)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-30 20:24:45 +08:00 committed by GitHub
parent 5dbf854553
commit b12e87f942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 354 additions and 175 deletions

View File

@ -106,14 +106,12 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
@ -333,8 +331,6 @@ steps:
- vllm/
- tests/models
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py
@ -469,11 +465,28 @@ steps:
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
- label: Plugin Tests (2 GPUs) # 40min
working_dir: "/vllm-workspace/tests"
num_gpus: 2
fast_check: true
source_file_dependencies:
- vllm/plugins/
- tests/plugins/
commands:
# begin platform plugin tests, all the code in-between runs on dummy platform
- pip install -e ./plugins/vllm_add_dummy_platform
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# other tests continue here:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
- pytest -v -s models/test_oot_registration.py # it needs a clean process
- label: Multi-step Tests (4 GPUs) # 36min
working_dir: "/vllm-workspace/tests"
num_gpus: 4

View File

@ -41,9 +41,11 @@ Every plugin has three parts:
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
## What Can Plugins Do?
## Types of supported plugins
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
## Guidelines for Writing Plugins

View File

@ -31,7 +31,6 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
@ -242,6 +241,7 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
class HfRunner:
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
from vllm.platforms import current_platform
if x is None or isinstance(x, (bool, )):
return x

View File

@ -5,7 +5,10 @@ import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name

View File

@ -0,0 +1,11 @@
from setuptools import setup
setup(
name='vllm_add_dummy_platform',
version='0.1',
packages=['vllm_add_dummy_platform'],
entry_points={
'vllm.platform_plugins': [
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
]
})

View File

@ -0,0 +1,5 @@
from typing import Optional
def dummy_platform_plugin() -> Optional[str]:
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"

View File

@ -0,0 +1,5 @@
from vllm.platforms.cuda import CudaPlatform
class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"

View File

@ -0,0 +1,16 @@
def test_platform_plugins():
# simulate workload by running an example
import runpy
current_file = __file__
import os
example_file = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
"examples", "offline_inference.py")
runpy.run_path(example_file)
# check if the plugin is loaded correctly
from vllm.platforms import _init_trace, current_platform
assert current_platform.device_name == "DummyDevice", (
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform, interface
from vllm.platforms import CpuArchEnum
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
@ -349,6 +349,7 @@ class ModelConfig:
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state()
from vllm.platforms import current_platform
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
@ -589,6 +590,7 @@ class ModelConfig:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
@ -644,6 +646,7 @@ class ModelConfig:
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
logger.warning(
"Async output processing is not supported on the "
@ -1012,6 +1015,7 @@ class CacheConfig:
raise ValueError(
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
from vllm.platforms import current_platform
if (current_platform.is_cuda() and self.block_size is not None
and self.block_size > 32):
raise ValueError("CUDA Paged Attention kernel only supports "
@ -1279,6 +1283,7 @@ class ParallelConfig:
f"distributed executor backend "
f"'{self.distributed_executor_backend}'.")
ray_only_devices = ["tpu", "hpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices
and self.world_size > 1):
if self.distributed_executor_backend is None:
@ -1327,7 +1332,7 @@ class ParallelConfig:
def _verify_args(self) -> None:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
if self.distributed_executor_backend not in (
"ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
@ -1528,6 +1533,7 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError("Failed to infer device type")
@ -2241,9 +2247,10 @@ def _get_and_verify_dtype(
else:
torch_dtype = config_dtype
from vllm.platforms import current_platform
if (current_platform.is_cpu()
and current_platform.get_cpu_architecture()
== interface.CpuArchEnum.POWERPC
== CpuArchEnum.POWERPC
and (config_dtype == torch.float16
or config_dtype == torch.float32)):
logger.info(
@ -3083,6 +3090,7 @@ class VllmConfig:
model_config: ModelConfig,
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
from vllm.platforms import current_platform
if model_config.quantization is not None:
from vllm.model_executor.model_loader.weight_utils import (
get_quant_config)
@ -3145,6 +3153,7 @@ class VllmConfig:
self.quant_config = VllmConfig._get_quantization_config(
self.model_config, self.load_config)
from vllm.platforms import current_platform
if self.scheduler_config is not None and \
self.model_config is not None and \
self.scheduler_config.chunked_prefill_enabled and \

View File

@ -39,7 +39,6 @@ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, supports_custom_op
if TYPE_CHECKING:
@ -194,6 +193,7 @@ class GroupCoordinator:
assert self.cpu_group is not None
assert self.device_group is not None
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()

View File

@ -18,7 +18,6 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
@ -1094,6 +1093,7 @@ class EngineArgs:
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
from vllm.platforms import current_platform
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter

View File

@ -8,7 +8,6 @@ import msgspec
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase
@ -229,6 +228,7 @@ def initialize_ray_cluster(
the default Ray cluster address.
"""
assert_ray_available()
from vllm.platforms import current_platform
# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():

View File

@ -6,7 +6,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import CpuArchEnum
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -39,6 +39,7 @@ def maybe_backend_fallback(
if guided_params.backend == "xgrammar":
# xgrammar only has x86 wheels for linux, fallback to outlines
from vllm.platforms import current_platform
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
logger.warning("xgrammar is only supported on x86 CPUs. "
"Falling back to use outlines instead.")

View File

@ -18,7 +18,6 @@ import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
@ -273,6 +272,7 @@ def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
return model.load_model_cls()

View File

@ -3,10 +3,9 @@ from typing import Any, Dict, Optional
import torch
from vllm.platforms import current_platform
def set_random_seed(seed: int) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
@ -38,6 +37,7 @@ def set_weight_attrs(
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)

View File

@ -1,22 +1,33 @@
import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname
from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
from .interface import CpuArchEnum, Platform, PlatformEnum
current_platform: Platform
logger = logging.getLogger(__name__)
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.
def tpu_platform_plugin() -> Optional[str]:
is_tpu = False
try:
# While it's technically possible to install libtpu on a non-TPU machine,
# this is a very uncommon scenario. Therefore, we assume that libtpu is
# installed if and only if the machine has TPUs.
# While it's technically possible to install libtpu on a
# non-TPU machine, this is a very uncommon scenario. Therefore,
# we assume that libtpu is installed if and only if the machine
# has TPUs.
import libtpu # noqa: F401
is_tpu = True
except Exception:
pass
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
def cuda_platform_plugin() -> Optional[str]:
is_cuda = False
try:
@ -38,6 +49,10 @@ except Exception:
if cuda_is_jetson():
is_cuda = True
return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
def rocm_platform_plugin() -> Optional[str]:
is_rocm = False
try:
@ -51,6 +66,10 @@ try:
except Exception:
pass
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
def hpu_platform_plugin() -> Optional[str]:
is_hpu = False
try:
from importlib import util
@ -58,6 +77,10 @@ try:
except Exception:
pass
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
def xpu_platform_plugin() -> Optional[str]:
is_xpu = False
try:
@ -70,6 +93,10 @@ try:
except Exception:
pass
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
def cpu_platform_plugin() -> Optional[str]:
is_cpu = False
try:
from importlib.metadata import version
@ -77,6 +104,10 @@ try:
except Exception:
pass
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
def neuron_platform_plugin() -> Optional[str]:
is_neuron = False
try:
import transformers_neuronx # noqa: F401
@ -84,6 +115,10 @@ try:
except ImportError:
pass
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
def openvino_platform_plugin() -> Optional[str]:
is_openvino = False
try:
from importlib.metadata import version
@ -91,33 +126,98 @@ try:
except Exception:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_hpu:
from .hpu import HpuPlatform
current_platform = HpuPlatform()
elif is_xpu:
from .xpu import XPUPlatform
current_platform = XPUPlatform()
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
else:
current_platform = UnspecifiedPlatform()
return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None
__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
builtin_platform_plugins = {
'tpu': tpu_platform_plugin,
'cuda': cuda_platform_plugin,
'rocm': rocm_platform_plugin,
'hpu': hpu_platform_plugin,
'xpu': xpu_platform_plugin,
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
'openvino': openvino_platform_plugin,
}
def resolve_current_platform_cls_qualname() -> str:
platform_plugins = load_plugins_by_group('vllm.platform_plugins')
activated_plugins = []
for name, func in chain(builtin_platform_plugins.items(),
platform_plugins.items()):
try:
assert callable(func)
platform_cls_qualname = func()
if platform_cls_qualname is not None:
activated_plugins.append(name)
except Exception:
pass
activated_builtin_plugins = list(
set(activated_plugins) & set(builtin_platform_plugins.keys()))
activated_oot_plugins = list(
set(activated_plugins) & set(platform_plugins.keys()))
if len(activated_oot_plugins) >= 2:
raise RuntimeError(
"Only one platform plugin can be activated, but got: "
f"{activated_oot_plugins}")
elif len(activated_oot_plugins) == 1:
platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
logger.info("Platform plugin %s is activated",
activated_oot_plugins[0])
elif len(activated_builtin_plugins) >= 2:
raise RuntimeError(
"Only one platform plugin can be activated, but got: "
f"{activated_builtin_plugins}")
elif len(activated_builtin_plugins) == 1:
platform_cls_qualname = builtin_platform_plugins[
activated_builtin_plugins[0]]()
logger.info("Automatically detected platform %s.",
activated_builtin_plugins[0])
else:
platform_cls_qualname = "vllm.interface.UnspecifiedPlatform"
logger.info(
"No platform detected, vLLM is running on UnspecifiedPlatform")
return platform_cls_qualname
_current_platform = None
_init_trace: str = ''
if TYPE_CHECKING:
current_platform: Platform
def __getattr__(name: str):
if name == 'current_platform':
# lazy init current_platform.
# 1. out-of-tree platform plugins need `from vllm.platforms import
# Platform` so that they can inherit `Platform` class. Therefore,
# we cannot resolve `current_platform` during the import of
# `vllm.platforms`.
# 2. when users use out-of-tree platform plugins, they might run
# `import vllm`, some vllm internal code might access
# `current_platform` during the import, and we need to make sure
# `current_platform` is only resolved after the plugins are loaded
# (we have tests for this, if any developer violate this, they will
# see the test failures).
global _current_platform
if _current_platform is None:
platform_cls_qualname = resolve_current_platform_cls_qualname()
_current_platform = resolve_obj_by_qualname(
platform_cls_qualname)()
global _init_trace
_init_trace = "".join(traceback.format_stack())
return _current_platform
else:
return globals()[name]
__all__ = [
'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
"_init_trace"
]

View File

@ -1,10 +1,10 @@
import logging
import os
from typing import Callable, Dict
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
@ -12,6 +12,39 @@ logger = logging.getLogger(__name__)
plugins_loaded = False
def load_plugins_by_group(group: str) -> Dict[str, Callable]:
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
allowed_plugins = envs.VLLM_PLUGINS
discovered_plugins = entry_points(group=group)
if len(discovered_plugins) == 0:
logger.debug("No plugins for group %s found.", group)
return {}
logger.info("Available plugins for group %s:", group)
for plugin in discovered_plugins:
logger.info("name=%s, value=%s", plugin.name, plugin.value)
if allowed_plugins is None:
logger.info("all available plugins for group %s will be loaded.",
group)
logger.info("set environment variable VLLM_PLUGINS to control"
" which plugins to load.")
plugins = {}
for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins:
try:
func = plugin.load()
plugins[plugin.name] = func
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
return plugins
def load_general_plugins():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
@ -26,6 +59,9 @@ def load_general_plugins():
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa
os.environ['TORCH_COMPILE_DISABLE'] = 'True'
@ -47,33 +83,7 @@ def load_general_plugins():
if plugins_loaded:
return
plugins_loaded = True
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
allowed_plugins = envs.VLLM_PLUGINS
discovered_plugins = entry_points(group='vllm.general_plugins')
if len(discovered_plugins) == 0:
logger.debug("No plugins found.")
return
logger.info("Available plugins:")
for plugin in discovered_plugins:
logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value,
plugin.group)
if allowed_plugins is None:
logger.info("all available plugins will be loaded.")
logger.info("set environment variable VLLM_PLUGINS to control"
" which plugins to load.")
else:
logger.info("plugins to load: %s", allowed_plugins)
for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins:
try:
func = plugin.load()
plugins = load_plugins_by_group(group='vllm.general_plugins')
# general plugins, we only need to execute the loaded functions
for func in plugins.values():
func()
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)

View File

@ -6,7 +6,6 @@ import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
@ -94,6 +93,7 @@ class AsyncMetricsCollector:
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
from vllm.platforms import current_platform
if not current_platform.is_cuda_alike():
return None

View File

@ -17,7 +17,6 @@ import torch
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.platforms import current_platform
from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT
@ -152,6 +151,7 @@ class UsageMessage:
usage_context: UsageContext,
extra_kvs: Dict[str, Any]) -> None:
# Platform information
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count()

View File

@ -50,7 +50,6 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -609,6 +608,7 @@ def create_kv_caches_with_random_flash(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@ -650,7 +650,7 @@ def create_kv_caches_with_random(
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
@ -703,6 +703,7 @@ def print_warning_once(msg: str) -> None:
@lru_cache(maxsize=None)
def is_pin_memory_available() -> bool:
from vllm.platforms import current_platform
return current_platform.is_pin_memory_available()
@ -713,6 +714,7 @@ class DeviceMemoryProfiler:
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
@ -1066,6 +1068,7 @@ def _cuda_device_count_stateless(
import torch.cuda
import torch.version
from vllm.platforms import current_platform
if not torch.cuda._is_compiled():
return 0
if current_platform.is_rocm():
@ -1673,6 +1676,7 @@ def direct_register_custom_op(
return
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "

View File

@ -12,7 +12,6 @@ from torch import is_tensor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
if TYPE_CHECKING:
@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@current_platform.inference_mode()
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.

View File

@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
model_input.record_step_event(current_stream)
if get_pp_group().is_last_rank and self.is_driver_worker:
assert isinstance(output, list)
assert len(
output
) == 1, "MultiStepModelRunner requires single-step base_models"

View File

@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)
@ -44,6 +43,8 @@ class WorkerBase(ABC):
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.kv_transfer_config = vllm_config.kv_transfer_config
from vllm.platforms import current_platform
self.current_platform = current_platform
@abstractmethod
def init_device(self) -> None:
@ -74,13 +75,13 @@ class WorkerBase(ABC):
"""
raise NotImplementedError
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(