[Core] Move function tracing setup to util function (#4352)

This commit is contained in:
Nick Hill 2024-04-25 16:45:12 -07:00 committed by GitHub
parent 15e7c675b0
commit efffb63f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 15 deletions

View File

@ -1,10 +1,13 @@
import asyncio import asyncio
import datetime
import enum import enum
import gc import gc
import glob import glob
import os import os
import socket import socket
import subprocess import subprocess
import tempfile
import threading
import uuid import uuid
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -18,7 +21,7 @@ import psutil
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
from vllm.logger import init_logger from vllm.logger import enable_trace_function_call, init_logger
T = TypeVar("T") T = TypeVar("T")
logger = init_logger(__name__) logger = init_logger(__name__)
@ -607,3 +610,19 @@ def find_nccl_library():
raise ValueError("NCCL only supports CUDA and ROCm backends.") raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}") logger.info(f"Found nccl from library {so_file}")
return so_file return so_file
def enable_trace_function_call_for_thread() -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)

View File

@ -1,15 +1,13 @@
import datetime
import importlib import importlib
import os import os
import tempfile
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_vllm_instance_id, update_environment_variables from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -128,15 +126,7 @@ class WorkerWrapperBase:
function tracing if required. function tracing if required.
Arguments are passed to the worker class constructor. Arguments are passed to the worker class constructor.
""" """
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): enable_trace_function_call_for_thread()
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
mod = importlib.import_module(self.worker_module_name) mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name) worker_class = getattr(mod, self.worker_class_name)