[Core] add an option to log every function call to for debugging hang/crash in distributed inference (#4079)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
8f9c28fd40
commit
8a7a3e4436
@ -40,7 +40,7 @@ steps:
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
|
||||
|
||||
- label: Entrypoints Test
|
||||
commands:
|
||||
|
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@ -57,6 +57,8 @@ body:
|
||||
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
||||
|
||||
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||
|
||||
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
|
27
tests/test_logger.py
Normal file
27
tests/test_logger.py
Normal file
@ -0,0 +1,27 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from vllm.logger import enable_trace_function_call
|
||||
|
||||
|
||||
def f1(x):
|
||||
return f2(x)
|
||||
|
||||
|
||||
def f2(x):
|
||||
return x
|
||||
|
||||
|
||||
def test_trace_function_call():
|
||||
fd, path = tempfile.mkstemp()
|
||||
cur_dir = os.path.dirname(__file__)
|
||||
enable_trace_function_call(path, cur_dir)
|
||||
f1(1)
|
||||
with open(path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
assert "f1" in content
|
||||
assert "f2" in content
|
||||
sys.settrace(None)
|
||||
os.remove(path)
|
@ -10,7 +10,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
get_vllm_instance_id, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -133,12 +133,18 @@ class RayGPUExecutor(ExecutorBase):
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver and workers.
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = []
|
||||
for (node_id, _) in worker_node_and_gpu_ids:
|
||||
all_args_to_update_environment_variables.append([{
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id]))
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
os.getenv("VLLM_TRACE_FUNCTION", "0"),
|
||||
}])
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=all_args_to_update_environment_variables)
|
||||
|
@ -1,9 +1,11 @@
|
||||
# Adapted from
|
||||
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
"""Logging configuration for vLLM."""
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
||||
@ -65,3 +67,53 @@ def init_logger(name: str):
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _trace_calls(log_path, root_dir, frame, event, arg=None):
|
||||
if event in ['call', 'return']:
|
||||
# Extract the filename, line number, function name, and the code object
|
||||
filename = frame.f_code.co_filename
|
||||
lineno = frame.f_lineno
|
||||
func_name = frame.f_code.co_name
|
||||
if not filename.startswith(root_dir):
|
||||
# only log the functions in the vllm root_dir
|
||||
return
|
||||
# Log every function call or return
|
||||
try:
|
||||
with open(log_path, 'a') as f:
|
||||
if event == 'call':
|
||||
f.write(f"{datetime.datetime.now()} Call to"
|
||||
f" {func_name} in {filename}:{lineno}\n")
|
||||
else:
|
||||
f.write(f"{datetime.datetime.now()} Return from"
|
||||
f" {func_name} in {filename}:{lineno}\n")
|
||||
except NameError:
|
||||
# modules are deleted during shutdown
|
||||
pass
|
||||
return partial(_trace_calls, log_path, root_dir)
|
||||
|
||||
|
||||
def enable_trace_function_call(log_file_path: str,
|
||||
root_dir: Optional[str] = None):
|
||||
"""
|
||||
Enable tracing of every function call in code under `root_dir`.
|
||||
This is useful for debugging hangs or crashes.
|
||||
`log_file_path` is the path to the log file.
|
||||
`root_dir` is the root directory of the code to trace. If None, it is the
|
||||
vllm root directory.
|
||||
|
||||
Note that this call is thread-level, any threads calling this function
|
||||
will have the trace enabled. Other threads will not be affected.
|
||||
"""
|
||||
logger.warning(
|
||||
"VLLM_TRACE_FUNCTION is enabled. It will record every"
|
||||
" function executed by Python. This will slow down the code. It "
|
||||
"is suggested to be used for debugging hang or crashes only.")
|
||||
logger.info(f"Trace frame log is saved to {log_file_path}")
|
||||
if root_dir is None:
|
||||
# by default, this is the vllm root directory
|
||||
root_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.settrace(partial(_trace_calls, log_file_path, root_dir))
|
||||
|
@ -163,6 +163,17 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_vllm_instance_id():
|
||||
"""
|
||||
If the environment variable VLLM_INSTANCE_ID is set, return it.
|
||||
Otherwise, return a random UUID.
|
||||
Instance id represents an instance of the VLLM. All processes in the same
|
||||
instance should have the same instance id.
|
||||
"""
|
||||
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
@ -274,7 +285,7 @@ def get_open_port() -> int:
|
||||
|
||||
def update_environment_variables(envs: Dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
if k in os.environ:
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(f"Overwriting environment variable {k} "
|
||||
f"from '{os.environ[k]}' to '{v}'")
|
||||
os.environ[k] = v
|
||||
|
@ -1,12 +1,15 @@
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import update_environment_variables
|
||||
from vllm.utils import get_vllm_instance_id, update_environment_variables
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -115,9 +118,20 @@ class WorkerWrapperBase:
|
||||
|
||||
def init_worker(self, *args, **kwargs):
|
||||
"""
|
||||
Actual initialization of the worker class.
|
||||
Actual initialization of the worker class, and set up
|
||||
function tracing if required.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
|
||||
tmp_dir = tempfile.gettempdir()
|
||||
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
|
||||
f"_thread_{threading.get_ident()}_"
|
||||
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
|
||||
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
|
||||
filename)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
mod = importlib.import_module(self.worker_module_name)
|
||||
worker_class = getattr(mod, self.worker_class_name)
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user