[Core] Move function tracing setup to util function (#4352)

This commit is contained in:
Nick Hill 2024-04-25 16:45:12 -07:00 committed by GitHub
parent 15e7c675b0
commit efffb63f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 15 deletions

View File

@ -1,10 +1,13 @@
import asyncio
import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
import tempfile
import threading
import uuid
import warnings
from collections import defaultdict
@ -18,7 +21,7 @@ import psutil
import torch
from packaging.version import Version, parse
from vllm.logger import init_logger
from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T")
logger = init_logger(__name__)
@ -607,3 +610,19 @@ def find_nccl_library():
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file
def enable_trace_function_call_for_thread() -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)

View File

@ -1,15 +1,13 @@
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple
from vllm.logger import enable_trace_function_call, init_logger
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_vllm_instance_id, update_environment_variables
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
logger = init_logger(__name__)
@ -128,15 +126,7 @@ class WorkerWrapperBase:
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
enable_trace_function_call_for_thread()
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)