diff --git a/vllm/utils.py b/vllm/utils.py index 15c8818c..79ac1db0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,10 +1,13 @@ import asyncio +import datetime import enum import gc import glob import os import socket import subprocess +import tempfile +import threading import uuid import warnings from collections import defaultdict @@ -18,7 +21,7 @@ import psutil import torch from packaging.version import Version, parse -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") logger = init_logger(__name__) @@ -607,3 +610,19 @@ def find_nccl_library(): raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info(f"Found nccl from library {so_file}") return so_file + + +def enable_trace_function_call_for_thread() -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index b5dade0a..0a89e3a7 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,15 +1,13 @@ -import datetime import importlib import os -import tempfile -import threading from abc import ABC, abstractmethod from typing import Dict, List, Set, Tuple -from vllm.logger import enable_trace_function_call, init_logger +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import get_vllm_instance_id, update_environment_variables +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) logger = init_logger(__name__) @@ -128,15 +126,7 @@ class WorkerWrapperBase: function tracing if required. Arguments are passed to the worker class constructor. """ - if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): - tmp_dir = tempfile.gettempdir() - filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" - f"_thread_{threading.get_ident()}_" - f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), - filename) - os.makedirs(os.path.dirname(log_path), exist_ok=True) - enable_trace_function_call(log_path) + enable_trace_function_call_for_thread() mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name)