[platforms] enable platform plugins (#11602)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
5dbf854553
commit
b12e87f942
@ -106,14 +106,12 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -333,8 +331,6 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
@ -469,11 +465,28 @@ steps:
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
commands:
|
||||
# begin platform plugin tests, all the code in-between runs on dummy platform
|
||||
- pip install -e ./plugins/vllm_add_dummy_platform
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
# end platform plugin tests
|
||||
# other tests continue here:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
|
@ -41,9 +41,11 @@ Every plugin has three parts:
|
||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
|
||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
|
||||
|
||||
## What Can Plugins Do?
|
||||
## Types of supported plugins
|
||||
|
||||
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
|
||||
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.
|
||||
|
||||
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
|
||||
|
||||
## Guidelines for Writing Plugins
|
||||
|
||||
|
@ -31,7 +31,6 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
identity)
|
||||
@ -242,6 +241,7 @@ _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
from vllm.platforms import current_platform
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
return x
|
||||
|
||||
|
@ -5,7 +5,10 @@ import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import which_attn_to_use
|
||||
from vllm.platforms import cpu, cuda, openvino, rocm
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.openvino import OpenVinoPlatform
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@ -20,26 +23,23 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cpu.CpuPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
rocm.RocmPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
openvino.OpenVinoPlatform()):
|
||||
OpenVinoPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "OPENVINO"
|
||||
else:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cuda.CudaPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == name
|
||||
|
11
tests/plugins/vllm_add_dummy_platform/setup.py
Normal file
11
tests/plugins/vllm_add_dummy_platform/setup.py
Normal file
@ -0,0 +1,11 @@
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='vllm_add_dummy_platform',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_platform'],
|
||||
entry_points={
|
||||
'vllm.platform_plugins': [
|
||||
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin" # noqa
|
||||
]
|
||||
})
|
@ -0,0 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def dummy_platform_plugin() -> Optional[str]:
|
||||
return "vllm_add_dummy_platform.dummy_platform.DummyPlatform"
|
@ -0,0 +1,5 @@
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
|
||||
|
||||
class DummyPlatform(CudaPlatform):
|
||||
device_name = "DummyDevice"
|
16
tests/plugins_tests/test_platform_plugins.py
Normal file
16
tests/plugins_tests/test_platform_plugins.py
Normal file
@ -0,0 +1,16 @@
|
||||
def test_platform_plugins():
|
||||
# simulate workload by running an example
|
||||
import runpy
|
||||
current_file = __file__
|
||||
import os
|
||||
example_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
|
||||
"examples", "offline_inference.py")
|
||||
runpy.run_path(example_file)
|
||||
|
||||
# check if the plugin is loaded correctly
|
||||
from vllm.platforms import _init_trace, current_platform
|
||||
assert current_platform.device_name == "DummyDevice", (
|
||||
f"Expected DummyDevice, got {current_platform.device_name}, "
|
||||
"possibly because current_platform is imported before the plugin"
|
||||
f" is loaded. The first import:\n{_init_trace}")
|
@ -22,7 +22,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform, interface
|
||||
from vllm.platforms import CpuArchEnum
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -349,6 +349,7 @@ class ModelConfig:
|
||||
self.is_hybrid = self._init_is_hybrid()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_neuron():
|
||||
self.override_neuron_config = override_neuron_config
|
||||
else:
|
||||
@ -589,6 +590,7 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
if self.quantization not in optimized_quantization_methods:
|
||||
logger.warning(
|
||||
@ -644,6 +646,7 @@ class ModelConfig:
|
||||
|
||||
# Reminder: Please update docs/source/usage/compatibility_matrix.md
|
||||
# If the feature combo become valid
|
||||
from vllm.platforms import current_platform
|
||||
if not current_platform.is_async_output_supported(self.enforce_eager):
|
||||
logger.warning(
|
||||
"Async output processing is not supported on the "
|
||||
@ -1012,6 +1015,7 @@ class CacheConfig:
|
||||
raise ValueError(
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
f"{self.gpu_memory_utilization}.")
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.is_cuda() and self.block_size is not None
|
||||
and self.block_size > 32):
|
||||
raise ValueError("CUDA Paged Attention kernel only supports "
|
||||
@ -1279,6 +1283,7 @@ class ParallelConfig:
|
||||
f"distributed executor backend "
|
||||
f"'{self.distributed_executor_backend}'.")
|
||||
ray_only_devices = ["tpu", "hpu"]
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.device_type in ray_only_devices
|
||||
and self.world_size > 1):
|
||||
if self.distributed_executor_backend is None:
|
||||
@ -1327,7 +1332,7 @@ class ParallelConfig:
|
||||
def _verify_args(self) -> None:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if self.distributed_executor_backend not in (
|
||||
"ray", "mp", None) and not (isinstance(
|
||||
self.distributed_executor_backend, type) and issubclass(
|
||||
@ -1528,6 +1533,7 @@ class DeviceConfig:
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError("Failed to infer device type")
|
||||
@ -2241,9 +2247,10 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.is_cpu()
|
||||
and current_platform.get_cpu_architecture()
|
||||
== interface.CpuArchEnum.POWERPC
|
||||
== CpuArchEnum.POWERPC
|
||||
and (config_dtype == torch.float16
|
||||
or config_dtype == torch.float32)):
|
||||
logger.info(
|
||||
@ -3083,6 +3090,7 @@ class VllmConfig:
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
from vllm.platforms import current_platform
|
||||
if model_config.quantization is not None:
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_quant_config)
|
||||
@ -3145,6 +3153,7 @@ class VllmConfig:
|
||||
self.quant_config = VllmConfig._get_quantization_config(
|
||||
self.model_config, self.load_config)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if self.scheduler_config is not None and \
|
||||
self.model_config is not None and \
|
||||
self.scheduler_config.chunked_prefill_enabled and \
|
||||
|
@ -39,7 +39,6 @@ import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -194,6 +193,7 @@ class GroupCoordinator:
|
||||
assert self.cpu_group is not None
|
||||
assert self.device_group is not None
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
@ -1188,6 +1188,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
import ray # Lazy import Ray
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
from vllm.platforms import current_platform
|
||||
if not current_platform.is_cpu():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -18,7 +18,6 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, StoreBoolean
|
||||
@ -1094,6 +1093,7 @@ class EngineArgs:
|
||||
use_sliding_window = (model_config.get_sliding_window()
|
||||
is not None)
|
||||
use_spec_decode = self.speculative_model is not None
|
||||
from vllm.platforms import current_platform
|
||||
if (is_gpu and not use_sliding_window and not use_spec_decode
|
||||
and not self.enable_lora
|
||||
and not self.enable_prompt_adapter
|
||||
|
@ -8,7 +8,6 @@ import msgspec
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -229,6 +228,7 @@ def initialize_ray_cluster(
|
||||
the default Ray cluster address.
|
||||
"""
|
||||
assert_ray_available()
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Connect to a ray cluster.
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
|
@ -6,7 +6,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms import CpuArchEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -39,6 +39,7 @@ def maybe_backend_fallback(
|
||||
|
||||
if guided_params.backend == "xgrammar":
|
||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
||||
logger.warning("xgrammar is only supported on x86 CPUs. "
|
||||
"Falling back to use outlines instead.")
|
||||
|
@ -18,7 +18,6 @@ import cloudpickle
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
@ -273,6 +272,7 @@ def _try_load_model_cls(
|
||||
model_arch: str,
|
||||
model: _BaseRegisteredModel,
|
||||
) -> Optional[Type[nn.Module]]:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.verify_model_arch(model_arch)
|
||||
try:
|
||||
return model.load_model_cls()
|
||||
|
@ -3,10 +3,9 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
|
||||
@ -38,6 +37,7 @@ def set_weight_attrs(
|
||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu() and key == "weight_loader":
|
||||
value = _make_synced_weight_loader(value)
|
||||
setattr(weight, key, value)
|
||||
|
@ -1,123 +1,223 @@
|
||||
import logging
|
||||
import traceback
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
from .interface import _Backend # noqa: F401
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
|
||||
current_platform: Platform
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
|
||||
# they only indicate the build configuration, not the runtime environment.
|
||||
# For example, people can install a cuda build of pytorch but run on tpu.
|
||||
|
||||
is_tpu = False
|
||||
try:
|
||||
# While it's technically possible to install libtpu on a non-TPU machine,
|
||||
# this is a very uncommon scenario. Therefore, we assume that libtpu is
|
||||
# installed if and only if the machine has TPUs.
|
||||
import libtpu # noqa: F401
|
||||
is_tpu = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
is_cuda = False
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
def tpu_platform_plugin() -> Optional[str]:
|
||||
is_tpu = False
|
||||
try:
|
||||
if pynvml.nvmlDeviceGetCount() > 0:
|
||||
# While it's technically possible to install libtpu on a
|
||||
# non-TPU machine, this is a very uncommon scenario. Therefore,
|
||||
# we assume that libtpu is installed if and only if the machine
|
||||
# has TPUs.
|
||||
import libtpu # noqa: F401
|
||||
is_tpu = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
|
||||
|
||||
|
||||
def cuda_platform_plugin() -> Optional[str]:
|
||||
is_cuda = False
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
if pynvml.nvmlDeviceGetCount() > 0:
|
||||
is_cuda = True
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception:
|
||||
# CUDA is supported on Jetson, but NVML may not be.
|
||||
import os
|
||||
|
||||
def cuda_is_jetson() -> bool:
|
||||
return os.path.isfile("/etc/nv_tegra_release") \
|
||||
or os.path.exists("/sys/class/tegra-firmware")
|
||||
|
||||
if cuda_is_jetson():
|
||||
is_cuda = True
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception:
|
||||
# CUDA is supported on Jetson, but NVML may not be.
|
||||
import os
|
||||
|
||||
def cuda_is_jetson() -> bool:
|
||||
return os.path.isfile("/etc/nv_tegra_release") \
|
||||
or os.path.exists("/sys/class/tegra-firmware")
|
||||
return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None
|
||||
|
||||
if cuda_is_jetson():
|
||||
is_cuda = True
|
||||
|
||||
is_rocm = False
|
||||
def rocm_platform_plugin() -> Optional[str]:
|
||||
is_rocm = False
|
||||
|
||||
try:
|
||||
import amdsmi
|
||||
amdsmi.amdsmi_init()
|
||||
try:
|
||||
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
|
||||
is_rocm = True
|
||||
finally:
|
||||
amdsmi.amdsmi_shut_down()
|
||||
except Exception:
|
||||
pass
|
||||
import amdsmi
|
||||
amdsmi.amdsmi_init()
|
||||
try:
|
||||
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
|
||||
is_rocm = True
|
||||
finally:
|
||||
amdsmi.amdsmi_shut_down()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
is_hpu = False
|
||||
try:
|
||||
from importlib import util
|
||||
is_hpu = util.find_spec('habana_frameworks') is not None
|
||||
except Exception:
|
||||
pass
|
||||
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
|
||||
|
||||
is_xpu = False
|
||||
|
||||
try:
|
||||
# installed IPEX if the machine has XPUs.
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch
|
||||
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
is_xpu = True
|
||||
except Exception:
|
||||
pass
|
||||
def hpu_platform_plugin() -> Optional[str]:
|
||||
is_hpu = False
|
||||
try:
|
||||
from importlib import util
|
||||
is_hpu = util.find_spec('habana_frameworks') is not None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
is_cpu = False
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
is_cpu = "cpu" in version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
|
||||
|
||||
is_neuron = False
|
||||
try:
|
||||
import transformers_neuronx # noqa: F401
|
||||
is_neuron = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
is_openvino = False
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
is_openvino = "openvino" in version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
def xpu_platform_plugin() -> Optional[str]:
|
||||
is_xpu = False
|
||||
|
||||
if is_tpu:
|
||||
# people might install pytorch built with cuda but run on tpu
|
||||
# so we need to check tpu first
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
elif is_cuda:
|
||||
from .cuda import CudaPlatform
|
||||
current_platform = CudaPlatform()
|
||||
elif is_rocm:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
elif is_hpu:
|
||||
from .hpu import HpuPlatform
|
||||
current_platform = HpuPlatform()
|
||||
elif is_xpu:
|
||||
from .xpu import XPUPlatform
|
||||
current_platform = XPUPlatform()
|
||||
elif is_cpu:
|
||||
from .cpu import CpuPlatform
|
||||
current_platform = CpuPlatform()
|
||||
elif is_neuron:
|
||||
from .neuron import NeuronPlatform
|
||||
current_platform = NeuronPlatform()
|
||||
elif is_openvino:
|
||||
from .openvino import OpenVinoPlatform
|
||||
current_platform = OpenVinoPlatform()
|
||||
else:
|
||||
current_platform = UnspecifiedPlatform()
|
||||
try:
|
||||
# installed IPEX if the machine has XPUs.
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch
|
||||
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||
is_xpu = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum']
|
||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||
|
||||
|
||||
def cpu_platform_plugin() -> Optional[str]:
|
||||
is_cpu = False
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
is_cpu = "cpu" in version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
|
||||
|
||||
def neuron_platform_plugin() -> Optional[str]:
|
||||
is_neuron = False
|
||||
try:
|
||||
import transformers_neuronx # noqa: F401
|
||||
is_neuron = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None
|
||||
|
||||
|
||||
def openvino_platform_plugin() -> Optional[str]:
|
||||
is_openvino = False
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
is_openvino = "openvino" in version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None
|
||||
|
||||
|
||||
builtin_platform_plugins = {
|
||||
'tpu': tpu_platform_plugin,
|
||||
'cuda': cuda_platform_plugin,
|
||||
'rocm': rocm_platform_plugin,
|
||||
'hpu': hpu_platform_plugin,
|
||||
'xpu': xpu_platform_plugin,
|
||||
'cpu': cpu_platform_plugin,
|
||||
'neuron': neuron_platform_plugin,
|
||||
'openvino': openvino_platform_plugin,
|
||||
}
|
||||
|
||||
|
||||
def resolve_current_platform_cls_qualname() -> str:
|
||||
platform_plugins = load_plugins_by_group('vllm.platform_plugins')
|
||||
|
||||
activated_plugins = []
|
||||
|
||||
for name, func in chain(builtin_platform_plugins.items(),
|
||||
platform_plugins.items()):
|
||||
try:
|
||||
assert callable(func)
|
||||
platform_cls_qualname = func()
|
||||
if platform_cls_qualname is not None:
|
||||
activated_plugins.append(name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
activated_builtin_plugins = list(
|
||||
set(activated_plugins) & set(builtin_platform_plugins.keys()))
|
||||
activated_oot_plugins = list(
|
||||
set(activated_plugins) & set(platform_plugins.keys()))
|
||||
|
||||
if len(activated_oot_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_oot_plugins}")
|
||||
elif len(activated_oot_plugins) == 1:
|
||||
platform_cls_qualname = platform_plugins[activated_oot_plugins[0]]()
|
||||
logger.info("Platform plugin %s is activated",
|
||||
activated_oot_plugins[0])
|
||||
elif len(activated_builtin_plugins) >= 2:
|
||||
raise RuntimeError(
|
||||
"Only one platform plugin can be activated, but got: "
|
||||
f"{activated_builtin_plugins}")
|
||||
elif len(activated_builtin_plugins) == 1:
|
||||
platform_cls_qualname = builtin_platform_plugins[
|
||||
activated_builtin_plugins[0]]()
|
||||
logger.info("Automatically detected platform %s.",
|
||||
activated_builtin_plugins[0])
|
||||
else:
|
||||
platform_cls_qualname = "vllm.interface.UnspecifiedPlatform"
|
||||
logger.info(
|
||||
"No platform detected, vLLM is running on UnspecifiedPlatform")
|
||||
return platform_cls_qualname
|
||||
|
||||
|
||||
_current_platform = None
|
||||
_init_trace: str = ''
|
||||
|
||||
if TYPE_CHECKING:
|
||||
current_platform: Platform
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == 'current_platform':
|
||||
# lazy init current_platform.
|
||||
# 1. out-of-tree platform plugins need `from vllm.platforms import
|
||||
# Platform` so that they can inherit `Platform` class. Therefore,
|
||||
# we cannot resolve `current_platform` during the import of
|
||||
# `vllm.platforms`.
|
||||
# 2. when users use out-of-tree platform plugins, they might run
|
||||
# `import vllm`, some vllm internal code might access
|
||||
# `current_platform` during the import, and we need to make sure
|
||||
# `current_platform` is only resolved after the plugins are loaded
|
||||
# (we have tests for this, if any developer violate this, they will
|
||||
# see the test failures).
|
||||
global _current_platform
|
||||
if _current_platform is None:
|
||||
platform_cls_qualname = resolve_current_platform_cls_qualname()
|
||||
_current_platform = resolve_obj_by_qualname(
|
||||
platform_cls_qualname)()
|
||||
global _init_trace
|
||||
_init_trace = "".join(traceback.format_stack())
|
||||
return _current_platform
|
||||
else:
|
||||
return globals()[name]
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum',
|
||||
"_init_trace"
|
||||
]
|
||||
|
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -12,6 +12,39 @@ logger = logging.getLogger(__name__)
|
||||
plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins_by_group(group: str) -> Dict[str, Callable]:
|
||||
import sys
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
allowed_plugins = envs.VLLM_PLUGINS
|
||||
|
||||
discovered_plugins = entry_points(group=group)
|
||||
if len(discovered_plugins) == 0:
|
||||
logger.debug("No plugins for group %s found.", group)
|
||||
return {}
|
||||
logger.info("Available plugins for group %s:", group)
|
||||
for plugin in discovered_plugins:
|
||||
logger.info("name=%s, value=%s", plugin.name, plugin.value)
|
||||
if allowed_plugins is None:
|
||||
logger.info("all available plugins for group %s will be loaded.",
|
||||
group)
|
||||
logger.info("set environment variable VLLM_PLUGINS to control"
|
||||
" which plugins to load.")
|
||||
plugins = {}
|
||||
for plugin in discovered_plugins:
|
||||
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||
try:
|
||||
func = plugin.load()
|
||||
plugins[plugin.name] = func
|
||||
logger.info("plugin %s loaded.", plugin.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
return plugins
|
||||
|
||||
|
||||
def load_general_plugins():
|
||||
"""WARNING: plugins can be loaded for multiple times in different
|
||||
processes. They should be designed in a way that they can be loaded
|
||||
@ -26,6 +59,9 @@ def load_general_plugins():
|
||||
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
|
||||
# see https://github.com/vllm-project/vllm/issues/10619
|
||||
torch._inductor.config.compile_threads = 1
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
# see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa
|
||||
os.environ['TORCH_COMPILE_DISABLE'] = 'True'
|
||||
@ -47,33 +83,7 @@ def load_general_plugins():
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
import sys
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
allowed_plugins = envs.VLLM_PLUGINS
|
||||
|
||||
discovered_plugins = entry_points(group='vllm.general_plugins')
|
||||
if len(discovered_plugins) == 0:
|
||||
logger.debug("No plugins found.")
|
||||
return
|
||||
logger.info("Available plugins:")
|
||||
for plugin in discovered_plugins:
|
||||
logger.info("name=%s, value=%s, group=%s", plugin.name, plugin.value,
|
||||
plugin.group)
|
||||
if allowed_plugins is None:
|
||||
logger.info("all available plugins will be loaded.")
|
||||
logger.info("set environment variable VLLM_PLUGINS to control"
|
||||
" which plugins to load.")
|
||||
else:
|
||||
logger.info("plugins to load: %s", allowed_plugins)
|
||||
for plugin in discovered_plugins:
|
||||
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||
try:
|
||||
func = plugin.load()
|
||||
func()
|
||||
logger.info("plugin %s loaded.", plugin.name)
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
plugins = load_plugins_by_group(group='vllm.general_plugins')
|
||||
# general plugins, we only need to execute the loaded functions
|
||||
for func in plugins.values():
|
||||
func()
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -94,6 +93,7 @@ class AsyncMetricsCollector:
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
# currently using cuda.Event, skip for any non_cuda_alike platform
|
||||
from vllm.platforms import current_platform
|
||||
if not current_platform.is_cuda_alike():
|
||||
return None
|
||||
|
||||
|
@ -17,7 +17,6 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
_config_home = envs.VLLM_CONFIG_ROOT
|
||||
@ -152,6 +151,7 @@ class UsageMessage:
|
||||
usage_context: UsageContext,
|
||||
extra_kvs: Dict[str, Any]) -> None:
|
||||
# Platform information
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
device_property = torch.cuda.get_device_properties(0)
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
|
@ -50,7 +50,6 @@ from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -609,6 +608,7 @@ def create_kv_caches_with_random_flash(
|
||||
seed: int = 0,
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
@ -650,7 +650,7 @@ def create_kv_caches_with_random(
|
||||
raise ValueError(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
@ -703,6 +703,7 @@ def print_warning_once(msg: str) -> None:
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def is_pin_memory_available() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
return current_platform.is_pin_memory_available()
|
||||
|
||||
|
||||
@ -713,6 +714,7 @@ class DeviceMemoryProfiler:
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
torch.cuda.reset_peak_memory_stats(self.device)
|
||||
mem = torch.cuda.max_memory_allocated(self.device)
|
||||
@ -1066,6 +1068,7 @@ def _cuda_device_count_stateless(
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if current_platform.is_rocm():
|
||||
@ -1673,6 +1676,7 @@ def direct_register_custom_op(
|
||||
return
|
||||
|
||||
if not supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
assert not current_platform.is_cuda_alike(), (
|
||||
"cuda platform needs torch>=2.4 to support custom op, "
|
||||
"chances are you are using an old version of pytorch "
|
||||
|
@ -12,7 +12,6 @@ from torch import is_tensor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -265,13 +264,13 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@current_platform.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute the model on the given input.
|
||||
|
@ -544,6 +544,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
|
||||
model_input.record_step_event(current_stream)
|
||||
|
||||
if get_pp_group().is_last_rank and self.is_driver_worker:
|
||||
assert isinstance(output, list)
|
||||
assert len(
|
||||
output
|
||||
) == 1, "MultiStepModelRunner requires single-step base_models"
|
||||
|
@ -11,7 +11,6 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, update_environment_variables)
|
||||
@ -44,6 +43,8 @@ class WorkerBase(ABC):
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
from vllm.platforms import current_platform
|
||||
self.current_platform = current_platform
|
||||
|
||||
@abstractmethod
|
||||
def init_device(self) -> None:
|
||||
@ -74,17 +75,17 @@ class WorkerBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@current_platform.inference_mode()
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
while True:
|
||||
output = self.execute_model(execute_model_req=None)
|
||||
if output is None:
|
||||
return None
|
||||
with self.current_platform.inference_mode():
|
||||
while True:
|
||||
output = self.execute_model(execute_model_req=None)
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
@ -352,6 +353,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
|
Loading…
x
Reference in New Issue
Block a user