[V1][Frontend] Improve Shutdown And Logs (#11737)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
3c776dcefb
commit
2b05b8ce69
@ -552,6 +552,7 @@ steps:
|
|||||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
|
|
||||||
- label: Plugin Tests (2 GPUs) # 40min
|
- label: Plugin Tests (2 GPUs) # 40min
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
97
tests/v1/shutdown/test_delete.py
Normal file
97
tests/v1/shutdown/test_delete.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Test that we handle a startup Error and shutdown."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
|
MODELS = ["meta-llama/Llama-3.2-1B"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("send_one_request", [False, True])
|
||||||
|
async def test_async_llm_delete(model: str, tensor_parallel_size: int,
|
||||||
|
send_one_request: bool) -> None:
|
||||||
|
"""Test that AsyncLLM frees GPU memory upon deletion.
|
||||||
|
AsyncLLM always uses an MP client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
tensor_parallel_size: degree of tensor parallelism
|
||||||
|
send_one_request: send one request to engine before deleting
|
||||||
|
"""
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
|
||||||
|
# Instantiate AsyncLLM; make request to complete any deferred
|
||||||
|
# initialization; then delete instance
|
||||||
|
async_llm = AsyncLLM.from_engine_args(engine_args)
|
||||||
|
if send_one_request:
|
||||||
|
async for _ in async_llm.generate(
|
||||||
|
"Hello my name is",
|
||||||
|
request_id="abc",
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=1, output_kind=RequestOutputKind.DELTA)):
|
||||||
|
pass
|
||||||
|
del async_llm
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("enable_multiprocessing", [True])
|
||||||
|
@pytest.mark.parametrize("send_one_request", [False, True])
|
||||||
|
def test_llm_delete(monkeypatch, model: str, tensor_parallel_size: int,
|
||||||
|
enable_multiprocessing: bool,
|
||||||
|
send_one_request: bool) -> None:
|
||||||
|
"""Test that LLM frees GPU memory upon deletion.
|
||||||
|
TODO(andy) - LLM without multiprocessing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
tensor_parallel_size: degree of tensor parallelism
|
||||||
|
enable_multiprocessing: enable workers in separate process(es)
|
||||||
|
send_one_request: send one request to engine before deleting
|
||||||
|
"""
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
MP_VALUE = "1" if enable_multiprocessing else "0"
|
||||||
|
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
|
||||||
|
|
||||||
|
# Instantiate LLM; make request to complete any deferred
|
||||||
|
# initialization; then delete instance
|
||||||
|
llm = LLM(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
if send_one_request:
|
||||||
|
llm.generate("Hello my name is",
|
||||||
|
sampling_params=SamplingParams(max_tokens=1))
|
||||||
|
del llm
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
)
|
129
tests/v1/shutdown/test_forward_error.py
Normal file
129
tests/v1/shutdown/test_forward_error.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Test that we handle an Error in model forward and shutdown."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
from vllm import LLM, AsyncEngineArgs, SamplingParams
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
|
|
||||||
|
MODELS = ["meta-llama/Llama-3.2-1B"]
|
||||||
|
|
||||||
|
|
||||||
|
def evil_forward(self, *args, **kwargs):
|
||||||
|
"""Evil forward method that raise an exception after 10 calls."""
|
||||||
|
NUMBER_OF_GOOD_PASSES = 10
|
||||||
|
|
||||||
|
if not hasattr(self, "num_calls"):
|
||||||
|
self.num_calls = 0
|
||||||
|
|
||||||
|
if (self.num_calls == NUMBER_OF_GOOD_PASSES
|
||||||
|
and get_tensor_model_parallel_rank() == 0):
|
||||||
|
raise Exception("Simulated illegal memory access on Rank 0!")
|
||||||
|
self.num_calls += 1
|
||||||
|
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,
|
||||||
|
model: str) -> None:
|
||||||
|
"""Test that AsyncLLM propagates a forward pass error and frees memory.
|
||||||
|
|
||||||
|
AsyncLLM always uses an MP client.
|
||||||
|
"""
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
# Monkeypatch an error in the model.
|
||||||
|
monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
async_llm = AsyncLLM.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
async def generate(request_id: str):
|
||||||
|
generator = async_llm.generate("Hello my name is",
|
||||||
|
request_id=request_id,
|
||||||
|
sampling_params=SamplingParams())
|
||||||
|
try:
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
|
||||||
|
NUM_REQS = 3
|
||||||
|
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
|
||||||
|
outputs = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Every request should get an EngineDeadError.
|
||||||
|
for output in outputs:
|
||||||
|
assert isinstance(output, EngineDeadError)
|
||||||
|
|
||||||
|
# AsyncLLM should be errored.
|
||||||
|
assert async_llm.errored
|
||||||
|
|
||||||
|
# We should not be able to make another request.
|
||||||
|
with pytest.raises(EngineDeadError):
|
||||||
|
async for _ in async_llm.generate("Hello my name is",
|
||||||
|
request_id="abc",
|
||||||
|
sampling_params=SamplingParams()):
|
||||||
|
raise Exception("We should not get here.")
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=2 * 2**30,
|
||||||
|
timeout_s=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: shutdown is handled by the API Server if an exception
|
||||||
|
# occurs, so it is expected that we would need to call this.
|
||||||
|
async_llm.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("enable_multiprocessing", [True])
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
def test_llm_model_error(monkeypatch, tensor_parallel_size: int,
|
||||||
|
enable_multiprocessing: bool, model: str) -> None:
|
||||||
|
"""Test that LLM propagates a forward pass error and frees memory.
|
||||||
|
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
|
||||||
|
and >1 rank
|
||||||
|
"""
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
|
||||||
|
MP_VALUE = "1" if enable_multiprocessing else "0"
|
||||||
|
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
|
||||||
|
|
||||||
|
# Monkeypatch an error in the model.
|
||||||
|
m.setattr(LlamaForCausalLM, "forward", evil_forward)
|
||||||
|
|
||||||
|
llm = LLM(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
EngineDeadError if enable_multiprocessing else Exception):
|
||||||
|
llm.generate("Hello my name is Robert and I")
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
)
|
69
tests/v1/shutdown/test_processor_error.py
Normal file
69
tests/v1/shutdown/test_processor_error.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Test error handling in Processor. Should not impact other reqs."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.inputs.data import TokensPrompt
|
||||||
|
from vllm.sampling_params import RequestOutputKind
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
from vllm.v1.engine.exceptions import EngineGenerateError
|
||||||
|
|
||||||
|
MODELS = ["meta-llama/Llama-3.2-1B"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
async def test_async_llm_processor_error(model: str) -> None:
|
||||||
|
"""Test that AsyncLLM propagates a processor error.
|
||||||
|
Test empty tokens prompt (failure) and non-empty prompt (no failure.)
|
||||||
|
AsyncLLM always uses an MP client.
|
||||||
|
"""
|
||||||
|
engine_args = AsyncEngineArgs(model=model, enforce_eager=True)
|
||||||
|
async_llm = AsyncLLM.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
async def generate(request_id: str):
|
||||||
|
# [] is not allowed and will raise a ValueError in Processor.
|
||||||
|
generator = async_llm.generate(TokensPrompt([]),
|
||||||
|
request_id=request_id,
|
||||||
|
sampling_params=SamplingParams())
|
||||||
|
try:
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
return e
|
||||||
|
|
||||||
|
NUM_REQS = 3
|
||||||
|
tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]
|
||||||
|
outputs = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Every request should have get an EngineGenerateError.
|
||||||
|
for output in outputs:
|
||||||
|
with pytest.raises(EngineGenerateError):
|
||||||
|
raise output
|
||||||
|
|
||||||
|
# AsyncLLM should be errored.
|
||||||
|
assert not async_llm.errored
|
||||||
|
|
||||||
|
# This should be no problem.
|
||||||
|
EXPECTED_TOKENS = 5
|
||||||
|
outputs = []
|
||||||
|
async for out in async_llm.generate(
|
||||||
|
"Hello my name is",
|
||||||
|
request_id="abc",
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=EXPECTED_TOKENS,
|
||||||
|
output_kind=RequestOutputKind.DELTA)):
|
||||||
|
outputs.append(out)
|
||||||
|
|
||||||
|
generated_tokens = []
|
||||||
|
for out in outputs:
|
||||||
|
generated_tokens.extend(out.outputs[0].token_ids)
|
||||||
|
assert len(generated_tokens) == EXPECTED_TOKENS
|
||||||
|
|
||||||
|
async_llm.shutdown()
|
97
tests/v1/shutdown/test_startup_error.py
Normal file
97
tests/v1/shutdown/test_startup_error.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Test that we handle a startup Error and shutdown."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.utils import wait_for_gpu_memory_to_clear
|
||||||
|
from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
|
MODELS = ["meta-llama/Llama-3.2-1B"]
|
||||||
|
|
||||||
|
|
||||||
|
def evil_method(self, *args, **kwargs):
|
||||||
|
"""Evil method that raises an exception."""
|
||||||
|
|
||||||
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
|
raise Exception("Simulated Error in startup!")
|
||||||
|
|
||||||
|
return self.model(*args, **kwargs, intermediate_tensors=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
|
||||||
|
def test_async_llm_startup_error(monkeypatch, model: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
failing_method: str) -> None:
|
||||||
|
"""Test that AsyncLLM propagates an __init__ error & frees memory.
|
||||||
|
Test profiling (forward()) and load weights failures.
|
||||||
|
AsyncLLM always uses an MP client.
|
||||||
|
"""
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
# Monkeypatch an error in the model.
|
||||||
|
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
|
||||||
|
|
||||||
|
engine_args = AsyncEngineArgs(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
|
||||||
|
# Confirm we get an exception.
|
||||||
|
with pytest.raises(Exception, match="initialization failed"):
|
||||||
|
_ = AsyncLLM.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [2, 1])
|
||||||
|
@pytest.mark.parametrize("enable_multiprocessing", [True])
|
||||||
|
@pytest.mark.parametrize("failing_method", ["forward", "load_weights"])
|
||||||
|
def test_llm_startup_error(monkeypatch, model: str, tensor_parallel_size: int,
|
||||||
|
enable_multiprocessing: bool,
|
||||||
|
failing_method: str) -> None:
|
||||||
|
"""Test that LLM propagates an __init__ error and frees memory.
|
||||||
|
Test profiling (forward()) and load weights failures.
|
||||||
|
TODO(andy) - LLM without multiprocessing.
|
||||||
|
"""
|
||||||
|
if model != "meta-llama/Llama-3.2-1B":
|
||||||
|
pytest.skip(reason="Only test meta-llama/Llama-3.2-1B")
|
||||||
|
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||||
|
pytest.skip(reason="Not enough CUDA devices")
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
|
||||||
|
MP_VALUE = "1" if enable_multiprocessing else "0"
|
||||||
|
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)
|
||||||
|
|
||||||
|
# Monkeypatch an error in the model.
|
||||||
|
monkeypatch.setattr(LlamaForCausalLM, failing_method, evil_method)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
Exception,
|
||||||
|
match="initialization failed"
|
||||||
|
if enable_multiprocessing else "Simulated Error in startup!"):
|
||||||
|
_ = LLM(model=model,
|
||||||
|
enforce_eager=True,
|
||||||
|
tensor_parallel_size=tensor_parallel_size)
|
||||||
|
|
||||||
|
# Confirm all the processes are cleaned up.
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=list(range(tensor_parallel_size)),
|
||||||
|
threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,
|
||||||
|
)
|
5
tests/v1/shutdown/utils.py
Normal file
5
tests/v1/shutdown/utils.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Shutdown test utils"""
|
||||||
|
|
||||||
|
SHUTDOWN_TEST_TIMEOUT_SEC = 120
|
||||||
|
SHUTDOWN_TEST_THRESHOLD_BYTES = 2 * 2**30
|
@ -7,11 +7,13 @@ import time
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import List, Optional, Tuple, Union
|
from threading import Event
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import zmq
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from zmq import IPV6 # type: ignore
|
from zmq import IPV6 # type: ignore
|
||||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
@ -400,7 +402,9 @@ class MessageQueue:
|
|||||||
break
|
break
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_read(self, timeout: Optional[float] = None):
|
def acquire_read(self,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
cancel: Optional[Event] = None):
|
||||||
assert self._is_local_reader, "Only readers can acquire read"
|
assert self._is_local_reader, "Only readers can acquire read"
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
@ -430,6 +434,9 @@ class MessageQueue:
|
|||||||
)
|
)
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
|
|
||||||
|
if cancel is not None and cancel.is_set():
|
||||||
|
raise RuntimeError("cancelled")
|
||||||
|
|
||||||
# if we time out, raise an exception
|
# if we time out, raise an exception
|
||||||
if (timeout is not None
|
if (timeout is not None
|
||||||
and time.monotonic() - start_time > timeout):
|
and time.monotonic() - start_time > timeout):
|
||||||
@ -464,10 +471,12 @@ class MessageQueue:
|
|||||||
if self.n_remote_reader > 0:
|
if self.n_remote_reader > 0:
|
||||||
self.remote_socket.send(serialized_obj)
|
self.remote_socket.send(serialized_obj)
|
||||||
|
|
||||||
def dequeue(self, timeout: Optional[float] = None):
|
def dequeue(self,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
cancel: Optional[Event] = None):
|
||||||
""" Read from message queue with optional timeout (in seconds) """
|
""" Read from message queue with optional timeout (in seconds) """
|
||||||
if self._is_local_reader:
|
if self._is_local_reader:
|
||||||
with self.acquire_read(timeout) as buf:
|
with self.acquire_read(timeout, cancel) as buf:
|
||||||
overflow = buf[0] == 1
|
overflow = buf[0] == 1
|
||||||
if not overflow:
|
if not overflow:
|
||||||
# no need to know the size of serialized object
|
# no need to know the size of serialized object
|
||||||
@ -475,15 +484,21 @@ class MessageQueue:
|
|||||||
# see https://docs.python.org/3/library/pickle.html
|
# see https://docs.python.org/3/library/pickle.html
|
||||||
obj = pickle.loads(buf[1:])
|
obj = pickle.loads(buf[1:])
|
||||||
if overflow:
|
if overflow:
|
||||||
recv = self.local_socket.recv()
|
obj = MessageQueue.recv(self.local_socket, timeout)
|
||||||
obj = pickle.loads(recv)
|
|
||||||
elif self._is_remote_reader:
|
elif self._is_remote_reader:
|
||||||
recv = self.remote_socket.recv()
|
obj = MessageQueue.recv(self.remote_socket, timeout)
|
||||||
obj = pickle.loads(recv)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Only readers can dequeue")
|
raise RuntimeError("Only readers can dequeue")
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
|
||||||
|
timeout_ms = None if timeout is None else int(timeout * 1000)
|
||||||
|
if not socket.poll(timeout=timeout_ms):
|
||||||
|
raise TimeoutError
|
||||||
|
recv = socket.recv(copy=False)
|
||||||
|
return pickle.loads(recv.buffer)
|
||||||
|
|
||||||
def broadcast_object(self, obj=None):
|
def broadcast_object(self, obj=None):
|
||||||
if self._is_writer:
|
if self._is_writer:
|
||||||
self.enqueue(obj)
|
self.enqueue(obj)
|
||||||
|
@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_process_using_port
|
from vllm.utils import find_process_using_port
|
||||||
|
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
|
|||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
watchdog_task = loop.create_task(
|
||||||
|
watchdog_loop(server, app.state.engine_client))
|
||||||
server_task = loop.create_task(
|
server_task = loop.create_task(
|
||||||
server.serve(sockets=[sock] if sock else None))
|
server.serve(sockets=[sock] if sock else None))
|
||||||
|
|
||||||
@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
|
|||||||
def signal_handler() -> None:
|
def signal_handler() -> None:
|
||||||
# prevents the uvicorn signal handler to exit early
|
# prevents the uvicorn signal handler to exit early
|
||||||
server_task.cancel()
|
server_task.cancel()
|
||||||
|
watchdog_task.cancel()
|
||||||
if ssl_cert_refresher:
|
if ssl_cert_refresher:
|
||||||
ssl_cert_refresher.stop()
|
ssl_cert_refresher.stop()
|
||||||
|
|
||||||
@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
|
|||||||
port, process, " ".join(process.cmdline()))
|
port, process, " ".join(process.cmdline()))
|
||||||
logger.info("Shutting down FastAPI HTTP server.")
|
logger.info("Shutting down FastAPI HTTP server.")
|
||||||
return server.shutdown()
|
return server.shutdown()
|
||||||
|
finally:
|
||||||
|
watchdog_task.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
|
||||||
|
"""
|
||||||
|
# Watchdog task that runs in the background, checking
|
||||||
|
# for error state in the engine. Needed to trigger shutdown
|
||||||
|
# if an exception arises is StreamingResponse() generator.
|
||||||
|
"""
|
||||||
|
VLLM_WATCHDOG_TIME_S = 5.0
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
|
||||||
|
terminate_if_errored(server, engine)
|
||||||
|
|
||||||
|
|
||||||
|
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
|
||||||
|
"""
|
||||||
|
See discussions here on shutting down a uvicorn server
|
||||||
|
https://github.com/encode/uvicorn/discussions/1103
|
||||||
|
In this case we cannot await the server shutdown here
|
||||||
|
because handler must first return to close the connection
|
||||||
|
for this request.
|
||||||
|
"""
|
||||||
|
engine_errored = engine.errored and not engine.is_running
|
||||||
|
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
|
||||||
|
server.should_exit = True
|
||||||
|
|
||||||
|
|
||||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||||
"""Adds handlers for fatal errors that should crash the server"""
|
"""
|
||||||
|
VLLM V1 AsyncLLM catches exceptions and returns
|
||||||
|
only two types: EngineGenerateError and EngineDeadError.
|
||||||
|
|
||||||
|
EngineGenerateError is raised by the per request generate()
|
||||||
|
method. This error could be request specific (and therefore
|
||||||
|
recoverable - e.g. if there is an error in input processing).
|
||||||
|
|
||||||
|
EngineDeadError is raised by the background output_handler
|
||||||
|
method. This error is global and therefore not recoverable.
|
||||||
|
|
||||||
|
We register these @app.exception_handlers to return nice
|
||||||
|
responses to the end user if they occur and shut down if needed.
|
||||||
|
See https://fastapi.tiangolo.com/tutorial/handling-errors/
|
||||||
|
for more details on how exception handlers work.
|
||||||
|
|
||||||
|
If an exception is encountered in a StreamingResponse
|
||||||
|
generator, the exception is not raised, since we already sent
|
||||||
|
a 200 status. Rather, we send an error message as the next chunk.
|
||||||
|
Since the exception is not raised, this means that the server
|
||||||
|
will not automatically shut down. Instead, we use the watchdog
|
||||||
|
background task for check for errored state.
|
||||||
|
"""
|
||||||
|
|
||||||
@app.exception_handler(RuntimeError)
|
@app.exception_handler(RuntimeError)
|
||||||
async def runtime_error_handler(request: Request, __):
|
|
||||||
"""On generic runtime error, check to see if the engine has died.
|
|
||||||
It probably has, in which case the server will no longer be able to
|
|
||||||
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
|
||||||
engine = request.app.state.engine_client
|
|
||||||
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
|
||||||
and not engine.is_running):
|
|
||||||
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
|
||||||
"process")
|
|
||||||
# See discussions here on shutting down a uvicorn server
|
|
||||||
# https://github.com/encode/uvicorn/discussions/1103
|
|
||||||
# In this case we cannot await the server shutdown here because
|
|
||||||
# this handler must first return to close the connection for
|
|
||||||
# this request.
|
|
||||||
server.should_exit = True
|
|
||||||
|
|
||||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
||||||
|
|
||||||
@app.exception_handler(AsyncEngineDeadError)
|
@app.exception_handler(AsyncEngineDeadError)
|
||||||
async def async_engine_dead_handler(_, __):
|
|
||||||
"""Kill the server if the async engine is already dead. It will
|
|
||||||
not handle any further requests."""
|
|
||||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
|
||||||
logger.fatal("AsyncLLMEngine is already dead, terminating server "
|
|
||||||
"process")
|
|
||||||
server.should_exit = True
|
|
||||||
|
|
||||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
||||||
|
|
||||||
@app.exception_handler(MQEngineDeadError)
|
@app.exception_handler(MQEngineDeadError)
|
||||||
async def mq_engine_dead_handler(_, __):
|
@app.exception_handler(EngineDeadError)
|
||||||
"""Kill the server if the mq engine is already dead. It will
|
@app.exception_handler(EngineGenerateError)
|
||||||
not handle any further requests."""
|
async def runtime_exception_handler(request: Request, __):
|
||||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
terminate_if_errored(
|
||||||
logger.fatal("MQLLMEngine is already dead, terminating server "
|
server=server,
|
||||||
"process")
|
engine=request.app.state.engine_client,
|
||||||
server.should_exit = True
|
)
|
||||||
|
|
||||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
@ -156,3 +156,5 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
START_DP = b'\x02'
|
START_DP = b'\x02'
|
||||||
UTILITY = b'\x03'
|
UTILITY = b'\x03'
|
||||||
|
# Sentinel used within EngineCoreProc.
|
||||||
|
EXECUTOR_FAILED = b'\x04'
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections.abc import AsyncGenerator, Mapping
|
from collections.abc import AsyncGenerator, Mapping
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
@ -26,9 +24,10 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, cdiv, kill_process_tree
|
from vllm.utils import Device, cdiv
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
|
||||||
|
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||||
RequestOutputCollector)
|
RequestOutputCollector)
|
||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
@ -61,8 +60,6 @@ class AsyncLLM(EngineClient):
|
|||||||
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
|
||||||
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
|
||||||
|
|
||||||
assert start_engine_loop
|
|
||||||
|
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.log_requests = log_requests
|
self.log_requests = log_requests
|
||||||
@ -99,15 +96,23 @@ class AsyncLLM(EngineClient):
|
|||||||
log_stats=self.log_stats)
|
log_stats=self.log_stats)
|
||||||
|
|
||||||
# EngineCore (starts the engine in background process).
|
# EngineCore (starts the engine in background process).
|
||||||
self.engine_core = EngineCoreClient.make_client(
|
core_client_class = AsyncMPClient if (
|
||||||
multiprocess_mode=True,
|
vllm_config.parallel_config.data_parallel_size
|
||||||
asyncio_mode=True,
|
== 1) else DPAsyncMPClient
|
||||||
|
|
||||||
|
self.engine_core = core_client_class(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output_handler: Optional[asyncio.Task] = None
|
self.output_handler: Optional[asyncio.Task] = None
|
||||||
|
try:
|
||||||
|
# Start output handler eagerly if we are in the asyncio eventloop.
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
self._run_output_handler()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_vllm_config(
|
def from_vllm_config(
|
||||||
@ -165,6 +170,9 @@ class AsyncLLM(EngineClient):
|
|||||||
usage_context=usage_context,
|
usage_context=usage_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown, cleaning up the background proc and IPC."""
|
"""Shutdown, cleaning up the background proc and IPC."""
|
||||||
|
|
||||||
@ -187,6 +195,9 @@ class AsyncLLM(EngineClient):
|
|||||||
) -> RequestOutputCollector:
|
) -> RequestOutputCollector:
|
||||||
"""Add new request to the AsyncLLM."""
|
"""Add new request to the AsyncLLM."""
|
||||||
|
|
||||||
|
if self.errored:
|
||||||
|
raise EngineDeadError()
|
||||||
|
|
||||||
assert isinstance(params, SamplingParams), \
|
assert isinstance(params, SamplingParams), \
|
||||||
"Pooling is not supported in V1"
|
"Pooling is not supported in V1"
|
||||||
|
|
||||||
@ -261,9 +272,7 @@ class AsyncLLM(EngineClient):
|
|||||||
# We start the output_handler on the first call to generate() so
|
# We start the output_handler on the first call to generate() so
|
||||||
# we can call __init__ before the event loop, which enables us
|
# we can call __init__ before the event loop, which enables us
|
||||||
# to handle startup failure gracefully in the OpenAI server.
|
# to handle startup failure gracefully in the OpenAI server.
|
||||||
if self.output_handler is None:
|
self._run_output_handler()
|
||||||
self.output_handler = asyncio.create_task(
|
|
||||||
self._run_output_handler())
|
|
||||||
|
|
||||||
q = await self.add_request(
|
q = await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
@ -288,62 +297,96 @@ class AsyncLLM(EngineClient):
|
|||||||
finished = out.finished
|
finished = out.finished
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
# If the request is disconnected by the client, the
|
# If the request is disconnected by the client, generate()
|
||||||
# generate() task will be canceled. So, we abort the
|
# is cancelled. So, we abort the request if we end up here.
|
||||||
# request if we end up here.
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await self.abort(request_id)
|
await self.abort(request_id)
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s aborted.", request_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _run_output_handler(self):
|
# Engine is dead. Do not abort since we shut down.
|
||||||
|
except EngineDeadError:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed (engine dead).", request_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Request validation error.
|
||||||
|
except ValueError:
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed (bad request).", request_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Unexpected error in the generate() task (possibly recoverable).
|
||||||
|
except Exception as e:
|
||||||
|
await self.abort(request_id)
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Request %s failed.", request_id)
|
||||||
|
raise EngineGenerateError() from e
|
||||||
|
|
||||||
|
def _run_output_handler(self):
|
||||||
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
|
||||||
|
|
||||||
try:
|
if self.output_handler is not None:
|
||||||
while True:
|
return
|
||||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
|
||||||
outputs = await self.engine_core.get_output_async()
|
|
||||||
num_outputs = len(outputs.outputs)
|
|
||||||
|
|
||||||
iteration_stats = IterationStats() if (
|
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
|
||||||
self.log_stats and num_outputs) else None
|
# object, or else it won't be garbage collected and cleaned up properly.
|
||||||
|
engine_core = self.engine_core
|
||||||
|
output_processor = self.output_processor
|
||||||
|
log_stats = self.log_stats
|
||||||
|
stat_loggers = self.stat_loggers if log_stats else None
|
||||||
|
|
||||||
# Split outputs into chunks of at most
|
async def output_handler():
|
||||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
try:
|
||||||
# event loop for too long.
|
while True:
|
||||||
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||||
slices = (outputs.outputs, )
|
outputs = await engine_core.get_output_async()
|
||||||
else:
|
num_outputs = len(outputs.outputs)
|
||||||
slices = np.array_split(
|
|
||||||
outputs.outputs,
|
|
||||||
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
|
||||||
|
|
||||||
for i, outputs_slice in enumerate(slices):
|
iteration_stats = IterationStats() if (
|
||||||
# 2) Process EngineCoreOutputs.
|
log_stats and num_outputs) else None
|
||||||
processed_outputs = self.output_processor.process_outputs(
|
|
||||||
outputs_slice, outputs.timestamp, iteration_stats)
|
|
||||||
# NOTE: RequestOutputs are pushed to their queues.
|
|
||||||
assert not processed_outputs.request_outputs
|
|
||||||
|
|
||||||
# Allow other asyncio tasks to run between chunks
|
# Split outputs into chunks of at most
|
||||||
if i + 1 < len(slices):
|
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||||
await asyncio.sleep(0)
|
# event loop for too long.
|
||||||
|
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
|
||||||
|
slices = (outputs.outputs, )
|
||||||
|
else:
|
||||||
|
slices = np.array_split(
|
||||||
|
outputs.outputs,
|
||||||
|
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||||
|
|
||||||
# 3) Abort any reqs that finished due to stop strings.
|
for i, outputs_slice in enumerate(slices):
|
||||||
await self.engine_core.abort_requests_async(
|
# 2) Process EngineCoreOutputs.
|
||||||
processed_outputs.reqs_to_abort)
|
processed_outputs = output_processor.process_outputs(
|
||||||
|
outputs_slice, outputs.timestamp, iteration_stats)
|
||||||
|
# NOTE: RequestOutputs are pushed to their queues.
|
||||||
|
assert not processed_outputs.request_outputs
|
||||||
|
|
||||||
# 4) Logging.
|
# Allow other asyncio tasks to run between chunks
|
||||||
# TODO(rob): make into a coroutine and launch it in
|
if i + 1 < len(slices):
|
||||||
# background thread once Prometheus overhead is non-trivial.
|
await asyncio.sleep(0)
|
||||||
self._record_stats(
|
|
||||||
engine_index=outputs.engine_index,
|
|
||||||
scheduler_stats=outputs.scheduler_stats,
|
|
||||||
iteration_stats=iteration_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
# 3) Abort any reqs that finished due to stop strings.
|
||||||
logger.exception("EngineCore output handler hit an error: %s", e)
|
await engine_core.abort_requests_async(
|
||||||
kill_process_tree(os.getpid())
|
processed_outputs.reqs_to_abort)
|
||||||
|
|
||||||
|
# 4) Logging.
|
||||||
|
# TODO(rob): make into a coroutine and launch it in
|
||||||
|
# background thread once Prometheus overhead is non-trivial.
|
||||||
|
if stat_loggers:
|
||||||
|
assert outputs.scheduler_stats is not None
|
||||||
|
AsyncLLM._record_stats(
|
||||||
|
stat_loggers[outputs.engine_index],
|
||||||
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
|
iteration_stats=iteration_stats,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("AsyncLLM output_handler failed.")
|
||||||
|
output_processor.propagate_error(e)
|
||||||
|
|
||||||
|
self.output_handler = asyncio.create_task(output_handler())
|
||||||
|
|
||||||
async def abort(self, request_id: str) -> None:
|
async def abort(self, request_id: str) -> None:
|
||||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||||
@ -354,17 +397,15 @@ class AsyncLLM(EngineClient):
|
|||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Aborted request %s.", request_id)
|
logger.info("Aborted request %s.", request_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _record_stats(
|
def _record_stats(
|
||||||
self,
|
stat_loggers: list[StatLoggerBase],
|
||||||
scheduler_stats: Optional[SchedulerStats],
|
scheduler_stats: SchedulerStats,
|
||||||
iteration_stats: Optional[IterationStats],
|
iteration_stats: Optional[IterationStats],
|
||||||
engine_index: int = 0,
|
|
||||||
):
|
):
|
||||||
if not self.log_stats:
|
"""static so that it can be used from the output_handler task
|
||||||
return
|
without a circular ref to AsyncLLM."""
|
||||||
|
for stat_logger in stat_loggers:
|
||||||
assert scheduler_stats is not None
|
|
||||||
for stat_logger in self.stat_loggers[engine_index]:
|
|
||||||
stat_logger.record(scheduler_stats=scheduler_stats,
|
stat_logger.record(scheduler_stats=scheduler_stats,
|
||||||
iteration_stats=iteration_stats)
|
iteration_stats=iteration_stats)
|
||||||
|
|
||||||
@ -451,16 +492,17 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
return True
|
# Is None before the loop is started.
|
||||||
|
return self.output_handler is None or not self.output_handler.done()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_stopped(self) -> bool:
|
def is_stopped(self) -> bool:
|
||||||
return False
|
return self.errored
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def errored(self) -> bool:
|
def errored(self) -> bool:
|
||||||
return False
|
return self.engine_core.resources.engine_dead or not self.is_running
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
return Exception() # TODO: implement
|
return EngineDeadError()
|
||||||
|
@ -11,9 +11,7 @@ from logging import DEBUG
|
|||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import psutil
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||||
@ -22,8 +20,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
|
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
|
||||||
zmq_socket_ctx)
|
|
||||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||||
unify_kv_cache_configs)
|
unify_kv_cache_configs)
|
||||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
@ -50,12 +47,11 @@ _R = TypeVar('_R') # Return type for collective_rpc
|
|||||||
class EngineCore:
|
class EngineCore:
|
||||||
"""Inner loop of vLLM's Engine."""
|
"""Inner loop of vLLM's Engine."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
vllm_config: VllmConfig,
|
||||||
vllm_config: VllmConfig,
|
executor_class: type[Executor],
|
||||||
executor_class: type[Executor],
|
log_stats: bool,
|
||||||
log_stats: bool,
|
executor_fail_callback: Optional[Callable] = None):
|
||||||
):
|
|
||||||
assert vllm_config.model_config.runner_type != "pooling"
|
assert vllm_config.model_config.runner_type != "pooling"
|
||||||
|
|
||||||
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
||||||
@ -65,6 +61,9 @@ class EngineCore:
|
|||||||
|
|
||||||
# Setup Model.
|
# Setup Model.
|
||||||
self.model_executor = executor_class(vllm_config)
|
self.model_executor = executor_class(vllm_config)
|
||||||
|
if executor_fail_callback is not None:
|
||||||
|
self.model_executor.register_failure_callback(
|
||||||
|
executor_fail_callback)
|
||||||
|
|
||||||
# Setup KV Caches and update CacheConfig after profiling.
|
# Setup KV Caches and update CacheConfig after profiling.
|
||||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
||||||
@ -254,7 +253,8 @@ class EngineCore:
|
|||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.model_executor.shutdown()
|
if self.model_executor:
|
||||||
|
self.model_executor.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
@ -308,6 +308,8 @@ class EngineCore:
|
|||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process."""
|
"""ZMQ-wrapper for running EngineCore in background process."""
|
||||||
|
|
||||||
|
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
@ -317,11 +319,16 @@ class EngineCoreProc(EngineCore):
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
engine_index: int = 0,
|
engine_index: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__(vllm_config, executor_class, log_stats)
|
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||||
|
|
||||||
|
executor_fail_callback = lambda: input_queue.put_nowait(
|
||||||
|
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||||
|
|
||||||
|
super().__init__(vllm_config, executor_class, log_stats,
|
||||||
|
executor_fail_callback)
|
||||||
|
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
self.step_with_batch_queue)
|
self.step_with_batch_queue)
|
||||||
|
|
||||||
self.global_unfinished_reqs = False
|
self.global_unfinished_reqs = False
|
||||||
|
|
||||||
# Background Threads and Queues for IO. These enable us to
|
# Background Threads and Queues for IO. These enable us to
|
||||||
@ -329,15 +336,16 @@ class EngineCoreProc(EngineCore):
|
|||||||
# and to overlap some serialization/deserialization with the
|
# and to overlap some serialization/deserialization with the
|
||||||
# model forward pass.
|
# model forward pass.
|
||||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||||
self.input_queue: queue.Queue[tuple[EngineCoreRequestType,
|
self.input_queue = input_queue
|
||||||
Any]] = queue.Queue()
|
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
|
||||||
threading.Thread(target=self.process_input_socket,
|
threading.Thread(target=self.process_input_socket,
|
||||||
args=(input_path, engine_index),
|
args=(input_path, engine_index),
|
||||||
daemon=True).start()
|
daemon=True).start()
|
||||||
threading.Thread(target=self.process_output_socket,
|
self.output_thread = threading.Thread(
|
||||||
args=(output_path, engine_index),
|
target=self.process_output_socket,
|
||||||
daemon=True).start()
|
args=(output_path, engine_index),
|
||||||
|
daemon=True)
|
||||||
|
self.output_thread.start()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_engine_core(*args,
|
def run_engine_core(*args,
|
||||||
@ -364,7 +372,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
parent_process = psutil.Process().parent()
|
|
||||||
engine_core: Optional[EngineCoreProc] = None
|
engine_core: Optional[EngineCoreProc] = None
|
||||||
try:
|
try:
|
||||||
parallel_config: ParallelConfig = kwargs[
|
parallel_config: ParallelConfig = kwargs[
|
||||||
@ -380,13 +387,15 @@ class EngineCoreProc(EngineCore):
|
|||||||
engine_core.run_busy_loop()
|
engine_core.run_busy_loop()
|
||||||
|
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
logger.debug("EngineCore interrupted.")
|
logger.debug("EngineCore exiting.")
|
||||||
|
|
||||||
except Exception:
|
|
||||||
traceback = get_exception_traceback()
|
|
||||||
logger.error("EngineCore hit an exception: %s", traceback)
|
|
||||||
parent_process.send_signal(signal.SIGUSR1)
|
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if engine_core is None:
|
||||||
|
logger.exception("EngineCore failed to start.")
|
||||||
|
else:
|
||||||
|
logger.exception("EngineCore encountered a fatal error.")
|
||||||
|
engine_core._send_engine_dead()
|
||||||
|
raise e
|
||||||
finally:
|
finally:
|
||||||
if engine_core is not None:
|
if engine_core is not None:
|
||||||
engine_core.shutdown()
|
engine_core.shutdown()
|
||||||
@ -458,6 +467,11 @@ class EngineCoreProc(EngineCore):
|
|||||||
f" failed: {str(e)}")
|
f" failed: {str(e)}")
|
||||||
self.output_queue.put_nowait(
|
self.output_queue.put_nowait(
|
||||||
EngineCoreOutputs(utility_output=output))
|
EngineCoreOutputs(utility_output=output))
|
||||||
|
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||||
|
raise RuntimeError("Executor failed.")
|
||||||
|
else:
|
||||||
|
logger.error("Unrecognized input request type encountered: %s",
|
||||||
|
request_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_msgspec_args(method, args):
|
def _convert_msgspec_args(method, args):
|
||||||
@ -473,6 +487,18 @@ class EngineCoreProc(EngineCore):
|
|||||||
and not isinstance(v, p.annotation) else v
|
and not isinstance(v, p.annotation) else v
|
||||||
for v, p in zip(args, arg_types))
|
for v, p in zip(args, arg_types))
|
||||||
|
|
||||||
|
def _send_engine_dead(self):
|
||||||
|
"""Send EngineDead status to the EngineCoreClient."""
|
||||||
|
|
||||||
|
# Put ENGINE_CORE_DEAD in the queue.
|
||||||
|
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
|
||||||
|
|
||||||
|
# Wait until msg sent by the daemon before shutdown.
|
||||||
|
self.output_thread.join(timeout=5.0)
|
||||||
|
if self.output_thread.is_alive():
|
||||||
|
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||||
|
"to send. Please report this issue.")
|
||||||
|
|
||||||
def process_input_socket(self, input_path: str, engine_index: int):
|
def process_input_socket(self, input_path: str, engine_index: int):
|
||||||
"""Input socket IO thread."""
|
"""Input socket IO thread."""
|
||||||
|
|
||||||
@ -511,9 +537,16 @@ class EngineCoreProc(EngineCore):
|
|||||||
# Reuse send buffer.
|
# Reuse send buffer.
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
|
|
||||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
|
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||||
|
# message is sent prior to closing the socket.
|
||||||
|
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
|
||||||
|
linger=4000) as socket:
|
||||||
while True:
|
while True:
|
||||||
outputs = self.output_queue.get()
|
outputs = self.output_queue.get()
|
||||||
|
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||||
|
socket.send(outputs, copy=False)
|
||||||
|
break
|
||||||
|
assert not isinstance(outputs, bytes)
|
||||||
outputs.engine_index = engine_index
|
outputs.engine_index = engine_index
|
||||||
buffers = encoder.encode_into(outputs, buffer)
|
buffers = encoder.encode_into(outputs, buffer)
|
||||||
socket.send_multipart(buffers, copy=False)
|
socket.send_multipart(buffers, copy=False)
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import queue
|
import queue
|
||||||
import signal
|
|
||||||
import threading
|
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable, Sequence
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -21,10 +17,11 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
|
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
|
||||||
kill_process_tree, make_zmq_socket)
|
make_zmq_socket)
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType, UtilityOutput)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||||
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||||
from vllm.v1.utils import BackgroundProcHandle
|
from vllm.v1.utils import BackgroundProcHandle
|
||||||
@ -305,14 +302,22 @@ class BackgroundResources:
|
|||||||
core_engines: list[CoreEngine] = field(default_factory=list)
|
core_engines: list[CoreEngine] = field(default_factory=list)
|
||||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
|
output_queue_task: Optional[asyncio.Task] = None
|
||||||
shutdown_path: Optional[str] = None
|
shutdown_path: Optional[str] = None
|
||||||
|
|
||||||
|
# Set if any of the engines are dead. Here so that the output
|
||||||
|
# processing threads can access it without holding a ref to the client.
|
||||||
|
engine_dead: bool = False
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
"""Clean up background resources."""
|
"""Clean up background resources."""
|
||||||
|
|
||||||
for core_engine in self.core_engines:
|
for core_engine in self.core_engines:
|
||||||
core_engine.close()
|
core_engine.close()
|
||||||
|
|
||||||
|
if self.output_queue_task is not None:
|
||||||
|
self.output_queue_task.cancel()
|
||||||
|
|
||||||
# ZMQ context termination can hang if the sockets
|
# ZMQ context termination can hang if the sockets
|
||||||
# aren't explicitly closed first.
|
# aren't explicitly closed first.
|
||||||
if self.output_socket is not None:
|
if self.output_socket is not None:
|
||||||
@ -327,6 +332,12 @@ class BackgroundResources:
|
|||||||
# Send shutdown signal.
|
# Send shutdown signal.
|
||||||
shutdown_sender.send(b'')
|
shutdown_sender.send(b'')
|
||||||
|
|
||||||
|
def validate_alive(self, frames: Sequence[zmq.Frame]):
|
||||||
|
if len(frames) == 1 and (frames[0].buffer
|
||||||
|
== EngineCoreProc.ENGINE_CORE_DEAD):
|
||||||
|
self.engine_dead = True
|
||||||
|
raise EngineDeadError()
|
||||||
|
|
||||||
|
|
||||||
class MPClient(EngineCoreClient):
|
class MPClient(EngineCoreClient):
|
||||||
"""
|
"""
|
||||||
@ -348,27 +359,6 @@ class MPClient(EngineCoreClient):
|
|||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
# The child processes will send SIGUSR1 when unrecoverable
|
|
||||||
# errors happen. We kill the process tree here so that the
|
|
||||||
# stack trace is very evident.
|
|
||||||
# TODO(rob): rather than killing the main process, we should
|
|
||||||
# figure out how to raise an AsyncEngineDeadError and
|
|
||||||
# handle at the API server level so we can return a better
|
|
||||||
# error code to the clients calling vLLM.
|
|
||||||
def sigusr1_handler(signum, frame):
|
|
||||||
logger.fatal("Got fatal signal from worker processes, shutting "
|
|
||||||
"down. See stack trace above for root cause issue.")
|
|
||||||
kill_process_tree(os.getpid())
|
|
||||||
|
|
||||||
if threading.current_thread() == threading.main_thread():
|
|
||||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
|
||||||
else:
|
|
||||||
logger.warning("SIGUSR1 handler not installed because we are not "
|
|
||||||
"running in the main thread. In this case the "
|
|
||||||
"forked engine process may not be killed when "
|
|
||||||
"an exception is raised, and you need to handle "
|
|
||||||
"the engine process shutdown manually.")
|
|
||||||
|
|
||||||
# Serialization setup.
|
# Serialization setup.
|
||||||
self.encoder = MsgpackEncoder()
|
self.encoder = MsgpackEncoder()
|
||||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||||
@ -378,32 +368,37 @@ class MPClient(EngineCoreClient):
|
|||||||
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
||||||
|
|
||||||
# This will ensure resources created so far are closed
|
# This will ensure resources created so far are closed
|
||||||
# when the client is garbage collected, even if an
|
# when the client is garbage collected, even if an
|
||||||
# exception is raised mid-construction.
|
# exception is raised mid-construction.
|
||||||
self.resources = BackgroundResources(ctx=sync_ctx)
|
self.resources = BackgroundResources(ctx=sync_ctx)
|
||||||
self._finalizer = weakref.finalize(self, self.resources)
|
self._finalizer = weakref.finalize(self, self.resources)
|
||||||
|
success = False
|
||||||
|
try:
|
||||||
|
# Paths and sockets for IPC.
|
||||||
|
self.output_path = get_open_zmq_ipc_path()
|
||||||
|
input_path = get_open_zmq_ipc_path()
|
||||||
|
self.input_socket = make_zmq_socket(self.ctx,
|
||||||
|
input_path,
|
||||||
|
zmq.ROUTER,
|
||||||
|
bind=True)
|
||||||
|
self.resources.input_socket = self.input_socket
|
||||||
|
|
||||||
# Paths and sockets for IPC.
|
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
||||||
self.output_path = get_open_zmq_ipc_path()
|
vllm_config, executor_class, log_stats, input_path, self.
|
||||||
input_path = get_open_zmq_ipc_path()
|
output_path, index, local_dp_rank)
|
||||||
self.input_socket = make_zmq_socket(self.ctx,
|
|
||||||
input_path,
|
|
||||||
zmq.ROUTER,
|
|
||||||
bind=True)
|
|
||||||
self.resources.input_socket = self.input_socket
|
|
||||||
|
|
||||||
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
# Start engine core process(es).
|
||||||
vllm_config, executor_class, log_stats, input_path, self.
|
self._init_core_engines(vllm_config, new_core_engine,
|
||||||
output_path, index, local_dp_rank)
|
self.resources.core_engines)
|
||||||
|
|
||||||
# Start engine core process(es).
|
# Wait for engine core process(es) to start.
|
||||||
self._init_core_engines(vllm_config, new_core_engine,
|
self._wait_for_engine_startup()
|
||||||
self.resources.core_engines)
|
|
||||||
|
|
||||||
# Wait for engine core process(es) to start.
|
self.utility_results: dict[int, AnyFuture] = {}
|
||||||
self._wait_for_engine_startup()
|
success = True
|
||||||
|
finally:
|
||||||
self.utility_results: dict[int, AnyFuture] = {}
|
if not success:
|
||||||
|
self._finalizer()
|
||||||
|
|
||||||
def _wait_for_engine_startup(self):
|
def _wait_for_engine_startup(self):
|
||||||
# Get a sync handle to the socket which can be sync or async.
|
# Get a sync handle to the socket which can be sync or async.
|
||||||
@ -451,8 +446,18 @@ class MPClient(EngineCoreClient):
|
|||||||
self.core_engine = core_engine
|
self.core_engine = core_engine
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
# Terminate background resources.
|
||||||
self._finalizer()
|
self._finalizer()
|
||||||
|
|
||||||
|
def _format_exception(self, e: Exception) -> Exception:
|
||||||
|
"""If errored, use EngineDeadError so root cause is clear."""
|
||||||
|
return EngineDeadError(
|
||||||
|
suppress_context=True) if self.resources.engine_dead else e
|
||||||
|
|
||||||
|
def ensure_alive(self):
|
||||||
|
if self.resources.engine_dead:
|
||||||
|
raise EngineDeadError()
|
||||||
|
|
||||||
|
|
||||||
def _process_utility_output(output: UtilityOutput,
|
def _process_utility_output(output: UtilityOutput,
|
||||||
utility_results: dict[int, AnyFuture]):
|
utility_results: dict[int, AnyFuture]):
|
||||||
@ -476,7 +481,7 @@ class SyncMPClient(MPClient):
|
|||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
|
||||||
|
|
||||||
# Ensure that the outputs socket processing thread does not have
|
# Ensure that the outputs socket processing thread does not have
|
||||||
# a ref to the client which prevents gc.
|
# a ref to the client which prevents gc.
|
||||||
@ -487,7 +492,8 @@ class SyncMPClient(MPClient):
|
|||||||
outputs_queue = self.outputs_queue
|
outputs_queue = self.outputs_queue
|
||||||
|
|
||||||
shutdown_path = get_open_zmq_inproc_path()
|
shutdown_path = get_open_zmq_inproc_path()
|
||||||
self.resources.shutdown_path = shutdown_path
|
resources = self.resources
|
||||||
|
resources.shutdown_path = shutdown_path
|
||||||
|
|
||||||
def process_outputs_socket():
|
def process_outputs_socket():
|
||||||
shutdown_socket = ctx.socket(zmq.PAIR)
|
shutdown_socket = ctx.socket(zmq.PAIR)
|
||||||
@ -506,12 +512,15 @@ class SyncMPClient(MPClient):
|
|||||||
break
|
break
|
||||||
|
|
||||||
frames = out_socket.recv_multipart(copy=False)
|
frames = out_socket.recv_multipart(copy=False)
|
||||||
|
resources.validate_alive(frames)
|
||||||
outputs = decoder.decode(frames)
|
outputs = decoder.decode(frames)
|
||||||
if outputs.utility_output:
|
if outputs.utility_output:
|
||||||
_process_utility_output(outputs.utility_output,
|
_process_utility_output(outputs.utility_output,
|
||||||
utility_results)
|
utility_results)
|
||||||
else:
|
else:
|
||||||
outputs_queue.put_nowait(outputs)
|
outputs_queue.put_nowait(outputs)
|
||||||
|
except Exception as e:
|
||||||
|
outputs_queue.put_nowait(e)
|
||||||
finally:
|
finally:
|
||||||
# Close sockets.
|
# Close sockets.
|
||||||
shutdown_socket.close(linger=0)
|
shutdown_socket.close(linger=0)
|
||||||
@ -524,9 +533,16 @@ class SyncMPClient(MPClient):
|
|||||||
self.output_queue_thread.start()
|
self.output_queue_thread.start()
|
||||||
|
|
||||||
def get_output(self) -> EngineCoreOutputs:
|
def get_output(self) -> EngineCoreOutputs:
|
||||||
return self.outputs_queue.get()
|
# If an exception arises in process_outputs_socket task,
|
||||||
|
# it is forwarded to the outputs_queue so we can raise it
|
||||||
|
# from this (run_output_handler) task to shut down the server.
|
||||||
|
outputs = self.outputs_queue.get()
|
||||||
|
if isinstance(outputs, Exception):
|
||||||
|
raise self._format_exception(outputs) from None
|
||||||
|
return outputs
|
||||||
|
|
||||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||||
|
self.ensure_alive()
|
||||||
# (Identity, RequestType, SerializedRequest)
|
# (Identity, RequestType, SerializedRequest)
|
||||||
msg = (self.core_engine.identity, request_type.value,
|
msg = (self.core_engine.identity, request_type.value,
|
||||||
*self.encoder.encode(request))
|
*self.encoder.encode(request))
|
||||||
@ -608,61 +624,81 @@ class AsyncMPClient(MPClient):
|
|||||||
log_stats=log_stats,
|
log_stats=log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
|
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
||||||
self.queue_task: Optional[asyncio.Task] = None
|
Exception]]()
|
||||||
|
try:
|
||||||
self.outputs_handler: Optional[Callable[
|
# If we are running in an asyncio event loop, start the queue task.
|
||||||
[AsyncMPClient, EngineCoreOutputs], Awaitable[None]]] = None
|
# Otherwise, it will be started lazily. If it is not started here,
|
||||||
|
# we could miss EXECUTOR_FAILED messages from engine core if they
|
||||||
|
# occur prior to any requests being sent.
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
self._ensure_output_queue_task()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
def _ensure_output_queue_task(self):
|
def _ensure_output_queue_task(self):
|
||||||
if self.outputs_queue is not None:
|
resources = self.resources
|
||||||
|
if resources.output_queue_task is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Perform IO in separate task to parallelize as much as possible.
|
# Perform IO in separate task to parallelize as much as possible.
|
||||||
# Avoid task having direct reference back to the client.
|
# Avoid task having direct reference back to the client.
|
||||||
self.outputs_queue = asyncio.Queue()
|
|
||||||
decoder = self.decoder
|
decoder = self.decoder
|
||||||
utility_results = self.utility_results
|
utility_results = self.utility_results
|
||||||
outputs_queue = self.outputs_queue
|
outputs_queue = self.outputs_queue
|
||||||
output_handler = self.outputs_handler
|
output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs],
|
||||||
|
Awaitable[None]]] = getattr(
|
||||||
|
self.__class__,
|
||||||
|
"process_engine_outputs", None)
|
||||||
_self_ref = weakref.ref(self) if output_handler else None
|
_self_ref = weakref.ref(self) if output_handler else None
|
||||||
output_path = self.output_path
|
output_path = self.output_path
|
||||||
output_socket = make_zmq_socket(self.ctx, output_path,
|
output_socket = make_zmq_socket(self.ctx, output_path,
|
||||||
zmq.constants.PULL)
|
zmq.constants.PULL)
|
||||||
self.resources.output_socket = output_socket
|
resources.output_socket = output_socket
|
||||||
|
|
||||||
async def process_outputs_socket():
|
async def process_outputs_socket():
|
||||||
while True:
|
try:
|
||||||
frames = await output_socket.recv_multipart(copy=False)
|
while True:
|
||||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
frames = await output_socket.recv_multipart(copy=False)
|
||||||
if outputs.utility_output:
|
resources.validate_alive(frames)
|
||||||
_process_utility_output(outputs.utility_output,
|
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||||
utility_results)
|
if outputs.utility_output:
|
||||||
continue
|
_process_utility_output(outputs.utility_output,
|
||||||
|
utility_results)
|
||||||
|
continue
|
||||||
|
|
||||||
if output_handler is not None:
|
if output_handler is not None:
|
||||||
assert _self_ref is not None
|
assert _self_ref is not None
|
||||||
_self = _self_ref()
|
_self = _self_ref()
|
||||||
if not _self:
|
if not _self:
|
||||||
# Client has been garbage collected, abort.
|
# Client has been garbage collected, abort.
|
||||||
return
|
return
|
||||||
await output_handler(_self, outputs)
|
await output_handler(_self, outputs)
|
||||||
|
|
||||||
if outputs.outputs or outputs.scheduler_stats:
|
if outputs.outputs or outputs.scheduler_stats:
|
||||||
outputs_queue.put_nowait(outputs)
|
outputs_queue.put_nowait(outputs)
|
||||||
|
except Exception as e:
|
||||||
|
outputs_queue.put_nowait(e)
|
||||||
|
|
||||||
self.queue_task = asyncio.create_task(process_outputs_socket(),
|
resources.output_queue_task = asyncio.create_task(
|
||||||
name="EngineCoreOutputQueueTask")
|
process_outputs_socket(), name="EngineCoreOutputQueueTask")
|
||||||
|
|
||||||
async def get_output_async(self) -> EngineCoreOutputs:
|
async def get_output_async(self) -> EngineCoreOutputs:
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
|
# If an exception arises in process_outputs_socket task,
|
||||||
|
# it is forwarded to the outputs_queue so we can raise it
|
||||||
|
# from this (run_output_handler) task to shut down the server.
|
||||||
assert self.outputs_queue is not None
|
assert self.outputs_queue is not None
|
||||||
return await self.outputs_queue.get()
|
outputs = await self.outputs_queue.get()
|
||||||
|
if isinstance(outputs, Exception):
|
||||||
|
raise self._format_exception(outputs) from None
|
||||||
|
return outputs
|
||||||
|
|
||||||
def _send_input(self,
|
def _send_input(self,
|
||||||
request_type: EngineCoreRequestType,
|
request_type: EngineCoreRequestType,
|
||||||
request: Any,
|
request: Any,
|
||||||
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
|
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
|
||||||
|
self.ensure_alive()
|
||||||
if engine is None:
|
if engine is None:
|
||||||
engine = self.core_engine
|
engine = self.core_engine
|
||||||
|
|
||||||
@ -671,6 +707,7 @@ class AsyncMPClient(MPClient):
|
|||||||
|
|
||||||
def _send_input_message(self, message: tuple[bytestr, ...],
|
def _send_input_message(self, message: tuple[bytestr, ...],
|
||||||
engine: CoreEngine) -> Awaitable[None]:
|
engine: CoreEngine) -> Awaitable[None]:
|
||||||
|
self.ensure_alive()
|
||||||
message = (engine.identity, ) + message
|
message = (engine.identity, ) + message
|
||||||
return self.input_socket.send_multipart(message, copy=False)
|
return self.input_socket.send_multipart(message, copy=False)
|
||||||
|
|
||||||
@ -754,18 +791,17 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||||
log_stats: bool):
|
log_stats: bool):
|
||||||
super().__init__(vllm_config, executor_class, log_stats)
|
|
||||||
|
|
||||||
assert len(self.core_engines) > 1
|
self.num_engines_running = 0
|
||||||
|
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||||
|
|
||||||
|
super().__init__(vllm_config, executor_class, log_stats)
|
||||||
|
|
||||||
# Control message used for triggering dp idle mode loop.
|
# Control message used for triggering dp idle mode loop.
|
||||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
||||||
*self.encoder.encode(None))
|
*self.encoder.encode(None))
|
||||||
|
|
||||||
self.num_engines_running = 0
|
assert len(self.core_engines) > 1
|
||||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
|
||||||
|
|
||||||
self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment]
|
|
||||||
|
|
||||||
def _init_core_engines(
|
def _init_core_engines(
|
||||||
self,
|
self,
|
||||||
|
16
vllm/v1/engine/exceptions.py
Normal file
16
vllm/v1/engine/exceptions.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
class EngineGenerateError(Exception):
|
||||||
|
"""Raised when a AsyncLLM.generate() fails. Recoverable."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EngineDeadError(Exception):
|
||||||
|
"""Raised when the EngineCore dies. Unrecoverable."""
|
||||||
|
|
||||||
|
def __init__(self, *args, suppress_context: bool = False, **kwargs):
|
||||||
|
ENGINE_DEAD_MESSAGE = "EngineCore encountered an issue. See stack trace (above) for the root cause." # noqa: E501
|
||||||
|
|
||||||
|
super().__init__(ENGINE_DEAD_MESSAGE, *args, **kwargs)
|
||||||
|
# Make stack trace clearer when using with LLMEngine by
|
||||||
|
# silencing irrelevant ZMQError.
|
||||||
|
self.__suppress_context__ = suppress_context
|
@ -28,32 +28,40 @@ class RequestOutputCollector:
|
|||||||
|
|
||||||
def __init__(self, output_kind: RequestOutputKind):
|
def __init__(self, output_kind: RequestOutputKind):
|
||||||
self.aggregate = output_kind == RequestOutputKind.DELTA
|
self.aggregate = output_kind == RequestOutputKind.DELTA
|
||||||
self.output: Optional[RequestOutput] = None
|
self.output: Optional[Union[RequestOutput, Exception]] = None
|
||||||
self.ready = asyncio.Event()
|
self.ready = asyncio.Event()
|
||||||
|
|
||||||
def put(self, output: RequestOutput) -> None:
|
def put(self, output: Union[RequestOutput, Exception]) -> None:
|
||||||
if self.output is None:
|
"""Non-blocking put operation."""
|
||||||
|
if self.output is None or isinstance(output, Exception):
|
||||||
self.output = output
|
self.output = output
|
||||||
self.ready.set()
|
self.ready.set()
|
||||||
elif self.aggregate:
|
elif isinstance(self.output, RequestOutput):
|
||||||
# Coalesce the outputs in delta case.
|
if self.aggregate:
|
||||||
self.output.add(output)
|
# Coalesce the outputs in delta case.
|
||||||
else:
|
self.output.add(output)
|
||||||
# Just replace latest in non-delta case.
|
else:
|
||||||
self.output = output
|
# Just replace latest in non-delta case.
|
||||||
|
self.output = output
|
||||||
|
|
||||||
async def get(self) -> RequestOutput:
|
async def get(self) -> RequestOutput:
|
||||||
|
"""Get operation blocks on put event."""
|
||||||
while (output := self.output) is None:
|
while (output := self.output) is None:
|
||||||
await self.ready.wait()
|
await self.ready.wait()
|
||||||
self.output = None
|
self.output = None
|
||||||
self.ready.clear()
|
self.ready.clear()
|
||||||
|
if isinstance(output, Exception):
|
||||||
|
raise output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_nowait(self) -> Optional[RequestOutput]:
|
def get_nowait(self) -> Optional[RequestOutput]:
|
||||||
|
"""Non-blocking get operation."""
|
||||||
output = self.output
|
output = self.output
|
||||||
if output is not None:
|
if output is not None:
|
||||||
self.output = None
|
self.output = None
|
||||||
self.ready.clear()
|
self.ready.clear()
|
||||||
|
if isinstance(output, Exception):
|
||||||
|
raise output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -235,6 +243,13 @@ class OutputProcessor:
|
|||||||
def has_unfinished_requests(self) -> bool:
|
def has_unfinished_requests(self) -> bool:
|
||||||
return len(self.request_states) > 0
|
return len(self.request_states) > 0
|
||||||
|
|
||||||
|
def propagate_error(self, e: Exception):
|
||||||
|
"""Propagate error to all generate() tasks."""
|
||||||
|
|
||||||
|
for _, state in self.request_states.items():
|
||||||
|
assert state.queue is not None
|
||||||
|
state.queue.put(e)
|
||||||
|
|
||||||
def abort_requests(
|
def abort_requests(
|
||||||
self,
|
self,
|
||||||
request_ids: Iterable[str],
|
request_ids: Iterable[str],
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Union
|
from typing import Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -15,6 +15,8 @@ from vllm.executor.uniproc_executor import ( # noqa
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
|
FailureCallback = Callable[[], None]
|
||||||
|
|
||||||
|
|
||||||
class Executor(ExecutorBase):
|
class Executor(ExecutorBase):
|
||||||
"""
|
"""
|
||||||
@ -62,6 +64,13 @@ class Executor(ExecutorBase):
|
|||||||
args=(kv_cache_configs, ))
|
args=(kv_cache_configs, ))
|
||||||
self.collective_rpc("compile_or_warm_up_model")
|
self.collective_rpc("compile_or_warm_up_model")
|
||||||
|
|
||||||
|
def register_failure_callback(self, callback: FailureCallback):
|
||||||
|
"""
|
||||||
|
Register a function to be called if the executor enters a permanent
|
||||||
|
failed state.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def determine_available_memory(self) -> list[int]: # in bytes
|
def determine_available_memory(self) -> list[int]: # in bytes
|
||||||
output = self.collective_rpc("determine_available_memory")
|
output = self.collective_rpc("determine_available_memory")
|
||||||
return output
|
return output
|
||||||
|
@ -1,21 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import weakref
|
import weakref
|
||||||
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from multiprocessing.connection import Connection
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import Any, Callable, Optional, Union
|
from threading import Thread
|
||||||
|
from typing import Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import psutil
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
@ -26,8 +28,9 @@ from vllm.executor.multiproc_worker_utils import (
|
|||||||
_add_prefix, set_multiprocessing_worker_envs)
|
_add_prefix, set_multiprocessing_worker_envs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
get_open_port)
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -35,6 +38,8 @@ logger = init_logger(__name__)
|
|||||||
POLLING_TIMEOUT_MS = 5000
|
POLLING_TIMEOUT_MS = 5000
|
||||||
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||||
|
|
||||||
|
EXECUTE_MODEL_TIMEOUT_S = 30
|
||||||
|
|
||||||
|
|
||||||
class MultiprocExecutor(Executor):
|
class MultiprocExecutor(Executor):
|
||||||
|
|
||||||
@ -42,19 +47,9 @@ class MultiprocExecutor(Executor):
|
|||||||
# Call self.shutdown at exit to clean up
|
# Call self.shutdown at exit to clean up
|
||||||
# and ensure workers will be terminated.
|
# and ensure workers will be terminated.
|
||||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||||
|
self.is_failed = False
|
||||||
# The child processes will send SIGUSR1 when unrecoverable
|
self.shutdown_event = threading.Event()
|
||||||
# errors happen.
|
self.failure_callback: Optional[FailureCallback] = None
|
||||||
def sigusr1_handler(signum, frame):
|
|
||||||
logger.fatal(
|
|
||||||
"MulitprocExecutor got fatal signal from worker processes, "
|
|
||||||
"shutting down. See stack trace above for root cause issue.")
|
|
||||||
# Propagate error up to parent process.
|
|
||||||
parent_process = psutil.Process().parent()
|
|
||||||
parent_process.send_signal(signal.SIGUSR1)
|
|
||||||
self.shutdown()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGUSR1, sigusr1_handler)
|
|
||||||
|
|
||||||
self.world_size = self.parallel_config.world_size
|
self.world_size = self.parallel_config.world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
@ -78,28 +73,94 @@ class MultiprocExecutor(Executor):
|
|||||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||||
|
|
||||||
# Create workers
|
# Create workers
|
||||||
self.workers: list[WorkerProcHandle] = []
|
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||||
for rank in range(self.world_size):
|
success = False
|
||||||
worker = WorkerProc.make_worker_process(self.vllm_config, rank,
|
try:
|
||||||
rank,
|
for rank in range(self.world_size):
|
||||||
distributed_init_method,
|
unready_workers.append(
|
||||||
scheduler_output_handle)
|
WorkerProc.make_worker_process(
|
||||||
self.workers.append(worker)
|
vllm_config=self.vllm_config,
|
||||||
|
local_rank=rank,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
input_shm_handle=scheduler_output_handle,
|
||||||
|
))
|
||||||
|
|
||||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
# Workers must be created before wait_for_ready to avoid
|
||||||
# Must be kept consistent with the WorkerProc
|
# deadlock, since worker.init_device() does a device sync.
|
||||||
self.rpc_broadcast_mq.wait_until_ready()
|
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||||
for w in self.workers:
|
|
||||||
w.worker_response_mq.wait_until_ready()
|
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||||
|
# Must be kept consistent with the WorkerProc.
|
||||||
|
self.rpc_broadcast_mq.wait_until_ready()
|
||||||
|
for w in self.workers:
|
||||||
|
w.worker_response_mq.wait_until_ready()
|
||||||
|
|
||||||
|
self.start_worker_monitor()
|
||||||
|
success = True
|
||||||
|
finally:
|
||||||
|
if not success:
|
||||||
|
# Clean up the worker procs if there was a failure.
|
||||||
|
self._ensure_worker_termination(
|
||||||
|
[w.proc for w in unready_workers])
|
||||||
|
|
||||||
|
def start_worker_monitor(self):
|
||||||
|
workers = self.workers
|
||||||
|
self_ref = weakref.ref(self)
|
||||||
|
|
||||||
|
# Monitors worker process liveness. If any die unexpectedly,
|
||||||
|
# logs an error, shuts down the executor and invokes the failure
|
||||||
|
# callback to inform the engine.
|
||||||
|
def monitor_workers():
|
||||||
|
sentinels = [h.proc.sentinel for h in workers]
|
||||||
|
died = multiprocessing.connection.wait(sentinels)
|
||||||
|
_self = self_ref()
|
||||||
|
if not _self or getattr(_self, 'shutting_down', False):
|
||||||
|
return
|
||||||
|
_self.is_failed = True
|
||||||
|
proc_name = next(h.proc.name for h in workers
|
||||||
|
if h.proc.sentinel == died[0])
|
||||||
|
logger.error(
|
||||||
|
"Worker proc %s died unexpectedly, "
|
||||||
|
"shutting down executor.", proc_name)
|
||||||
|
_self.shutdown()
|
||||||
|
callback = _self.failure_callback
|
||||||
|
if callback is not None:
|
||||||
|
_self.failure_callback = None
|
||||||
|
callback()
|
||||||
|
|
||||||
|
Thread(target=monitor_workers,
|
||||||
|
daemon=True,
|
||||||
|
name="MultiprocWorkerMonitor").start()
|
||||||
|
|
||||||
|
def register_failure_callback(self, callback: FailureCallback):
|
||||||
|
if self.is_failed:
|
||||||
|
callback()
|
||||||
|
else:
|
||||||
|
self.failure_callback = callback
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output,
|
||||||
|
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
|
||||||
|
(output, ) = self.collective_rpc("execute_model",
|
||||||
|
args=(scheduler_output, ),
|
||||||
|
rank0_reply_only=True,
|
||||||
|
timeout=EXECUTE_MODEL_TIMEOUT_S)
|
||||||
|
return output
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = 180.0,
|
||||||
args: tuple = (),
|
args: tuple = (),
|
||||||
kwargs: Optional[dict] = None) -> list[Any]:
|
kwargs: Optional[dict] = None,
|
||||||
|
rank0_reply_only: bool = False) -> list[Any]:
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
if self.is_failed:
|
||||||
|
raise RuntimeError("Executor failed.")
|
||||||
|
|
||||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||||
# and unpack them in the method of every worker, because every worker
|
# and unpack them in the method of every worker, because every worker
|
||||||
# knows their own rank.
|
# knows their own rank.
|
||||||
@ -109,30 +170,30 @@ class MultiprocExecutor(Executor):
|
|||||||
else:
|
else:
|
||||||
send_method = cloudpickle.dumps(
|
send_method = cloudpickle.dumps(
|
||||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
|
self.rpc_broadcast_mq.enqueue(
|
||||||
|
(send_method, args, kwargs, rank0_reply_only))
|
||||||
|
|
||||||
responses = [None] * self.world_size
|
workers = (self.workers[0], ) if rank0_reply_only else self.workers
|
||||||
for w in self.workers:
|
responses = [None] * len(workers)
|
||||||
|
for w in workers:
|
||||||
dequeue_timeout = timeout - (time.monotonic() - start_time
|
dequeue_timeout = timeout - (time.monotonic() - start_time
|
||||||
) if timeout is not None else None
|
) if timeout is not None else None
|
||||||
status, result = w.worker_response_mq.dequeue(
|
status, result = w.worker_response_mq.dequeue(
|
||||||
timeout=dequeue_timeout)
|
timeout=dequeue_timeout, cancel=self.shutdown_event)
|
||||||
|
|
||||||
if status != WorkerProc.ResponseStatus.SUCCESS:
|
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Worker failed with error %s, please check the"
|
f"Worker failed with error '{result}', please check the"
|
||||||
" stack trace above for the root cause", result)
|
" stack trace above for the root cause")
|
||||||
|
|
||||||
responses[w.rank] = result
|
responses[w.rank] = result
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
except TimeoutError as e:
|
except TimeoutError as e:
|
||||||
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||||
except Exception as e:
|
|
||||||
# Re-raise any other exceptions
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def _ensure_worker_termination(self):
|
@staticmethod
|
||||||
|
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||||
"""Ensure that all worker processes are terminated. Assumes workers have
|
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||||
received termination requests. Waits for processing, then sends
|
received termination requests. Waits for processing, then sends
|
||||||
termination and kill signals if needed."""
|
termination and kill signals if needed."""
|
||||||
@ -150,7 +211,7 @@ class MultiprocExecutor(Executor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Send SIGTERM if still running
|
# Send SIGTERM if still running
|
||||||
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
|
active_procs = [proc for proc in worker_procs if proc.is_alive()]
|
||||||
for p in active_procs:
|
for p in active_procs:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
if not wait_for_termination(active_procs, 4):
|
if not wait_for_termination(active_procs, 4):
|
||||||
@ -159,22 +220,14 @@ class MultiprocExecutor(Executor):
|
|||||||
for p in active_procs:
|
for p in active_procs:
|
||||||
p.kill()
|
p.kill()
|
||||||
|
|
||||||
self._cleanup_sockets()
|
|
||||||
|
|
||||||
def _cleanup_sockets(self):
|
|
||||||
for w in self.workers:
|
|
||||||
# Remove the zmq ipc socket file
|
|
||||||
socket_path = w.ready_path.replace("ipc://", "")
|
|
||||||
if os and os.path.exists(socket_path):
|
|
||||||
os.remove(socket_path)
|
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Properly shut down the executor and its workers"""
|
"""Properly shut down the executor and its workers"""
|
||||||
if not getattr(self, 'shutting_down', False):
|
if not getattr(self, 'shutting_down', False):
|
||||||
self.shutting_down = True
|
self.shutting_down = True
|
||||||
|
self.shutdown_event.set()
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
w.worker_response_mq = None
|
w.worker_response_mq = None
|
||||||
self._ensure_worker_termination()
|
self._ensure_worker_termination([w.proc for w in self.workers])
|
||||||
|
|
||||||
self.rpc_broadcast_mq = None
|
self.rpc_broadcast_mq = None
|
||||||
|
|
||||||
@ -183,13 +236,30 @@ class MultiprocExecutor(Executor):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnreadyWorkerProcHandle:
|
||||||
|
"""WorkerProcess handle before READY."""
|
||||||
|
proc: BaseProcess
|
||||||
|
rank: int
|
||||||
|
ready_pipe: Connection
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WorkerProcHandle:
|
class WorkerProcHandle:
|
||||||
proc: BaseProcess
|
proc: BaseProcess
|
||||||
rank: int
|
rank: int
|
||||||
ready_path: str
|
|
||||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unready_handle(
|
||||||
|
cls, unready_handle: UnreadyWorkerProcHandle,
|
||||||
|
worker_response_mq: MessageQueue) -> "WorkerProcHandle":
|
||||||
|
return cls(
|
||||||
|
proc=unready_handle.proc,
|
||||||
|
rank=unready_handle.rank,
|
||||||
|
worker_response_mq=worker_response_mq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkerProc:
|
class WorkerProc:
|
||||||
"""Wrapper that runs one Worker in a separate process."""
|
"""Wrapper that runs one Worker in a separate process."""
|
||||||
@ -203,7 +273,6 @@ class WorkerProc:
|
|||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
input_shm_handle: Handle,
|
input_shm_handle: Handle,
|
||||||
ready_path: str,
|
|
||||||
):
|
):
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||||
@ -231,18 +300,8 @@ class WorkerProc:
|
|||||||
|
|
||||||
# Initializes a message queue for sending the model output
|
# Initializes a message queue for sending the model output
|
||||||
self.worker_response_mq = MessageQueue(1, 1)
|
self.worker_response_mq = MessageQueue(1, 1)
|
||||||
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
|
||||||
|
|
||||||
# Send Readiness signal to EngineCore process.
|
|
||||||
# Set linger here because we want to ensure the message has
|
|
||||||
# been sent before the context is closed.
|
|
||||||
with zmq_socket_ctx(ready_path, zmq.constants.PUSH,
|
|
||||||
linger=10000) as ready_socket:
|
|
||||||
payload = pickle.dumps(worker_response_mq_handle,
|
|
||||||
protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
ready_socket.send_string(WorkerProc.READY_STR)
|
|
||||||
ready_socket.send(payload)
|
|
||||||
|
|
||||||
|
# Initialize device and loads weights
|
||||||
self.worker.init_device()
|
self.worker.init_device()
|
||||||
self.worker.load_model()
|
self.worker.load_model()
|
||||||
|
|
||||||
@ -253,12 +312,10 @@ class WorkerProc:
|
|||||||
rank: int,
|
rank: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
input_shm_handle, # Receive SchedulerOutput
|
input_shm_handle, # Receive SchedulerOutput
|
||||||
) -> WorkerProcHandle:
|
) -> UnreadyWorkerProcHandle:
|
||||||
context = get_mp_context()
|
context = get_mp_context()
|
||||||
|
# (reader, writer)
|
||||||
# ZMQ path for worker to send ready message and shm_broadcast handle
|
reader, writer = context.Pipe(duplex=False)
|
||||||
# back to core process.
|
|
||||||
ready_path = get_open_zmq_ipc_path()
|
|
||||||
|
|
||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"vllm_config": vllm_config,
|
"vllm_config": vllm_config,
|
||||||
@ -266,24 +323,57 @@ class WorkerProc:
|
|||||||
"rank": rank,
|
"rank": rank,
|
||||||
"distributed_init_method": distributed_init_method,
|
"distributed_init_method": distributed_init_method,
|
||||||
"input_shm_handle": input_shm_handle,
|
"input_shm_handle": input_shm_handle,
|
||||||
"ready_path": ready_path,
|
"ready_pipe": (reader, writer),
|
||||||
}
|
}
|
||||||
# Run EngineCore busy loop in background process.
|
# Run EngineCore busy loop in background process.
|
||||||
proc = context.Process(target=WorkerProc.worker_main,
|
proc = context.Process(target=WorkerProc.worker_main,
|
||||||
kwargs=process_kwargs,
|
kwargs=process_kwargs,
|
||||||
|
name=f"VllmWorker-{rank}",
|
||||||
daemon=True)
|
daemon=True)
|
||||||
|
|
||||||
with zmq_socket_ctx(ready_path, zmq.constants.PULL) as ready_socket:
|
proc.start()
|
||||||
proc.start()
|
writer.close()
|
||||||
|
return UnreadyWorkerProcHandle(proc, rank, reader)
|
||||||
|
|
||||||
# Wait for startup
|
@staticmethod
|
||||||
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
def wait_for_ready(
|
||||||
proc, ready_socket)
|
unready_proc_handles: list[UnreadyWorkerProcHandle]
|
||||||
|
) -> list[WorkerProcHandle]:
|
||||||
|
|
||||||
worker_response_mq = MessageQueue.create_from_handle(
|
e = Exception("WorkerProc initialization failed due to "
|
||||||
worker_response_mq_handle, 0)
|
"an exception in a background process. "
|
||||||
|
"See stack trace for root cause.")
|
||||||
|
|
||||||
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
|
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
|
||||||
|
ready_proc_handles: list[Optional[WorkerProcHandle]] = (
|
||||||
|
[None] * len(unready_proc_handles))
|
||||||
|
while pipes:
|
||||||
|
ready = multiprocessing.connection.wait(pipes.keys())
|
||||||
|
for pipe in ready:
|
||||||
|
assert isinstance(pipe, Connection)
|
||||||
|
try:
|
||||||
|
# Wait until the WorkerProc is ready.
|
||||||
|
unready_proc_handle = pipes.pop(pipe)
|
||||||
|
response: dict[str, Any] = pipe.recv()
|
||||||
|
if response["status"] != "READY":
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Extract the message queue handle.
|
||||||
|
worker_response_mq = MessageQueue.create_from_handle(
|
||||||
|
response["handle"], 0)
|
||||||
|
ready_proc_handles[unready_proc_handle.rank] = (
|
||||||
|
WorkerProcHandle.from_unready_handle(
|
||||||
|
unready_proc_handle, worker_response_mq))
|
||||||
|
|
||||||
|
except EOFError:
|
||||||
|
e.__suppress_context__ = True
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Close connection.
|
||||||
|
pipe.close()
|
||||||
|
|
||||||
|
return cast(list[WorkerProcHandle], ready_proc_handles)
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
self.rpc_broadcast_mq = None
|
self.rpc_broadcast_mq = None
|
||||||
@ -312,51 +402,51 @@ class WorkerProc:
|
|||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
worker = None
|
worker = None
|
||||||
|
# tuple[Connection, Connection]
|
||||||
|
reader, ready_writer = kwargs.pop("ready_pipe")
|
||||||
try:
|
try:
|
||||||
|
reader.close()
|
||||||
worker = WorkerProc(*args, **kwargs)
|
worker = WorkerProc(*args, **kwargs)
|
||||||
|
|
||||||
|
# Send READY once we know everything is loaded
|
||||||
|
ready_writer.send({
|
||||||
|
"status":
|
||||||
|
WorkerProc.READY_STR,
|
||||||
|
"handle":
|
||||||
|
worker.worker_response_mq.export_handle(),
|
||||||
|
})
|
||||||
|
|
||||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||||
# Must be kept consistent with the Executor
|
# Must be kept consistent with the Executor
|
||||||
worker.rpc_broadcast_mq.wait_until_ready()
|
worker.rpc_broadcast_mq.wait_until_ready()
|
||||||
worker.worker_response_mq.wait_until_ready()
|
worker.worker_response_mq.wait_until_ready()
|
||||||
|
ready_writer.close()
|
||||||
|
ready_writer = None
|
||||||
|
|
||||||
worker.worker_busy_loop()
|
worker.worker_busy_loop()
|
||||||
|
|
||||||
except SystemExit:
|
|
||||||
logger.debug("Worker interrupted.")
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# worker_busy_loop sends exceptions to Executor
|
# NOTE: if an Exception arises in busy_loop, we send
|
||||||
# for shutdown, but if there is an error in startup or an
|
# a FAILURE message over the MQ RPC to notify the Executor,
|
||||||
# error with IPC itself, we need to alert the parent.
|
# which triggers system shutdown.
|
||||||
psutil.Process().parent().send_signal(signal.SIGUSR1)
|
# TODO(rob): handle case where the MQ itself breaks.
|
||||||
raise
|
|
||||||
|
if ready_writer is not None:
|
||||||
|
logger.exception("WorkerProc failed to start.")
|
||||||
|
else:
|
||||||
|
logger.exception("WorkerProc failed.")
|
||||||
|
|
||||||
|
# The parent sends a SIGTERM to all worker processes if
|
||||||
|
# any worker dies. Set this value so we don't re-throw
|
||||||
|
# SystemExit() to avoid zmq exceptions in __del__.
|
||||||
|
shutdown_requested = True
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
if ready_writer is not None:
|
||||||
|
ready_writer.close()
|
||||||
# Clean up once worker exits busy loop
|
# Clean up once worker exits busy loop
|
||||||
if worker is not None:
|
if worker is not None:
|
||||||
worker.shutdown()
|
worker.shutdown()
|
||||||
worker = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def wait_for_startup(
|
|
||||||
proc: BaseProcess,
|
|
||||||
ready_socket: zmq.Socket,
|
|
||||||
) -> Optional[Handle]:
|
|
||||||
"""Wait until the Worker is ready."""
|
|
||||||
|
|
||||||
# Wait for Worker to send READY.
|
|
||||||
while ready_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
|
||||||
logger.debug("Waiting for WorkerProc to startup.")
|
|
||||||
|
|
||||||
if not proc.is_alive():
|
|
||||||
raise RuntimeError("WorkerProc failed to start.")
|
|
||||||
|
|
||||||
message = ready_socket.recv_string()
|
|
||||||
assert message == WorkerProc.READY_STR
|
|
||||||
handle_frame = ready_socket.recv(copy=False)
|
|
||||||
handle = pickle.loads(handle_frame.buffer)
|
|
||||||
return handle
|
|
||||||
|
|
||||||
class ResponseStatus(Enum):
|
class ResponseStatus(Enum):
|
||||||
SUCCESS = auto()
|
SUCCESS = auto()
|
||||||
@ -365,7 +455,7 @@ class WorkerProc:
|
|||||||
def worker_busy_loop(self):
|
def worker_busy_loop(self):
|
||||||
"""Main busy loop for Multiprocessing Workers"""
|
"""Main busy loop for Multiprocessing Workers"""
|
||||||
while True:
|
while True:
|
||||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
@ -377,12 +467,14 @@ class WorkerProc:
|
|||||||
# Notes have been introduced in python 3.11
|
# Notes have been introduced in python 3.11
|
||||||
if hasattr(e, "add_note"):
|
if hasattr(e, "add_note"):
|
||||||
e.add_note(traceback.format_exc())
|
e.add_note(traceback.format_exc())
|
||||||
logger.exception("WorkerProc hit an exception: %s", exc_info=e)
|
logger.exception("WorkerProc hit an exception.")
|
||||||
# exception might not be serializable, so we convert it to
|
# exception might not be serializable, so we convert it to
|
||||||
# string, only for logging purpose.
|
# string, only for logging purpose.
|
||||||
self.worker_response_mq.enqueue(
|
if not rank0_only or self.rank == 0:
|
||||||
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
self.worker_response_mq.enqueue(
|
||||||
|
(WorkerProc.ResponseStatus.FAILURE, str(e)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.worker_response_mq.enqueue(
|
if not rank0_only or self.rank == 0:
|
||||||
(WorkerProc.ResponseStatus.SUCCESS, output))
|
self.worker_response_mq.enqueue(
|
||||||
|
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user