[core][bugfix] configure env var during import vllm (#12209)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-20 19:35:59 +08:00 committed by GitHub
parent 170eb35079
commit c222f47992
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 37 additions and 45 deletions

View File

@ -19,7 +19,7 @@ from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams, configure_as_vllm_process
from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port
from vllm.worker.worker import Worker
@ -98,12 +98,7 @@ class MyLLM(LLM):
"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
It is important for all the processes outside of vLLM to call
`configure_as_vllm_process` to set some common environment variables
the same as vLLM workers.
"""
configure_as_vllm_process()
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")

View File

@ -1,4 +1,7 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
import os
import torch
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -17,16 +20,10 @@ from vllm.sampling_params import SamplingParams
from .version import __version__, __version_tuple__
def configure_as_vllm_process():
"""
set some common config/environment variables that should be set
for all processes created by vllm and all processes
that interact with vllm workers.
"""
import os
import torch
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
@ -36,25 +33,6 @@ def configure_as_vllm_process():
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
__all__ = [
"__version__",
"__version_tuple__",
@ -80,5 +58,4 @@ __all__ = [
"AsyncEngineArgs",
"initialize_ray_cluster",
"PoolingParams",
"configure_as_vllm_process",
]

View File

@ -1,6 +1,9 @@
import logging
import os
from typing import Callable, Dict
import torch
import vllm.envs as envs
logger = logging.getLogger(__name__)
@ -51,6 +54,26 @@ def load_general_plugins():
if plugins_loaded:
return
plugins_loaded = True
# some platform-specific configurations
from vllm.platforms import current_platform
if current_platform.is_xpu():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch._dynamo.config.disable = True
elif current_platform.is_hpu():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
if is_lazy:
torch._dynamo.config.disable = True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
plugins = load_plugins_by_group(group='vllm.general_plugins')
# general plugins, we only need to execute the loaded functions
for func in plugins.values():

View File

@ -535,9 +535,6 @@ class WorkerWrapperBase:
kwargs = all_kwargs[self.rpc_rank]
enable_trace_function_call_for_thread(self.vllm_config)
from vllm import configure_as_vllm_process
configure_as_vllm_process()
from vllm.plugins import load_general_plugins
load_general_plugins()