34 lines
836 B
Python
34 lines
836 B
Python
import importlib
|
|
import traceback
|
|
from typing import Callable
|
|
from unittest.mock import patch
|
|
|
|
|
|
def find_cuda_init(fn: Callable[[], object]) -> None:
|
|
"""
|
|
Helper function to debug CUDA re-initialization errors.
|
|
|
|
If `fn` initializes CUDA, prints the stack trace of how this happens.
|
|
"""
|
|
from torch.cuda import _lazy_init
|
|
|
|
stack = None
|
|
|
|
def wrapper():
|
|
nonlocal stack
|
|
stack = traceback.extract_stack()
|
|
return _lazy_init()
|
|
|
|
with patch("torch.cuda._lazy_init", wrapper):
|
|
fn()
|
|
|
|
if stack is not None:
|
|
print("==== CUDA Initialized ====")
|
|
print("".join(traceback.format_list(stack)).strip())
|
|
print("==========================")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
find_cuda_init(
|
|
lambda: importlib.import_module("vllm.model_executor.models.llava"))
|