[Core][Bugfix][Perf] Introduce MQLLMEngine
to avoid asyncio
OH (#8157)
Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
9d104b5beb
commit
7c7714d856
@ -43,13 +43,15 @@ steps:
|
|||||||
fast_check: true
|
fast_check: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
|
- tests/mq_llm_engine
|
||||||
- tests/async_engine
|
- tests/async_engine
|
||||||
- tests/test_inputs
|
- tests/test_inputs
|
||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/test_utils
|
- tests/test_utils
|
||||||
- tests/worker
|
- tests/worker
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s async_engine # Async Engine
|
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||||
|
- pytest -v -s async_engine # AsyncLLMEngine
|
||||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s multimodal
|
- pytest -v -s multimodal
|
||||||
|
@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/.
|
|||||||
.. tip::
|
.. tip::
|
||||||
|
|
||||||
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
||||||
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
|
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
|
||||||
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
|
``export VLLM_RPC_TIMEOUT=1800000``
|
||||||
|
|
||||||
Example commands and usage:
|
Example commands and usage:
|
||||||
===========================
|
===========================
|
||||||
|
@ -1,106 +0,0 @@
|
|||||||
import openai # use the official client for correctness check
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
|
||||||
|
|
||||||
# any model with a chat template should work here
|
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
|
||||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
|
||||||
assert chatml_jinja_path.exists()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server():
|
|
||||||
args = [
|
|
||||||
# use half precision for speed and memory savings in CI environment
|
|
||||||
"--dtype",
|
|
||||||
"float16",
|
|
||||||
"--max-model-len",
|
|
||||||
"2048",
|
|
||||||
"--enforce-eager",
|
|
||||||
"--chat-template",
|
|
||||||
str(chatml_jinja_path),
|
|
||||||
]
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
|
||||||
yield remote_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def client(server):
|
|
||||||
async with server.get_async_client() as async_client:
|
|
||||||
yield async_client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_models(client: openai.AsyncOpenAI):
|
|
||||||
models = await client.models.list()
|
|
||||||
models = models.data
|
|
||||||
served_model = models[0]
|
|
||||||
assert served_model.id == MODEL_NAME
|
|
||||||
assert all(model.root == MODEL_NAME for model in models)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_completion(client: openai.AsyncOpenAI):
|
|
||||||
completion = await client.completions.create(model=MODEL_NAME,
|
|
||||||
prompt="Hello, my name is",
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0)
|
|
||||||
|
|
||||||
assert completion.id is not None
|
|
||||||
assert len(completion.choices) == 1
|
|
||||||
assert len(completion.choices[0].text) >= 5
|
|
||||||
assert completion.choices[0].finish_reason == "length"
|
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
||||||
|
|
||||||
# test using token IDs
|
|
||||||
completion = await client.completions.create(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
prompt=[0, 0, 0, 0, 0],
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
assert len(completion.choices[0].text) >= 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_chat_session(client: openai.AsyncOpenAI):
|
|
||||||
messages = [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "you are a helpful assistant"
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
|
||||||
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
logprobs=True,
|
|
||||||
top_logprobs=5)
|
|
||||||
assert chat_completion.id is not None
|
|
||||||
assert len(chat_completion.choices) == 1
|
|
||||||
|
|
||||||
choice = chat_completion.choices[0]
|
|
||||||
assert choice.finish_reason == "length"
|
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
|
||||||
completion_tokens=10, prompt_tokens=55, total_tokens=65)
|
|
||||||
|
|
||||||
message = choice.message
|
|
||||||
assert message.content is not None and len(message.content) >= 10
|
|
||||||
assert message.role == "assistant"
|
|
||||||
messages.append({"role": "assistant", "content": message.content})
|
|
||||||
|
|
||||||
# test multi-turn dialogue
|
|
||||||
messages.append({"role": "user", "content": "express your result in json"})
|
|
||||||
chat_completion = await client.chat.completions.create(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
)
|
|
||||||
message = chat_completion.choices[0].message
|
|
||||||
assert message.content is not None and len(message.content) >= 0
|
|
@ -1,120 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
import unittest.mock
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
||||||
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
|
|
||||||
RPCClientClosedError)
|
|
||||||
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
def tmp_socket():
|
|
||||||
with tempfile.TemporaryDirectory() as td:
|
|
||||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def dummy_server(tmp_socket, monkeypatch):
|
|
||||||
dummy_engine = unittest.mock.AsyncMock()
|
|
||||||
|
|
||||||
def dummy_engine_builder(*args, **kwargs):
|
|
||||||
return dummy_engine
|
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
|
|
||||||
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
server_task = loop.create_task(server.run_server_loop())
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield server
|
|
||||||
finally:
|
|
||||||
server_task.cancel()
|
|
||||||
server.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def client(tmp_socket):
|
|
||||||
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
|
|
||||||
# Sanity check: the server is connected
|
|
||||||
await client._wait_for_server_rpc()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield client
|
|
||||||
finally:
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
|
|
||||||
client: AsyncEngineRPCClient):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
# Make the server _not_ reply with a model config
|
|
||||||
m.setattr(dummy_server, "get_config", lambda x: None)
|
|
||||||
m.setattr(client, "_data_timeout", 10)
|
|
||||||
|
|
||||||
# And ensure the task completes anyway
|
|
||||||
# (client.setup() invokes server.get_config())
|
|
||||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
|
||||||
with pytest.raises(TimeoutError, match="Server didn't reply within"):
|
|
||||||
await asyncio.wait_for(client_task, timeout=0.05)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
|
|
||||||
client: AsyncEngineRPCClient):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
# Hang all abort requests
|
|
||||||
m.setattr(dummy_server, "abort", lambda x: None)
|
|
||||||
m.setattr(client, "_data_timeout", 10)
|
|
||||||
|
|
||||||
# The client should suppress timeouts on `abort`s
|
|
||||||
# and return normally, assuming the server will eventually
|
|
||||||
# abort the request.
|
|
||||||
client_task = asyncio.get_running_loop().create_task(
|
|
||||||
client.abort("test request id"))
|
|
||||||
await asyncio.wait_for(client_task, timeout=0.05)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_client_data_methods_reraise_exceptions(
|
|
||||||
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
# Make the server raise some random exception
|
|
||||||
exception = RuntimeError("Client test exception")
|
|
||||||
|
|
||||||
def raiser():
|
|
||||||
raise exception
|
|
||||||
|
|
||||||
m.setattr(dummy_server.engine, "get_model_config", raiser)
|
|
||||||
m.setattr(client, "_data_timeout", 10)
|
|
||||||
|
|
||||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
|
||||||
# And ensure the task completes, raising the exception
|
|
||||||
with pytest.raises(RuntimeError, match=str(exception)):
|
|
||||||
await asyncio.wait_for(client_task, timeout=0.05)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_client_errors_after_closing(monkeypatch, dummy_server,
|
|
||||||
client: AsyncEngineRPCClient):
|
|
||||||
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
# Healthchecks and generate requests will fail with explicit errors
|
|
||||||
with pytest.raises(RPCClientClosedError):
|
|
||||||
await client.check_health()
|
|
||||||
with pytest.raises(RPCClientClosedError):
|
|
||||||
async for _ in client.generate(None, None, None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# But no-ops like aborting will pass
|
|
||||||
await client.abort("test-request-id")
|
|
||||||
await client.do_log_stats()
|
|
@ -18,38 +18,32 @@ TASK = "gsm8k"
|
|||||||
FILTER = "exact_match,strict-match"
|
FILTER = "exact_match,strict-match"
|
||||||
RTOL = 0.03
|
RTOL = 0.03
|
||||||
EXPECTED_VALUE = 0.58
|
EXPECTED_VALUE = 0.58
|
||||||
|
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||||
|
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||||
def server():
|
def test_lm_eval_accuracy(more_args):
|
||||||
args = [
|
args = list(DEFAULT_ARGS)
|
||||||
"--max-model-len", "4096", "--enable-chunked-prefill",
|
args.extend(more_args)
|
||||||
"--disable-log-requests", "--enforce-eager"
|
|
||||||
]
|
print(f"Running with: {args}")
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
yield remote_server
|
url = f"{remote_server.url_for('v1')}/completions"
|
||||||
|
|
||||||
|
model_args = (
|
||||||
|
f"model={MODEL_NAME},"
|
||||||
|
f"base_url={url},"
|
||||||
|
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
results = lm_eval.simple_evaluate(
|
||||||
def server_data(server):
|
model="local-completions",
|
||||||
return {
|
model_args=model_args,
|
||||||
"url": f"{server.url_for('v1')}/completions",
|
tasks=TASK,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
measured_value = results["results"][TASK][FILTER]
|
||||||
def test_lm_eval_accuracy(server_data):
|
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||||
model_args = (f"model={MODEL_NAME},"
|
and measured_value + RTOL > EXPECTED_VALUE
|
||||||
f"base_url={server_data['url']},"
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
|
||||||
|
|
||||||
results = lm_eval.simple_evaluate(
|
|
||||||
model="local-completions",
|
|
||||||
model_args=model_args,
|
|
||||||
tasks=TASK,
|
|
||||||
)
|
|
||||||
|
|
||||||
measured_value = results["results"][TASK][FILTER]
|
|
||||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
|
||||||
and measured_value + RTOL > EXPECTED_VALUE
|
|
||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
|
||||||
|
@ -5,7 +5,7 @@ from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
|||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ..utils import VLLM_PATH
|
from ...utils import VLLM_PATH
|
||||||
|
|
||||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
assert chatml_jinja_path.exists()
|
assert chatml_jinja_path.exists()
|
@ -1,40 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mp_crash_detection():
|
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
|
||||||
parser = make_arg_parser(parser)
|
|
||||||
args = parser.parse_args([])
|
|
||||||
# use an invalid tensor_parallel_size to trigger the
|
|
||||||
# error in the server
|
|
||||||
args.tensor_parallel_size = 65536
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
async with build_async_engine_client(args):
|
|
||||||
pass
|
|
||||||
end = time.perf_counter()
|
|
||||||
|
|
||||||
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
|
|
||||||
"if there is an error in the startup.")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mp_cuda_init():
|
|
||||||
# it should not crash, when cuda is initialized
|
|
||||||
# in the API server process
|
|
||||||
import torch
|
|
||||||
torch.cuda.init()
|
|
||||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
|
||||||
parser = make_arg_parser(parser)
|
|
||||||
args = parser.parse_args([])
|
|
||||||
|
|
||||||
async with build_async_engine_client(args):
|
|
||||||
pass
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
|||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from vllm.config import MultiModalConfig
|
from vllm.config import MultiModalConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
@ -52,8 +52,9 @@ def test_async_serving_chat_init():
|
|||||||
|
|
||||||
|
|
||||||
def test_serving_chat_should_set_correct_max_tokens():
|
def test_serving_chat_should_set_correct_max_tokens():
|
||||||
mock_engine = MagicMock(spec=AsyncLLMEngine)
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
|
||||||
serving_chat = OpenAIServingChat(mock_engine,
|
serving_chat = OpenAIServingChat(mock_engine,
|
||||||
MockModelConfig(),
|
MockModelConfig(),
|
||||||
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoraAdapterRequest,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
@ -18,7 +18,7 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
|||||||
|
|
||||||
|
|
||||||
async def _async_serving_engine_init():
|
async def _async_serving_engine_init():
|
||||||
mock_engine_client = MagicMock(spec=AsyncEngineClient)
|
mock_engine_client = MagicMock(spec=EngineClient)
|
||||||
mock_model_config = MagicMock(spec=ModelConfig)
|
mock_model_config = MagicMock(spec=ModelConfig)
|
||||||
# Set the max_model_len attribute to avoid missing attribute
|
# Set the max_model_len attribute to avoid missing attribute
|
||||||
mock_model_config.max_model_len = 2048
|
mock_model_config.max_model_len = 2048
|
||||||
|
@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
|||||||
prompt="Hello, my name is")
|
prompt="Hello, my name is")
|
||||||
|
|
||||||
# Now the server should shut down
|
# Now the server should shut down
|
||||||
return_code = remote_server.proc.wait(timeout=3)
|
return_code = remote_server.proc.wait(timeout=8)
|
||||||
assert return_code is not None
|
assert return_code is not None
|
||||||
|
67
tests/mq_llm_engine/test_abort.py
Normal file
67
tests/mq_llm_engine/test_abort.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""Test that aborting is handled properly."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
|
MODEL = "google/gemma-1.1-2b-it"
|
||||||
|
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||||
|
RAISED_ERROR = KeyError
|
||||||
|
RAISED_VALUE = "foo"
|
||||||
|
EXPECTED_TOKENS = 250
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tmp_socket():
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abort(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
|
||||||
|
request_id_to_be_aborted = "request-aborted"
|
||||||
|
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
|
||||||
|
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
|
||||||
|
|
||||||
|
# Requests started before one to be aborted.
|
||||||
|
tasks = []
|
||||||
|
for request_id in request_ids_a:
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
generate(client, request_id, EXPECTED_TOKENS)))
|
||||||
|
|
||||||
|
# Aborted.
|
||||||
|
task_aborted = asyncio.create_task(
|
||||||
|
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
|
||||||
|
|
||||||
|
# Requests started after one to be aborted.
|
||||||
|
for request_id in request_ids_b:
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
generate(client, request_id, EXPECTED_TOKENS)))
|
||||||
|
|
||||||
|
# Actually abort.
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
await client.abort(request_id_to_be_aborted)
|
||||||
|
|
||||||
|
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||||
|
for task in tasks:
|
||||||
|
count, request_id = await task
|
||||||
|
assert count == EXPECTED_TOKENS, (
|
||||||
|
f"{request_id} generated only {count} tokens")
|
||||||
|
|
||||||
|
# Cancel task (this will hang indefinitely if not).
|
||||||
|
task_aborted.cancel()
|
||||||
|
|
||||||
|
# Shutdown.
|
||||||
|
client.close()
|
244
tests/mq_llm_engine/test_error_handling.py
Normal file
244
tests/mq_llm_engine/test_error_handling.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
"""Test that various errors are handled properly."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
|
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||||
|
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||||
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
MODEL = "google/gemma-1.1-2b-it"
|
||||||
|
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||||
|
RAISED_ERROR = KeyError
|
||||||
|
RAISED_VALUE = "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tmp_socket():
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||||
|
# Make engine.
|
||||||
|
engine = MQLLMEngine.from_engine_args(
|
||||||
|
engine_args=engine_args,
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
# Raise error during first forward pass.
|
||||||
|
engine.engine.model_executor.execute_model = Mock(
|
||||||
|
side_effect=RAISED_ERROR(RAISED_VALUE))
|
||||||
|
|
||||||
|
# Run engine.
|
||||||
|
engine.start()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evil_forward(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket,
|
||||||
|
run_fn=run_with_evil_forward) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
|
||||||
|
# Server should be healthy after initial probe.
|
||||||
|
await asyncio.sleep(2.0)
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
# Throws an error in first forward pass.
|
||||||
|
with pytest.raises(RAISED_ERROR):
|
||||||
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id=uuid.uuid4()):
|
||||||
|
pass
|
||||||
|
assert client.errored
|
||||||
|
|
||||||
|
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
||||||
|
with pytest.raises(MQEngineDeadError):
|
||||||
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id=uuid.uuid4()):
|
||||||
|
pass
|
||||||
|
assert client.errored
|
||||||
|
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
with pytest.raises(RAISED_ERROR):
|
||||||
|
await client.check_health()
|
||||||
|
assert client.errored
|
||||||
|
|
||||||
|
# Shutdown.
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
|
||||||
|
ipc_path: str):
|
||||||
|
# Make engine.
|
||||||
|
engine = MQLLMEngine.from_engine_args(
|
||||||
|
engine_args=engine_args,
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
# Raise error during first forward pass.
|
||||||
|
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
|
||||||
|
|
||||||
|
# Run engine.
|
||||||
|
engine.start()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_health_check(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(
|
||||||
|
engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket,
|
||||||
|
run_fn=run_with_evil_model_executor_health) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
assert client.is_running
|
||||||
|
|
||||||
|
# Health probe should throw RAISED_ERROR.
|
||||||
|
await asyncio.sleep(15.)
|
||||||
|
|
||||||
|
with pytest.raises(RAISED_ERROR):
|
||||||
|
await client.check_health()
|
||||||
|
assert client.errored
|
||||||
|
|
||||||
|
# Generate call should throw ENGINE_DEAD_ERROR
|
||||||
|
with pytest.raises(MQEngineDeadError):
|
||||||
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id=uuid.uuid4()):
|
||||||
|
pass
|
||||||
|
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||||
|
# Make engine.
|
||||||
|
engine = MQLLMEngine.from_engine_args(
|
||||||
|
engine_args=engine_args,
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
# Raise error during abort call.
|
||||||
|
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
|
||||||
|
|
||||||
|
# Run engine.
|
||||||
|
engine.start()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_abort(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket,
|
||||||
|
run_fn=run_with_evil_abort) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
assert client.is_running
|
||||||
|
|
||||||
|
# Firsh check health should work.
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
# Trigger an abort on the client side.
|
||||||
|
async def bad_abort_after_2s():
|
||||||
|
await asyncio.sleep(2.0)
|
||||||
|
await client.abort(request_id="foo")
|
||||||
|
|
||||||
|
# Trigger an abort in 2s from now.
|
||||||
|
abort_task = asyncio.create_task(bad_abort_after_2s())
|
||||||
|
|
||||||
|
# Exception in abort() will happen during this generation.
|
||||||
|
# This will kill the engine and should return ENGINE_DEAD_ERROR
|
||||||
|
# with reference to the original KeyError("foo")
|
||||||
|
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||||
|
async for _ in client.generate(
|
||||||
|
inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(max_tokens=2000),
|
||||||
|
request_id=uuid.uuid4()):
|
||||||
|
pass
|
||||||
|
assert "KeyError" in repr(execinfo.value)
|
||||||
|
assert client.errored
|
||||||
|
|
||||||
|
await abort_task
|
||||||
|
|
||||||
|
# This should raise the original error.
|
||||||
|
with pytest.raises(RAISED_ERROR):
|
||||||
|
await client.check_health()
|
||||||
|
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bad_request(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
|
||||||
|
# Invalid request should fail, but not crash the server.
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id="abcd-1",
|
||||||
|
lora_request=LoRARequest(
|
||||||
|
"invalid-lora", 1,
|
||||||
|
"invalid-path")):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# This request should be okay.
|
||||||
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
request_id="abcd-2"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Shutdown.
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mp_crash_detection(monkeypatch):
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
|
||||||
|
# When LLMEngine is loaded, it will crash.
|
||||||
|
def mock_init():
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
async with build_async_engine_client(args):
|
||||||
|
pass
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
|
||||||
|
"if there is an error in the startup.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mp_cuda_init():
|
||||||
|
# it should not crash, when cuda is initialized
|
||||||
|
# in the API server process
|
||||||
|
import torch
|
||||||
|
torch.cuda.init()
|
||||||
|
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
|
||||||
|
async with build_async_engine_client(args):
|
||||||
|
pass
|
57
tests/mq_llm_engine/test_load.py
Normal file
57
tests/mq_llm_engine/test_load.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
|
||||||
|
MODEL = "google/gemma-1.1-2b-it"
|
||||||
|
NUM_EXPECTED_TOKENS = 10
|
||||||
|
NUM_REQUESTS = 10000
|
||||||
|
|
||||||
|
# Scenarios to test for num generated token.
|
||||||
|
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tmp_socket():
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load(tmp_socket):
|
||||||
|
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||||
|
ipc_path=tmp_socket) as engine:
|
||||||
|
|
||||||
|
client = await engine.make_client()
|
||||||
|
|
||||||
|
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||||
|
|
||||||
|
# Create concurrent requests.
|
||||||
|
tasks = []
|
||||||
|
for request_id in request_ids:
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
generate(client, request_id, NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
|
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||||
|
failed_request_id = None
|
||||||
|
tokens = None
|
||||||
|
for task in tasks:
|
||||||
|
num_generated_tokens, request_id = await task
|
||||||
|
if (num_generated_tokens != NUM_EXPECTED_TOKENS
|
||||||
|
and failed_request_id is None):
|
||||||
|
failed_request_id = request_id
|
||||||
|
tokens = num_generated_tokens
|
||||||
|
|
||||||
|
assert failed_request_id is None, (
|
||||||
|
f"{failed_request_id} generated {tokens} but "
|
||||||
|
f"expected {NUM_EXPECTED_TOKENS}")
|
||||||
|
|
||||||
|
# Shutdown.
|
||||||
|
client.close()
|
78
tests/mq_llm_engine/utils.py
Normal file
78
tests/mq_llm_engine/utils.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import asyncio
|
||||||
|
import multiprocessing
|
||||||
|
from typing import Callable, Tuple, Union
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
client: MQLLMEngineClient,
|
||||||
|
request_id: str,
|
||||||
|
num_tokens: int,
|
||||||
|
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
|
||||||
|
|
||||||
|
final_output = None
|
||||||
|
count = 0
|
||||||
|
async for out in client.generate(
|
||||||
|
request_id=request_id,
|
||||||
|
inputs="Hello my name is Robert and",
|
||||||
|
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||||
|
temperature=0)):
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
final_output = out
|
||||||
|
await asyncio.sleep(0.)
|
||||||
|
|
||||||
|
if return_output:
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
# Confirm we generated all the tokens we expected.
|
||||||
|
return count, request_id
|
||||||
|
|
||||||
|
|
||||||
|
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||||
|
# Make engine.
|
||||||
|
engine = MQLLMEngine.from_engine_args(
|
||||||
|
engine_args=engine_args,
|
||||||
|
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
|
# Run engine.
|
||||||
|
engine.start()
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteMQLLMEngine:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
engine_args: AsyncEngineArgs,
|
||||||
|
ipc_path: str,
|
||||||
|
run_fn: Callable = run_normal) -> None:
|
||||||
|
|
||||||
|
self.engine_args = engine_args
|
||||||
|
self.ipc_path = ipc_path
|
||||||
|
context = multiprocessing.get_context("spawn")
|
||||||
|
self.proc = context.Process(target=run_fn,
|
||||||
|
args=(engine_args, ipc_path))
|
||||||
|
self.proc.start()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.proc.kill()
|
||||||
|
|
||||||
|
async def make_client(self) -> MQLLMEngineClient:
|
||||||
|
engine_config = self.engine_args.create_engine_config()
|
||||||
|
client = MQLLMEngineClient(self.ipc_path, engine_config)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await client.setup()
|
||||||
|
break
|
||||||
|
except TimeoutError:
|
||||||
|
assert self.proc.is_alive()
|
||||||
|
return client
|
@ -1,5 +1,12 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
from ..utils import compare_two_settings
|
from ..utils import compare_two_settings
|
||||||
|
|
||||||
|
# --enforce-eager on TPU causes graph compilation
|
||||||
|
# this times out default Health Check in the MQLLMEngine,
|
||||||
|
# so we set the timeout here to 30s
|
||||||
|
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
|
||||||
|
|
||||||
|
|
||||||
def test_custom_dispatcher():
|
def test_custom_dispatcher():
|
||||||
compare_two_settings("google/gemma-2b",
|
compare_two_settings("google/gemma-2b",
|
||||||
|
@ -119,7 +119,7 @@ class RemoteOpenAIServer:
|
|||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
self.proc.terminate()
|
self.proc.terminate()
|
||||||
try:
|
try:
|
||||||
self.proc.wait(3)
|
self.proc.wait(8)
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
# force kill if needed
|
# force kill if needed
|
||||||
self.proc.kill()
|
self.proc.kill()
|
||||||
|
@ -601,9 +601,12 @@ class AsyncLLMEngine:
|
|||||||
return self._errored_with is not None
|
return self._errored_with is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit_concurrency(self) -> Optional[int]:
|
def dead_error(self) -> BaseException:
|
||||||
"""Maximum number of concurrently running requests."""
|
return AsyncEngineDeadError(
|
||||||
return None
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
def set_errored(self, exc: Exception) -> None:
|
def set_errored(self, exc: Exception) -> None:
|
||||||
self._errored_with = exc
|
self._errored_with = exc
|
||||||
|
@ -1289,6 +1289,7 @@ class LLMEngine:
|
|||||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||||
# the RPC thread in the workers so that they can process any other
|
# the RPC thread in the workers so that they can process any other
|
||||||
# queued control plane messages, such as add/remove lora adapters.
|
# queued control plane messages, such as add/remove lora adapters.
|
||||||
|
logger.debug("Stopping remote worker execution loop.")
|
||||||
self.model_executor.stop_remote_worker_execution_loop()
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
return ctx.request_outputs
|
return ctx.request_outputs
|
||||||
|
73
vllm/engine/multiprocessing/__init__.py
Normal file
73
vllm/engine/multiprocessing/__init__.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Mapping, Optional, Union
|
||||||
|
|
||||||
|
from vllm.inputs import PromptInputs
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||||
|
|
||||||
|
IPC_INPUT_EXT = "_input_socket"
|
||||||
|
IPC_OUTPUT_EXT = "_output_socket"
|
||||||
|
IPC_HEALTH_EXT = "_health_socket"
|
||||||
|
IPC_DATA_EXT = "_data_socket"
|
||||||
|
|
||||||
|
|
||||||
|
class MQEngineDeadError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RPCGenerateRequest:
|
||||||
|
inputs: PromptInputs
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
request_id: str
|
||||||
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RPCError:
|
||||||
|
request_id: Optional[str]
|
||||||
|
is_engine_errored: bool
|
||||||
|
exception: BaseException
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RPCAbortRequest:
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class RPCHealthRequest:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RPCStartupRequest(Enum):
|
||||||
|
IS_SERVER_READY = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RPCStartupResponse:
|
||||||
|
tracing_enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
|
||||||
|
RPCStartupRequest]
|
||||||
|
|
||||||
|
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
||||||
|
|
||||||
|
|
||||||
|
def ENGINE_DEAD_ERROR(
|
||||||
|
error: Optional[BaseException] = None) -> MQEngineDeadError:
|
||||||
|
if error is None:
|
||||||
|
return MQEngineDeadError(
|
||||||
|
"Engine loop is not running. Inspect the stacktrace to "
|
||||||
|
"find the original error")
|
||||||
|
|
||||||
|
return MQEngineDeadError(
|
||||||
|
"Engine loop is not running. Inspect the stacktrace to "
|
||||||
|
f"find the original error: {repr(error)}.")
|
452
vllm/engine/multiprocessing/client.py
Normal file
452
vllm/engine/multiprocessing/client.py
Normal file
@ -0,0 +1,452 @@
|
|||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import pickle
|
||||||
|
from contextlib import contextmanager, suppress
|
||||||
|
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
|
||||||
|
Union)
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
from zmq import Frame # type: ignore[attr-defined]
|
||||||
|
from zmq.asyncio import Socket
|
||||||
|
|
||||||
|
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||||
|
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||||
|
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||||
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
|
RPCError, RPCGenerateRequest,
|
||||||
|
RPCHealthRequest, RPCStartupRequest,
|
||||||
|
RPCStartupResponse)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||||
|
from vllm.inputs import PromptInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MQClientClosedError(Exception):
|
||||||
|
"""Exception class raised when the client is used post-close.
|
||||||
|
|
||||||
|
The client can be closed, which closes the ZMQ context. This normally
|
||||||
|
happens on server shutdown. In some cases, methods like abort and
|
||||||
|
do_log_stats will still be called and then try to open a socket, which
|
||||||
|
causes a ZMQError and creates a huge stack trace.
|
||||||
|
So, we throw this error such that we can suppress it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MQLLMEngineClient:
|
||||||
|
"""A client wrapper for MQLLMEngine that conforms to the
|
||||||
|
EngineClient protocol.
|
||||||
|
|
||||||
|
MQLLMEngine and MQLLMEngineClient are intended to run in separate
|
||||||
|
processes communicating via zeromq ipc sockets.
|
||||||
|
|
||||||
|
The entrypoint to MQLLMEngineClient is through the generate()
|
||||||
|
method. On generate() MQLLMEngine does three things:
|
||||||
|
- Creates an asyncio output queue
|
||||||
|
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
|
||||||
|
- Pulls RequestOutputs from its queue and yields them
|
||||||
|
|
||||||
|
MQLLMEngine runs two background loops:
|
||||||
|
- output_loop: the output loop pulls List[RequestOutput]
|
||||||
|
from the MQLLMEngine via zmq (each list is the output
|
||||||
|
of one engine_step in the LLMEngine). It then parses
|
||||||
|
the list and pushes individual request_outputs into
|
||||||
|
the corresponding output_queue such that they can be
|
||||||
|
consumed by the .generate() method.
|
||||||
|
- health_loop: the health loop queries the health socket
|
||||||
|
every N seconds, confirming the engine is healthy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ipc_path: str, engine_config: EngineConfig):
|
||||||
|
self.context = zmq.asyncio.Context()
|
||||||
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
|
# Get the configs.
|
||||||
|
self.model_config = engine_config.model_config
|
||||||
|
self.decoding_config = engine_config.decoding_config
|
||||||
|
|
||||||
|
# Create the tokenizer group.
|
||||||
|
self.tokenizer = init_tokenizer_from_configs(
|
||||||
|
model_config=self.model_config,
|
||||||
|
scheduler_config=engine_config.scheduler_config,
|
||||||
|
parallel_config=engine_config.parallel_config,
|
||||||
|
enable_lora=bool(engine_config.lora_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send RPCGenerateRequest to the MQLLMEngine.
|
||||||
|
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||||
|
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||||
|
|
||||||
|
# Receive streams of RequestOutput from the MQLLMEngine.
|
||||||
|
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||||
|
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||||
|
|
||||||
|
# IPC path for ack of check_health requests.
|
||||||
|
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||||
|
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||||
|
|
||||||
|
# IPC path for the data socket.
|
||||||
|
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||||
|
|
||||||
|
# Stream for each individual request.
|
||||||
|
self.output_queues: Dict[str, asyncio.Queue] = {}
|
||||||
|
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
|
||||||
|
|
||||||
|
# Loop to check health of the LLMEngine periodically.
|
||||||
|
# Started after the MQLLMEngine is ready.
|
||||||
|
self.health_loop: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||||
|
if engine_args.pipeline_parallel_size > 1:
|
||||||
|
return True
|
||||||
|
|
||||||
|
is_embedding = ModelConfig(
|
||||||
|
model=engine_args.model,
|
||||||
|
revision=engine_args.revision,
|
||||||
|
tokenizer=engine_args.model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=engine_args.trust_remote_code,
|
||||||
|
quantization=engine_args.quantization,
|
||||||
|
seed=0,
|
||||||
|
dtype="auto").embedding_mode
|
||||||
|
|
||||||
|
return is_embedding
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_data_socket(self) -> Iterator[Socket]:
|
||||||
|
socket = self.context.socket(zmq.constants.DEALER)
|
||||||
|
try:
|
||||||
|
socket.connect(self.data_ipc_path)
|
||||||
|
yield socket
|
||||||
|
finally:
|
||||||
|
socket.close(linger=0)
|
||||||
|
|
||||||
|
async def run_check_health_loop(self, timeout: int):
|
||||||
|
"""Background loop that continually probes the RPCServer for health.
|
||||||
|
|
||||||
|
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
|
||||||
|
the MQLLMEngine server is blocking on.
|
||||||
|
|
||||||
|
The Server replies on the HEALTH_SOCKET (rather than on the
|
||||||
|
OUTPUT_SOCKET such that the messages are not intermingled with
|
||||||
|
output streaming).
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if await self.health_socket.poll(timeout=timeout) == 0:
|
||||||
|
# Wakeup every N seconds and do a health probe.
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
RPCHealthRequest(), self.input_socket)
|
||||||
|
|
||||||
|
# Wait for ack from the health socket.
|
||||||
|
await self._await_ack(error_message="Health check failed.",
|
||||||
|
socket=self.health_socket)
|
||||||
|
else:
|
||||||
|
# Server sent a health status message unprompted.
|
||||||
|
await self._check_success(
|
||||||
|
error_message="Health check failed.",
|
||||||
|
socket=self.health_socket)
|
||||||
|
|
||||||
|
logger.debug("Health probe successful.")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._set_errored(e)
|
||||||
|
|
||||||
|
async def run_output_handler_loop(self):
|
||||||
|
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Poll, checking for ENGINE_DEAD
|
||||||
|
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
|
||||||
|
) == 0:
|
||||||
|
logger.debug("Waiting for output from MQLLMEngine.")
|
||||||
|
|
||||||
|
# If errored, alert all running requests.
|
||||||
|
if self.errored:
|
||||||
|
for queue_j in tuple(self.output_queues.values()):
|
||||||
|
queue_j.put_nowait(
|
||||||
|
ENGINE_DEAD_ERROR(self._errored_with))
|
||||||
|
return
|
||||||
|
|
||||||
|
message: Frame = await self.output_socket.recv(copy=False)
|
||||||
|
request_outputs = pickle.loads(message.buffer)
|
||||||
|
|
||||||
|
is_error = isinstance(request_outputs,
|
||||||
|
(BaseException, RPCError))
|
||||||
|
if is_error:
|
||||||
|
if isinstance(request_outputs, RPCError):
|
||||||
|
rpc_error: RPCError = request_outputs
|
||||||
|
request_id = rpc_error.request_id
|
||||||
|
exception = rpc_error.exception
|
||||||
|
is_engine_errored = rpc_error.is_engine_errored
|
||||||
|
else:
|
||||||
|
# MPLLMEngine should always return an RPCError to
|
||||||
|
# the output_socket when an issue arises.
|
||||||
|
# If we are here, we are in a bad state and
|
||||||
|
# should shut down the server.
|
||||||
|
error: BaseException = request_outputs
|
||||||
|
logger.error(
|
||||||
|
"Received Exception %s rather than RPCError from "
|
||||||
|
"MPLLMEngine. This should never happen.", error)
|
||||||
|
request_id = None
|
||||||
|
exception = error
|
||||||
|
is_engine_errored = True
|
||||||
|
|
||||||
|
# Set to error state only on engine critical error
|
||||||
|
# (and record only the first one)
|
||||||
|
if is_engine_errored and not self._errored_with:
|
||||||
|
self._errored_with = exception
|
||||||
|
|
||||||
|
if request_id is None:
|
||||||
|
for queue_i in tuple(self.output_queues.values()):
|
||||||
|
queue_i.put_nowait(exception)
|
||||||
|
else:
|
||||||
|
queue = self.output_queues.get(request_id)
|
||||||
|
if queue is not None:
|
||||||
|
queue.put_nowait(exception)
|
||||||
|
else:
|
||||||
|
# Put each output into the appropriate steam.
|
||||||
|
for request_output in request_outputs:
|
||||||
|
queue = self.output_queues.get(
|
||||||
|
request_output.request_id)
|
||||||
|
if queue is not None:
|
||||||
|
queue.put_nowait(request_output)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
||||||
|
|
||||||
|
async def setup(self):
|
||||||
|
"""Setup the client before it starts sending server requests."""
|
||||||
|
|
||||||
|
with self.get_data_socket() as socket:
|
||||||
|
# Wait until server is ready.
|
||||||
|
response = await self._wait_for_server_rpc(socket)
|
||||||
|
|
||||||
|
self.tracing_flag = response.tracing_enabled
|
||||||
|
|
||||||
|
# Start health_loop.
|
||||||
|
self.health_loop = asyncio.create_task(
|
||||||
|
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Destroy the ZeroMQ Context."""
|
||||||
|
# Close all sockets and terminate the context.
|
||||||
|
self.context.destroy(linger=0)
|
||||||
|
|
||||||
|
# Cancel background tasks.
|
||||||
|
if self.health_loop is not None:
|
||||||
|
self.health_loop.cancel()
|
||||||
|
self.output_loop.cancel()
|
||||||
|
|
||||||
|
def _set_errored(self, e: BaseException):
|
||||||
|
logger.exception(repr(e))
|
||||||
|
if self._errored_with is None:
|
||||||
|
self._errored_with = e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_get_data_rpc_request(request: RPCStartupRequest,
|
||||||
|
expected_type: Any,
|
||||||
|
error_message: str,
|
||||||
|
socket: Socket) -> Any:
|
||||||
|
"""Send an RPC request that is expecting data back."""
|
||||||
|
|
||||||
|
# Ping RPCServer with a request.
|
||||||
|
await socket.send_multipart((pickle.dumps(request), ), copy=False)
|
||||||
|
|
||||||
|
# Make sure the server responds in time.
|
||||||
|
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||||
|
raise TimeoutError("RPCServer didn't reply within "
|
||||||
|
f"{VLLM_RPC_TIMEOUT} ms")
|
||||||
|
|
||||||
|
# Await the data from the Server.
|
||||||
|
frame = await socket.recv(copy=False)
|
||||||
|
data = pickle.loads(frame.buffer)
|
||||||
|
|
||||||
|
if isinstance(data, BaseException):
|
||||||
|
raise data
|
||||||
|
elif not isinstance(data, expected_type):
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
|
||||||
|
socket: Socket):
|
||||||
|
"""Send one-way RPC request to trigger an action."""
|
||||||
|
|
||||||
|
if socket.closed:
|
||||||
|
raise MQClientClosedError()
|
||||||
|
|
||||||
|
await socket.send_multipart((pickle.dumps(request), ))
|
||||||
|
|
||||||
|
async def _await_ack(self, error_message: str, socket: Socket):
|
||||||
|
"""Await acknowledgement that a request succeeded."""
|
||||||
|
|
||||||
|
if socket.closed:
|
||||||
|
raise MQClientClosedError()
|
||||||
|
|
||||||
|
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||||
|
raise TimeoutError("MQLLMEngine didn't reply within "
|
||||||
|
f"{VLLM_RPC_TIMEOUT}ms")
|
||||||
|
|
||||||
|
await self._check_success(error_message, socket)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _check_success(error_message: str, socket: Socket):
|
||||||
|
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
|
||||||
|
|
||||||
|
if socket.closed:
|
||||||
|
raise MQClientClosedError()
|
||||||
|
|
||||||
|
frame = await socket.recv(copy=False)
|
||||||
|
response = pickle.loads(frame.buffer)
|
||||||
|
|
||||||
|
# Raise error if unsuccessful
|
||||||
|
if isinstance(response, BaseException):
|
||||||
|
raise response
|
||||||
|
elif (not isinstance(response, str)
|
||||||
|
or response != VLLM_RPC_SUCCESS_STR):
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||||
|
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||||
|
|
||||||
|
async def get_decoding_config(self) -> DecodingConfig:
|
||||||
|
return self.decoding_config
|
||||||
|
|
||||||
|
async def get_model_config(self) -> ModelConfig:
|
||||||
|
return self.model_config
|
||||||
|
|
||||||
|
async def is_tracing_enabled(self) -> bool:
|
||||||
|
return self.tracing_flag
|
||||||
|
|
||||||
|
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
|
||||||
|
"""Wait for the RPCServer to start up."""
|
||||||
|
|
||||||
|
return await self._send_get_data_rpc_request(
|
||||||
|
request=RPCStartupRequest.IS_SERVER_READY,
|
||||||
|
expected_type=RPCStartupResponse,
|
||||||
|
error_message="Unable to start RPC Server",
|
||||||
|
socket=socket)
|
||||||
|
|
||||||
|
async def abort(self, request_id: str):
|
||||||
|
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||||
|
|
||||||
|
with suppress(MQClientClosedError):
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
||||||
|
|
||||||
|
async def do_log_stats(self):
|
||||||
|
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def check_health(self):
|
||||||
|
"""
|
||||||
|
The check health loop probes the health status of the
|
||||||
|
Engine's health every N seconds and sets _errored_with
|
||||||
|
if the engine is unhealthy.
|
||||||
|
"""
|
||||||
|
if self._errored_with is not None:
|
||||||
|
raise self._errored_with
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
return not self.errored
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_stopped(self) -> bool:
|
||||||
|
return self.errored
|
||||||
|
|
||||||
|
@property
|
||||||
|
def errored(self) -> bool:
|
||||||
|
return self._errored_with is not None
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
inputs: PromptInputs,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||||
|
|
||||||
|
# If already dead, error out.
|
||||||
|
if self._errored_with is not None:
|
||||||
|
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||||
|
|
||||||
|
# 1) Create output queue for this requests.
|
||||||
|
queue: asyncio.Queue[Union[RequestOutput,
|
||||||
|
BaseException]] = asyncio.Queue()
|
||||||
|
self.output_queues[request_id] = queue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 2) Detach logits processors so that they can be pickled
|
||||||
|
# separately (may require cloudpickle which is slower)
|
||||||
|
if sampling_params.logits_processors:
|
||||||
|
# Defensive shallow copy
|
||||||
|
sampling_params = copy.copy(sampling_params)
|
||||||
|
logits_processors = sampling_params.logits_processors
|
||||||
|
sampling_params.logits_processors = None
|
||||||
|
lp_bytes = cloudpickle.dumps(logits_processors)
|
||||||
|
else:
|
||||||
|
lp_bytes = None
|
||||||
|
|
||||||
|
request_bytes = pickle.dumps(
|
||||||
|
RPCGenerateRequest(
|
||||||
|
inputs=inputs,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
request_id=request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request))
|
||||||
|
|
||||||
|
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||||
|
parts = (request_bytes,
|
||||||
|
lp_bytes) if lp_bytes else (request_bytes, )
|
||||||
|
await self.input_socket.send_multipart(parts, copy=False)
|
||||||
|
|
||||||
|
# 4) Stream the RequestOutputs from the output queue. Note
|
||||||
|
# that the output_loop pushes RequestOutput objects to this
|
||||||
|
# queue after pulling them from the zmq socket.
|
||||||
|
finished = False
|
||||||
|
try:
|
||||||
|
while not finished:
|
||||||
|
request_output = await queue.get()
|
||||||
|
|
||||||
|
if isinstance(request_output, BaseException):
|
||||||
|
raise request_output
|
||||||
|
|
||||||
|
finished = request_output.finished
|
||||||
|
yield request_output
|
||||||
|
finally:
|
||||||
|
# Request was canceled by the client.
|
||||||
|
if not finished and not self.errored:
|
||||||
|
await self.abort(request_id)
|
||||||
|
finally:
|
||||||
|
self.output_queues.pop(request_id)
|
||||||
|
|
||||||
|
async def encode(self, *args,
|
||||||
|
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Embeddings not supported with multiprocessing backend")
|
321
vllm/engine/multiprocessing/engine.py
Normal file
321
vllm/engine/multiprocessing/engine.py
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
import pickle
|
||||||
|
import signal
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm import AsyncEngineArgs, LLMEngine
|
||||||
|
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig)
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||||
|
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||||
|
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||||
|
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||||
|
RPCError, RPCGenerateRequest,
|
||||||
|
RPCHealthRequest, RPCStartupRequest,
|
||||||
|
RPCStartupResponse)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
|
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||||
|
SchedulerConfig, LoRAConfig]
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
POLLING_TIMEOUT_MS = 10000
|
||||||
|
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
||||||
|
|
||||||
|
|
||||||
|
class MQLLMEngine:
|
||||||
|
"""A multiprocessing wrapper for :class:`LLMEngine`.
|
||||||
|
|
||||||
|
This class is used to wrap the :class:`LLMEngine` class to enable use
|
||||||
|
in concurrnet manner. It runs a background loop and uses zeromq to
|
||||||
|
receive new requests and stream outputs incrementally via ipc.
|
||||||
|
|
||||||
|
The :class:`LLMEngine.generate` is kicked off when a new
|
||||||
|
RPCGenerateRequest is received by the input_socket.
|
||||||
|
|
||||||
|
The self.engine_loop checks the input_socket for new requests,
|
||||||
|
adds them to the LLMEngine if there are any, calls the internal
|
||||||
|
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
|
||||||
|
the output_socket.
|
||||||
|
|
||||||
|
If use_async_sockets is set, the logic associated with reading new
|
||||||
|
requests from the socket and sending data to the socket is passed
|
||||||
|
as a callback to the llm_engine, which calls the logic asynchronously
|
||||||
|
such that the IPC can be overlapped with the GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ipc_path: Base path for zeromq interprocess messaging
|
||||||
|
use_async_sockets: Whether to make send/recv async with GPU
|
||||||
|
log_requests: Whether to log the requests.
|
||||||
|
*args: Arguments for :class:`LLMEngine`.
|
||||||
|
**kwargs: Arguments for :class:`LLMEngine`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ipc_path: str,
|
||||||
|
use_async_sockets: bool,
|
||||||
|
*args,
|
||||||
|
log_requests: bool = True,
|
||||||
|
**kwargs) -> None:
|
||||||
|
self.engine = LLMEngine(*args, **kwargs)
|
||||||
|
self.log_requests = log_requests
|
||||||
|
|
||||||
|
self.use_async_sockets = use_async_sockets
|
||||||
|
if self.use_async_sockets:
|
||||||
|
self.engine.process_request_outputs_callback = \
|
||||||
|
self._async_socket_engine_callback
|
||||||
|
|
||||||
|
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# Receive input from the client.
|
||||||
|
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
||||||
|
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||||
|
|
||||||
|
# Send output stream back to client.
|
||||||
|
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||||
|
|
||||||
|
# Send health status back to client.
|
||||||
|
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||||
|
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||||
|
|
||||||
|
# IPC path for the data socket.
|
||||||
|
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||||
|
|
||||||
|
# Error state.
|
||||||
|
self._errored_with: Optional[BaseException] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dead_error(self) -> BaseException:
|
||||||
|
if self._errored_with is not None:
|
||||||
|
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||||
|
else:
|
||||||
|
return ENGINE_DEAD_ERROR()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||||
|
usage_context: UsageContext, ipc_path: str):
|
||||||
|
"""Creates an MQLLMEngine from the engine arguments."""
|
||||||
|
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
|
||||||
|
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
ipc_path=ipc_path,
|
||||||
|
use_async_sockets=engine_config.model_config.use_async_output_proc,
|
||||||
|
**engine_config.to_dict(),
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_requests=not engine_args.disable_log_requests,
|
||||||
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
usage_context=usage_context)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
logger.debug("Starting Startup Loop.")
|
||||||
|
self.run_startup_loop()
|
||||||
|
logger.debug("Starting Engine Loop.")
|
||||||
|
self.run_engine_loop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(repr(e))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Shutting down MQLLMEngine.")
|
||||||
|
finally:
|
||||||
|
logger.debug("MQLLMEngine is shut down.")
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup zeromq state on shutdown."""
|
||||||
|
# Closes all sockets and destroys context.
|
||||||
|
self.ctx.destroy(linger=0)
|
||||||
|
del self.engine
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_data_socket(
|
||||||
|
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||||
|
socket = self.ctx.socket(zmq.constants.ROUTER)
|
||||||
|
try:
|
||||||
|
socket.bind(self.data_ipc_path)
|
||||||
|
yield socket
|
||||||
|
finally:
|
||||||
|
socket.close(linger=0)
|
||||||
|
|
||||||
|
def run_startup_loop(self) -> None:
|
||||||
|
"""Startup loop for sending data from Engine -> Client."""
|
||||||
|
|
||||||
|
with self.make_data_socket() as socket:
|
||||||
|
response: Union[RPCStartupResponse, BaseException]
|
||||||
|
try:
|
||||||
|
identity, message = socket.recv_multipart(copy=False)
|
||||||
|
request: RPCStartupRequest = pickle.loads(message.buffer)
|
||||||
|
|
||||||
|
# Handle the query from the Client.
|
||||||
|
if request == RPCStartupRequest.IS_SERVER_READY:
|
||||||
|
tracing_enabled = self.engine.is_tracing_enabled()
|
||||||
|
response = RPCStartupResponse(
|
||||||
|
tracing_enabled=tracing_enabled)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
response = e
|
||||||
|
|
||||||
|
socket.send_multipart((identity, pickle.dumps(response)),
|
||||||
|
copy=False)
|
||||||
|
|
||||||
|
def run_engine_loop(self):
|
||||||
|
"""Core busy loop of the LLMEngine."""
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if not self.engine.has_unfinished_requests():
|
||||||
|
# Poll until there is work to do.
|
||||||
|
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||||
|
self.engine.do_log_stats()
|
||||||
|
logger.debug("Waiting for new requests in engine loop.")
|
||||||
|
|
||||||
|
# Handle any input from the client.
|
||||||
|
self.handle_new_input()
|
||||||
|
|
||||||
|
# Engine step.
|
||||||
|
request_outputs = self.engine_step()
|
||||||
|
|
||||||
|
# Send request outputs (if async, done in engine_step callback).
|
||||||
|
if not self.use_async_sockets:
|
||||||
|
self._send_outputs(request_outputs)
|
||||||
|
|
||||||
|
def engine_step(self) -> List[RequestOutput]:
|
||||||
|
"""Engine step wrapper with error handling."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.engine.step()
|
||||||
|
except SystemExit:
|
||||||
|
raise
|
||||||
|
except BaseException as e:
|
||||||
|
self._set_errored(e)
|
||||||
|
rpc_err = RPCError(request_id=None,
|
||||||
|
is_engine_errored=True,
|
||||||
|
exception=e)
|
||||||
|
self._send_outputs(rpc_err)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def handle_new_input(self):
|
||||||
|
"""Handle new input from the socket"""
|
||||||
|
try:
|
||||||
|
while self.input_socket.poll(timeout=0) != 0:
|
||||||
|
frames = self.input_socket.recv_multipart(copy=False)
|
||||||
|
request = pickle.loads(frames[0].buffer)
|
||||||
|
|
||||||
|
if isinstance(request, RPCGenerateRequest):
|
||||||
|
if len(frames) > 1:
|
||||||
|
# Use cloudpickle for logits processors
|
||||||
|
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||||
|
request.sampling_params.logits_processors = lprocs
|
||||||
|
self._handle_generate_request(request)
|
||||||
|
elif isinstance(request, RPCAbortRequest):
|
||||||
|
self._handle_abort_request(request)
|
||||||
|
elif isinstance(request, RPCHealthRequest):
|
||||||
|
self._handle_health_request()
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown RPCRequest Type: {request}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._set_errored(e)
|
||||||
|
self._send_unhealthy(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _handle_generate_request(self, request: RPCGenerateRequest):
|
||||||
|
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
|
||||||
|
request_id = request.request_id
|
||||||
|
|
||||||
|
if self._errored_with is not None:
|
||||||
|
rpc_err = RPCError(request_id=request_id,
|
||||||
|
is_engine_errored=True,
|
||||||
|
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||||
|
self._send_outputs(rpc_err)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.engine.add_request(
|
||||||
|
request_id=request_id,
|
||||||
|
inputs=request.inputs,
|
||||||
|
params=request.sampling_params,
|
||||||
|
lora_request=request.lora_request,
|
||||||
|
trace_headers=request.trace_headers,
|
||||||
|
prompt_adapter_request=request.prompt_adapter_request)
|
||||||
|
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Added request %s.", request.request_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# We do not set self._errored = True here, since the error
|
||||||
|
# is due to an issue adding this request to the engine,
|
||||||
|
# rather than an issue with the engine itself.
|
||||||
|
is_errored = self._errored_with is not None
|
||||||
|
rpc_err = RPCError(request_id=request_id,
|
||||||
|
is_engine_errored=is_errored,
|
||||||
|
exception=e)
|
||||||
|
self._send_outputs(rpc_err)
|
||||||
|
|
||||||
|
# Remove request from the engine.
|
||||||
|
self.engine.abort_request(request_id)
|
||||||
|
|
||||||
|
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||||
|
self.engine.abort_request(request.request_id)
|
||||||
|
if self.log_requests:
|
||||||
|
logger.info("Aborted request %s.", request.request_id)
|
||||||
|
|
||||||
|
def _handle_health_request(self):
|
||||||
|
if self._errored_with is not None:
|
||||||
|
self._send_unhealthy(self._errored_with)
|
||||||
|
|
||||||
|
# Raises error if unhealthy.
|
||||||
|
self.engine.check_health()
|
||||||
|
self._send_healthy()
|
||||||
|
|
||||||
|
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||||
|
"""Send List of RequestOutput to RPCClient."""
|
||||||
|
if outputs:
|
||||||
|
output_bytes = pickle.dumps(outputs)
|
||||||
|
self.output_socket.send_multipart((output_bytes, ), copy=False)
|
||||||
|
|
||||||
|
def _send_healthy(self):
|
||||||
|
"""Send HEALTHY message to RPCClient."""
|
||||||
|
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||||
|
|
||||||
|
def _send_unhealthy(self, error: BaseException):
|
||||||
|
"""Send UNHEALTHY message to RPCClient."""
|
||||||
|
error_bytes = pickle.dumps(error)
|
||||||
|
self.health_socket.send_multipart((error_bytes, ), copy=False)
|
||||||
|
|
||||||
|
def _async_socket_engine_callback(self,
|
||||||
|
request_outputs: REQUEST_OUTPUTS_T):
|
||||||
|
"""Callback used by engine to make socket handling async with GPU."""
|
||||||
|
self._send_outputs(request_outputs)
|
||||||
|
self.handle_new_input()
|
||||||
|
|
||||||
|
def _set_errored(self, e: BaseException):
|
||||||
|
"""Log and set errored status if this is the first issue."""
|
||||||
|
if self._errored_with is None:
|
||||||
|
self._errored_with = e
|
||||||
|
|
||||||
|
|
||||||
|
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||||
|
ipc_path: str):
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
# Interrupt server on sigterm
|
||||||
|
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||||
|
usage_context=usage_context,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
engine.start()
|
@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class AsyncEngineClient(Protocol):
|
class EngineClient(Protocol):
|
||||||
"""Protocol class for Clients to AsyncLLMEngine"""
|
"""Protocol class for Clients to Engine"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit_concurrency(self) -> Optional[int]:
|
def dead_error(self) -> BaseException:
|
||||||
"""Maximum number of concurrently running requests."""
|
...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
@ -1,21 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Request, Response
|
from fastapi import FastAPI, Request, Response
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_process_using_port
|
from vllm.utils import find_process_using_port
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||||
**uvicorn_kwargs: Any):
|
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
methods = getattr(route, "methods", None)
|
methods = getattr(route, "methods", None)
|
||||||
@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
|||||||
|
|
||||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||||
|
|
||||||
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
|
||||||
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
|
||||||
if limit_concurrency is not None:
|
|
||||||
logger.info(
|
|
||||||
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
|
||||||
"limit at the expense of performance run with "
|
|
||||||
"--disable-frontend-multiprocessing", limit_concurrency)
|
|
||||||
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
|
|
||||||
|
|
||||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
_add_shutdown_handlers(app, server)
|
_add_shutdown_handlers(app, server)
|
||||||
@ -63,7 +54,7 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"port %s is used by process %s launched with command:\n%s",
|
"port %s is used by process %s launched with command:\n%s",
|
||||||
port, process, " ".join(process.cmdline()))
|
port, process, " ".join(process.cmdline()))
|
||||||
logger.info("Gracefully stopping http server")
|
logger.info("Shutting down FastAPI HTTP server.")
|
||||||
return server.shutdown()
|
return server.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
|||||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
@app.exception_handler(AsyncEngineDeadError)
|
@app.exception_handler(AsyncEngineDeadError)
|
||||||
async def engine_dead_handler(_, __):
|
async def async_engine_dead_handler(_, __):
|
||||||
"""Kill the server if the async engine is already dead. It will
|
"""Kill the server if the async engine is already dead. It will
|
||||||
not handle any further requests."""
|
not handle any further requests."""
|
||||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||||
@ -99,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
|||||||
server.should_exit = True
|
server.should_exit = True
|
||||||
|
|
||||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
@app.exception_handler(MQEngineDeadError)
|
||||||
|
async def mq_engine_dead_handler(_, __):
|
||||||
|
"""Kill the server if the mq engine is already dead. It will
|
||||||
|
not handle any further requests."""
|
||||||
|
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||||
|
logger.fatal("MQLLMEngine is already dead, terminating server "
|
||||||
|
"process")
|
||||||
|
server.should_exit = True
|
||||||
|
|
||||||
|
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
@ -26,7 +26,9 @@ import vllm.envs as envs
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
|
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||||
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
@ -44,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
UnloadLoraAdapterRequest)
|
UnloadLoraAdapterRequest)
|
||||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
|
||||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
@ -67,29 +67,16 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
|
|||||||
_running_tasks: Set[asyncio.Task] = set()
|
_running_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
def model_is_embedding(model_name: str, trust_remote_code: bool,
|
|
||||||
quantization: Optional[str],
|
|
||||||
revision: Optional[str]) -> bool:
|
|
||||||
return ModelConfig(model=model_name,
|
|
||||||
revision=revision,
|
|
||||||
tokenizer=model_name,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
quantization=quantization,
|
|
||||||
seed=0,
|
|
||||||
dtype="auto").embedding_mode
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
try:
|
try:
|
||||||
if app.state.log_stats:
|
if app.state.log_stats:
|
||||||
async_engine_client = app.state.engine_client
|
engine_client: EngineClient = app.state.engine_client
|
||||||
|
|
||||||
async def _force_log():
|
async def _force_log():
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10.)
|
||||||
await async_engine_client.do_log_stats()
|
await engine_client.do_log_stats()
|
||||||
|
|
||||||
task = asyncio.create_task(_force_log())
|
task = asyncio.create_task(_force_log())
|
||||||
_running_tasks.add(task)
|
_running_tasks.add(task)
|
||||||
@ -108,9 +95,9 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def build_async_engine_client(
|
async def build_async_engine_client(
|
||||||
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
args: Namespace) -> AsyncIterator[Optional[EngineClient]]:
|
||||||
|
|
||||||
# Context manager to handle async_engine_client lifecycle
|
# Context manager to handle engine_client lifecycle
|
||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
@ -123,19 +110,18 @@ async def build_async_engine_client(
|
|||||||
async def build_async_engine_client_from_engine_args(
|
async def build_async_engine_client_from_engine_args(
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
disable_frontend_multiprocessing: bool = False,
|
disable_frontend_multiprocessing: bool = False,
|
||||||
) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
) -> AsyncIterator[Optional[EngineClient]]:
|
||||||
"""
|
"""
|
||||||
Create AsyncEngineClient, either:
|
Create EngineClient, either:
|
||||||
- in-process using the AsyncLLMEngine Directly
|
- in-process using the AsyncLLMEngine Directly
|
||||||
- multiprocess using AsyncLLMEngine RPC
|
- multiprocess using AsyncLLMEngine RPC
|
||||||
|
|
||||||
Returns the Client or None if the creation failed.
|
Returns the Client or None if the creation failed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
# Fall back
|
||||||
# TODO: support embedding model via RPC.
|
# TODO: fill out feature matrix.
|
||||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||||
engine_args.quantization, engine_args.revision)
|
|
||||||
or disable_frontend_multiprocessing):
|
or disable_frontend_multiprocessing):
|
||||||
engine_config = engine_args.create_engine_config()
|
engine_config = engine_args.create_engine_config()
|
||||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||||
@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
"and vLLM will properly handle cleanup.")
|
"and vLLM will properly handle cleanup.")
|
||||||
|
|
||||||
# Select random path for IPC.
|
# Select random path for IPC.
|
||||||
rpc_path = get_open_zmq_ipc_path()
|
ipc_path = get_open_zmq_ipc_path()
|
||||||
logger.info("Multiprocessing frontend to use %s for RPC Path.",
|
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||||
rpc_path)
|
ipc_path)
|
||||||
|
|
||||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
# Start RPCServer in separate process (holds the LLMEngine).
|
||||||
# NOTE: Actually, this is not true yet. We still need to support
|
|
||||||
# embedding models via RPC (see TODO above)
|
|
||||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
|
||||||
|
|
||||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
|
||||||
context = multiprocessing.get_context("spawn")
|
|
||||||
# the current process might have CUDA context,
|
# the current process might have CUDA context,
|
||||||
# so we need to spawn a new process
|
# so we need to spawn a new process
|
||||||
rpc_server_process = context.Process(
|
context = multiprocessing.get_context("spawn")
|
||||||
target=run_rpc_server,
|
|
||||||
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
|
engine_process = context.Process(target=run_mp_engine,
|
||||||
rpc_server_process.start()
|
args=(engine_args,
|
||||||
logger.info("Started engine process with PID %d",
|
UsageContext.OPENAI_API_SERVER,
|
||||||
rpc_server_process.pid)
|
ipc_path))
|
||||||
|
engine_process.start()
|
||||||
|
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||||
|
|
||||||
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
|
# NOTE: Actually, this is not true yet. We still need to support
|
||||||
|
# embedding models via RPC (see TODO above)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await rpc_client.setup()
|
await mp_engine_client.setup()
|
||||||
break
|
break
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
if not rpc_server_process.is_alive():
|
if not engine_process.is_alive():
|
||||||
logger.error(
|
logger.error("Engine process died before responding "
|
||||||
"RPCServer process died before responding "
|
"to readiness probe")
|
||||||
"to readiness probe")
|
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
yield rpc_client # type: ignore[misc]
|
yield mp_engine_client # type: ignore[misc]
|
||||||
finally:
|
finally:
|
||||||
# Ensure rpc server process was terminated
|
# Ensure rpc server process was terminated
|
||||||
rpc_server_process.terminate()
|
engine_process.terminate()
|
||||||
|
|
||||||
# Close all open connections to the backend
|
# Close all open connections to the backend
|
||||||
rpc_client.close()
|
mp_engine_client.close()
|
||||||
|
|
||||||
# Wait for server process to join
|
# Wait for engine process to join
|
||||||
rpc_server_process.join()
|
engine_process.join(4)
|
||||||
|
if engine_process.exitcode is None:
|
||||||
|
# Kill if taking longer than 5 seconds to stop
|
||||||
|
engine_process.kill()
|
||||||
|
|
||||||
# Lazy import for prometheus multiprocessing.
|
# Lazy import for prometheus multiprocessing.
|
||||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||||
# before prometheus_client is imported.
|
# before prometheus_client is imported.
|
||||||
# See https://prometheus.github.io/client_python/multiprocess/
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
from prometheus_client import multiprocess
|
from prometheus_client import multiprocess
|
||||||
multiprocess.mark_process_dead(rpc_server_process.pid)
|
multiprocess.mark_process_dead(engine_process.pid)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding:
|
|||||||
return request.app.state.openai_serving_embedding
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> AsyncEngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
||||||
@ -473,7 +463,7 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
|
|
||||||
|
|
||||||
def init_app_state(
|
def init_app_state(
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
state: State,
|
state: State,
|
||||||
args: Namespace,
|
args: Namespace,
|
||||||
@ -488,11 +478,11 @@ def init_app_state(
|
|||||||
else:
|
else:
|
||||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||||
|
|
||||||
state.engine_client = async_engine_client
|
state.engine_client = engine_client
|
||||||
state.log_stats = not args.disable_log_stats
|
state.log_stats = not args.disable_log_stats
|
||||||
|
|
||||||
state.openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
async_engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
args.response_role,
|
args.response_role,
|
||||||
@ -504,7 +494,7 @@ def init_app_state(
|
|||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser)
|
tool_parser=args.tool_call_parser)
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
async_engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
@ -513,13 +503,13 @@ def init_app_state(
|
|||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
)
|
)
|
||||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
async_engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
)
|
)
|
||||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
async_engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
served_model_names,
|
served_model_names,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
@ -541,21 +531,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
async with build_async_engine_client(args) as async_engine_client:
|
async with build_async_engine_client(args) as engine_client:
|
||||||
# If None, creation of the client failed and we exit.
|
# If None, creation of the client failed and we exit.
|
||||||
if async_engine_client is None:
|
if engine_client is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
model_config = await async_engine_client.get_model_config()
|
model_config = await engine_client.get_model_config()
|
||||||
init_app_state(async_engine_client, model_config, app.state, args)
|
init_app_state(engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
temp_socket.close()
|
temp_socket.close()
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
limit_concurrency=async_engine_client.limit_concurrency,
|
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
log_level=args.uvicorn_log_level,
|
log_level=args.uvicorn_log_level,
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Mapping, Optional, Union
|
|
||||||
|
|
||||||
from vllm.inputs import PromptInputs
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
|
|
||||||
# Success string used for RPC instructions.
|
|
||||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
|
||||||
|
|
||||||
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
|
|
||||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
|
|
||||||
|
|
||||||
# HWM is set to Infinity.
|
|
||||||
VLLM_RPC_ZMQ_HWM = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RPCGenerateRequest:
|
|
||||||
inputs: PromptInputs
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
request_id: str
|
|
||||||
lora_request: Optional[LoRARequest] = None
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RPCAbortRequest:
|
|
||||||
request_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class RPCUtilityRequest(Enum):
|
|
||||||
IS_SERVER_READY = 1
|
|
||||||
GET_MODEL_CONFIG = 2
|
|
||||||
GET_DECODING_CONFIG = 3
|
|
||||||
GET_PARALLEL_CONFIG = 4
|
|
||||||
GET_SCHEDULER_CONFIG = 5
|
|
||||||
GET_LORA_CONFIG = 6
|
|
||||||
DO_LOG_STATS = 7
|
|
||||||
IS_SERVER_HEALTHY = 8
|
|
||||||
IS_TRACING_ENABLED = 9
|
|
||||||
START_PROFILE = 10
|
|
||||||
STOP_PROFILE = 11
|
|
||||||
|
|
||||||
|
|
||||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
|
||||||
RPCUtilityRequest]
|
|
@ -1,451 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import pickle
|
|
||||||
from contextlib import contextmanager, suppress
|
|
||||||
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import cloudpickle
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
from zmq import Frame # type: ignore[attr-defined]
|
|
||||||
from zmq.asyncio import Socket
|
|
||||||
|
|
||||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig)
|
|
||||||
# yapf: disable
|
|
||||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
|
||||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
|
|
||||||
VLLM_RPC_SUCCESS_STR,
|
|
||||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
|
||||||
RPCGenerateRequest, RPCUtilityRequest)
|
|
||||||
# yapf: enable
|
|
||||||
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
|
|
||||||
from vllm.inputs import PromptInputs
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
# Path used for inprocess proxy.
|
|
||||||
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
|
|
||||||
|
|
||||||
|
|
||||||
class RPCClientClosedError(Exception):
|
|
||||||
"""Exception class raised when the client is used post-close.
|
|
||||||
|
|
||||||
The client can be closed, which closes the ZMQ context. This normally
|
|
||||||
happens on server shutdown. In some cases, methods like abort and
|
|
||||||
do_log_stats will still be called and then try to open a socket, which
|
|
||||||
causes a ZMQError and creates a huge stack trace.
|
|
||||||
So, we throw this error such that we can suppress it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncEngineRPCClient:
|
|
||||||
"""
|
|
||||||
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
|
|
||||||
|
|
||||||
The overall design mirrors the Asynchronous Client Server Pattern
|
|
||||||
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
|
|
||||||
|
|
||||||
On startup, the RPCClient:
|
|
||||||
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
|
|
||||||
via ipc, which uses unix sockets under the hood
|
|
||||||
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
|
|
||||||
- makes ROUTER socket (from_api_server) that binds to a random
|
|
||||||
inproc address, which uses memory under the hood
|
|
||||||
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
|
|
||||||
- runs a proxy in a background asyncio task between
|
|
||||||
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
|
|
||||||
|
|
||||||
Each request handled by the asyncio api_server calls generate():
|
|
||||||
- make a DEALER socket that connects to from_api_server via inproc
|
|
||||||
- send a RCPGenerateRequest to the inproc socket
|
|
||||||
- background proxy forwards the request from inproc -> ipc
|
|
||||||
- RPCServer responds to the request one token at a time over ipc
|
|
||||||
- background proxy forwards the response from ipc -> inproc
|
|
||||||
|
|
||||||
The connection looks like this:
|
|
||||||
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
|
|
||||||
|
|
||||||
Message routing is performed via identities that are managed by the
|
|
||||||
ROUTER socket. ROUTER sockets track every connection it has and
|
|
||||||
tells the caller about these. The way it tells the caller is to stick
|
|
||||||
the connection identity in front of each message received. When we
|
|
||||||
send the message via a ROUTER, we first send an identity frame.
|
|
||||||
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
|
|
||||||
for more details on connection identities.
|
|
||||||
|
|
||||||
This proxy design enables us to use a single unix socket, which
|
|
||||||
improves performance by avoiding syscalls (~5%) and avoids resource limits
|
|
||||||
such as ulimit, which defaults to 1024 on ubuntu.
|
|
||||||
|
|
||||||
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
|
|
||||||
which is required to avoid dropping messages under high load.
|
|
||||||
This is generally not advisable. However, since we are in control
|
|
||||||
of both sides of the connection + failure on either side is
|
|
||||||
catastrophic to the overall system health and memory profiling
|
|
||||||
suggests limited memory overhead relative to asyncio, we will
|
|
||||||
proceed for now.
|
|
||||||
|
|
||||||
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
|
|
||||||
for more details on high water marks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rpc_path: str):
|
|
||||||
self.context = zmq.asyncio.Context()
|
|
||||||
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
|
|
||||||
self._errored = False
|
|
||||||
|
|
||||||
# Maximum number of sockets that can be opened (typically 65536).
|
|
||||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
|
||||||
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
|
|
||||||
assert isinstance(socket_limit, int)
|
|
||||||
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
|
|
||||||
"the number of concurrent requests vLLM can process. Launch "
|
|
||||||
"vLLM with --disable-frontend-multiprocessing and open a "
|
|
||||||
"GitHub issue so we can investigate.")
|
|
||||||
|
|
||||||
# We only have 1 ipc connection that uses unix sockets, so
|
|
||||||
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
|
|
||||||
# not run into ulimit issues)
|
|
||||||
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
|
|
||||||
|
|
||||||
# IPC connection to RPC Server (uses unix sockets).
|
|
||||||
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
|
|
||||||
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
|
||||||
self.to_rpc_server.bind(rpc_path)
|
|
||||||
|
|
||||||
# In process proxy to RPC Server (uses memory-based messaging).
|
|
||||||
self.from_api_server: Socket = self.context.socket(
|
|
||||||
zmq.constants.ROUTER)
|
|
||||||
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
|
||||||
self.from_api_server.bind(INPROC_PROXY_PATH)
|
|
||||||
|
|
||||||
# Asyncio background task for the proxy.
|
|
||||||
self.proxy_in_task = asyncio.create_task(
|
|
||||||
self.run_proxy(self.from_api_server, self.to_rpc_server))
|
|
||||||
self.proxy_out_task = asyncio.create_task(
|
|
||||||
self.run_proxy(self.to_rpc_server, self.from_api_server))
|
|
||||||
|
|
||||||
# Since we open 1 inproc socket per request, we have a hard cap on
|
|
||||||
# the number of requests that can run in vLLM w. frontend
|
|
||||||
# mulitprocessing. This value is used uvicorn to launch
|
|
||||||
# with --limit-concurrency to return 503 when server is overloaded.
|
|
||||||
# We need 2 sockets per request - 2:
|
|
||||||
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
|
|
||||||
self.limit_concurrency = socket_limit // 2 - 2
|
|
||||||
|
|
||||||
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
|
|
||||||
"""Background task that runs a proxy"""
|
|
||||||
while True:
|
|
||||||
frames = await socket_from.recv_multipart(copy=False)
|
|
||||||
await socket_to.send_multipart(frames, copy=False)
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
"""Setup the client before it starts sending server requests."""
|
|
||||||
|
|
||||||
# Wait until server is ready.
|
|
||||||
await self._wait_for_server_rpc()
|
|
||||||
|
|
||||||
# Get the configs.
|
|
||||||
self.model_config = await self._get_model_config_rpc()
|
|
||||||
self.decoding_config = await self._get_decoding_config_rpc()
|
|
||||||
self.tracing_flag = await self._is_tracing_enabled_rpc()
|
|
||||||
|
|
||||||
# Create the tokenizer group.
|
|
||||||
# TODO: refactor OAI server to avoid needing this info.
|
|
||||||
self.tokenizer = init_tokenizer_from_configs(
|
|
||||||
model_config=self.model_config,
|
|
||||||
scheduler_config=(await self._get_scheduler_config_rpc()),
|
|
||||||
parallel_config=(await self._get_parallel_config_rpc()),
|
|
||||||
enable_lora=bool(await self._get_lora_config_rpc()),
|
|
||||||
)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Destroy the ZeroMQ Context."""
|
|
||||||
# Close all sockets associated with this context and
|
|
||||||
# then terminate the context.
|
|
||||||
self.from_api_server.close()
|
|
||||||
self.to_rpc_server.close()
|
|
||||||
self.context.destroy()
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def to_proxy_socket(self) -> Iterator[Socket]:
|
|
||||||
# Connect to the RPCServer via the proxy.
|
|
||||||
|
|
||||||
# Raise a sensible error if the client was already closed.
|
|
||||||
# This can happen if a server shutdown is triggered but some coroutines
|
|
||||||
# are still running requests.
|
|
||||||
# There should not be a race condition with this check because we don't
|
|
||||||
# yield to the event loop between here and opening the socket.
|
|
||||||
if self.context.closed:
|
|
||||||
raise RPCClientClosedError("The ZMQ client has already shut down")
|
|
||||||
|
|
||||||
# Note that we use DEALER to enable asynchronous communication
|
|
||||||
# to enable streaming.
|
|
||||||
socket = self.context.socket(zmq.constants.DEALER)
|
|
||||||
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
|
||||||
try:
|
|
||||||
socket.connect(INPROC_PROXY_PATH)
|
|
||||||
yield socket
|
|
||||||
finally:
|
|
||||||
socket.close(linger=0)
|
|
||||||
|
|
||||||
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
|
|
||||||
expected_type: Any,
|
|
||||||
error_message: str) -> Any:
|
|
||||||
"""Send an RPC request that is expecting data back."""
|
|
||||||
|
|
||||||
with self.to_proxy_socket() as socket:
|
|
||||||
# Ping RPCServer with a request.
|
|
||||||
await socket.send_multipart((cloudpickle.dumps(request), ),
|
|
||||||
copy=False)
|
|
||||||
|
|
||||||
# Make sure the server responds
|
|
||||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
|
||||||
raise TimeoutError("Server didn't reply within "
|
|
||||||
f"{self._data_timeout} ms")
|
|
||||||
|
|
||||||
# Await the data from the Server.
|
|
||||||
frame = await socket.recv(copy=False)
|
|
||||||
assert isinstance(frame, Frame)
|
|
||||||
data = pickle.loads(frame.buffer)
|
|
||||||
|
|
||||||
if isinstance(data, Exception):
|
|
||||||
# Re-raise exceptions returned by the server
|
|
||||||
raise data
|
|
||||||
|
|
||||||
if not isinstance(data, expected_type):
|
|
||||||
# LoRAConfig can be None.
|
|
||||||
if expected_type == LoRAConfig and data is None:
|
|
||||||
pass
|
|
||||||
elif isinstance(data, Exception):
|
|
||||||
logger.error(error_message)
|
|
||||||
raise data
|
|
||||||
else:
|
|
||||||
raise ValueError(error_message)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def _send_one_way_rpc_request(self,
|
|
||||||
request: RPC_REQUEST_TYPE,
|
|
||||||
error_message: str,
|
|
||||||
socket: Optional[Socket] = None):
|
|
||||||
"""Send one-way RPC request to trigger an action."""
|
|
||||||
|
|
||||||
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
|
|
||||||
|
|
||||||
await socket.send_multipart((cloudpickle.dumps(request), ))
|
|
||||||
|
|
||||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
|
||||||
raise TimeoutError("Server didn't reply within "
|
|
||||||
f"{self._data_timeout} ms")
|
|
||||||
|
|
||||||
frame = await socket.recv(copy=False)
|
|
||||||
assert isinstance(frame, Frame)
|
|
||||||
return pickle.loads(frame.buffer)
|
|
||||||
|
|
||||||
# Make a new socket connection.
|
|
||||||
if socket is None:
|
|
||||||
with self.to_proxy_socket() as socket:
|
|
||||||
response = await do_rpc_call(socket, request)
|
|
||||||
|
|
||||||
# Use existing socket connection.
|
|
||||||
else:
|
|
||||||
response = await do_rpc_call(socket, request)
|
|
||||||
|
|
||||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
|
||||||
if isinstance(response, Exception):
|
|
||||||
logger.error(error_message)
|
|
||||||
raise response
|
|
||||||
raise ValueError(error_message)
|
|
||||||
|
|
||||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
|
||||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
|
||||||
|
|
||||||
async def get_decoding_config(self) -> DecodingConfig:
|
|
||||||
return self.decoding_config
|
|
||||||
|
|
||||||
async def get_model_config(self) -> ModelConfig:
|
|
||||||
return self.model_config
|
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
|
||||||
return self.tracing_flag
|
|
||||||
|
|
||||||
async def _wait_for_server_rpc(self):
|
|
||||||
"""Wait for the RPCServer to start up."""
|
|
||||||
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
|
||||||
error_message="Unable to start RPC Server")
|
|
||||||
|
|
||||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
|
||||||
"""Get the ModelConfig object from the RPC Server"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
|
||||||
expected_type=ModelConfig,
|
|
||||||
error_message="Could not get ModelConfig from RPC Server")
|
|
||||||
|
|
||||||
async def _get_decoding_config_rpc(self) -> DecodingConfig:
|
|
||||||
"""Get DecodingConfig from the RPCServer"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
|
||||||
expected_type=DecodingConfig,
|
|
||||||
error_message="Could not get DecodingConfig from RPC Server")
|
|
||||||
|
|
||||||
async def _get_parallel_config_rpc(self) -> ParallelConfig:
|
|
||||||
"""Get ParallelConfig from the RPCServer"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
|
||||||
expected_type=ParallelConfig,
|
|
||||||
error_message="Could not get ParallelConfig from RPC Server")
|
|
||||||
|
|
||||||
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
|
|
||||||
"""Get SchedulerConfig from the RPCServer"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
|
||||||
expected_type=SchedulerConfig,
|
|
||||||
error_message="Could not get SchedulerConfig from RPC Server")
|
|
||||||
|
|
||||||
async def _get_lora_config_rpc(self) -> LoRAConfig:
|
|
||||||
"""Get LoRAConfig from the RPCServer"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.GET_LORA_CONFIG,
|
|
||||||
expected_type=LoRAConfig,
|
|
||||||
error_message="Could not get LoRAConfig from RPC Server")
|
|
||||||
|
|
||||||
async def _is_tracing_enabled_rpc(self) -> bool:
|
|
||||||
"""Get is_tracing_enabled flag from the RPCServer"""
|
|
||||||
|
|
||||||
return await self._send_get_data_rpc_request(
|
|
||||||
RPCUtilityRequest.IS_TRACING_ENABLED,
|
|
||||||
expected_type=bool,
|
|
||||||
error_message="Could not get is_tracing_enabled from RPC Server")
|
|
||||||
|
|
||||||
async def abort(self, request_id: str):
|
|
||||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
|
||||||
|
|
||||||
# Suppress timeouts as well.
|
|
||||||
# In cases where the server is busy processing requests and a very
|
|
||||||
# large volume of abort requests arrive, it is likely that the server
|
|
||||||
# will not be able to ack all of them in time. We have seen this when
|
|
||||||
# we abort 20k requests at once while another 2k are processing- many
|
|
||||||
# of them time out, but we see the server successfully abort all of the
|
|
||||||
# requests.
|
|
||||||
# In this case we assume that the server has received or will receive
|
|
||||||
# these abort requests, and ignore the timeout. This prevents a massive
|
|
||||||
# wall of `TimeoutError` stack traces.
|
|
||||||
with suppress(RPCClientClosedError, TimeoutError):
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCAbortRequest(request_id),
|
|
||||||
error_message=f"RPCAbortRequest {request_id} failed")
|
|
||||||
|
|
||||||
async def do_log_stats(self):
|
|
||||||
"""Send a DO_LOG_STATS signal to the RPC Server"""
|
|
||||||
with suppress(RPCClientClosedError):
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
|
||||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
return not self._errored
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_stopped(self) -> bool:
|
|
||||||
return self._errored
|
|
||||||
|
|
||||||
@property
|
|
||||||
def errored(self) -> bool:
|
|
||||||
return self._errored
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self,
|
|
||||||
inputs: PromptInputs,
|
|
||||||
sampling_params: SamplingParams,
|
|
||||||
request_id: str,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
|
||||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
|
||||||
|
|
||||||
finished = False
|
|
||||||
try:
|
|
||||||
with self.to_proxy_socket() as socket:
|
|
||||||
# Send RPCGenerateRequest to the RPCServer.
|
|
||||||
await socket.send_multipart((cloudpickle.dumps(
|
|
||||||
RPCGenerateRequest(
|
|
||||||
inputs=inputs,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
request_id=request_id,
|
|
||||||
lora_request=lora_request,
|
|
||||||
trace_headers=trace_headers,
|
|
||||||
prompt_adapter_request=prompt_adapter_request)), ))
|
|
||||||
|
|
||||||
# Stream back the results from the RPC Server.
|
|
||||||
while not finished:
|
|
||||||
message = await socket.recv(copy=False)
|
|
||||||
assert isinstance(message, Frame)
|
|
||||||
request_output = pickle.loads(message.buffer)
|
|
||||||
|
|
||||||
if isinstance(request_output, Exception):
|
|
||||||
# On exception, check if the server is still healthy
|
|
||||||
# possibly setting the `errored` property.
|
|
||||||
if not self._errored:
|
|
||||||
try:
|
|
||||||
await self.check_health(socket=socket)
|
|
||||||
except Exception as e:
|
|
||||||
self._errored = True
|
|
||||||
logger.exception(repr(e))
|
|
||||||
|
|
||||||
# NB: do before raising here so that the flag is set
|
|
||||||
# by the time the caller receives this exception
|
|
||||||
raise request_output
|
|
||||||
|
|
||||||
finished = request_output.finished
|
|
||||||
yield request_output
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Request was canceled by the client.
|
|
||||||
if not finished and not self._errored:
|
|
||||||
await self.abort(request_id)
|
|
||||||
|
|
||||||
async def check_health(self, socket: Optional[Socket] = None) -> None:
|
|
||||||
"""Raise if unhealthy"""
|
|
||||||
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
|
|
||||||
error_message="Got Unhealthy response from RPC Server",
|
|
||||||
socket=socket)
|
|
||||||
|
|
||||||
async def encode(self, *args,
|
|
||||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Embeddings not supported with multiprocessing backend")
|
|
||||||
|
|
||||||
async def start_profile(self) -> None:
|
|
||||||
"""Start profiling the engine"""
|
|
||||||
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCUtilityRequest.START_PROFILE,
|
|
||||||
error_message="RPCRequest START_PROFILE failed.")
|
|
||||||
|
|
||||||
async def stop_profile(self) -> None:
|
|
||||||
"""Stop profiling the engine"""
|
|
||||||
|
|
||||||
await self._send_one_way_rpc_request(
|
|
||||||
request=RPCUtilityRequest.STOP_PROFILE,
|
|
||||||
error_message="RPCRequest STOP_PROFILE failed.")
|
|
@ -1,243 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import pickle
|
|
||||||
import signal
|
|
||||||
from typing import Any, Coroutine, Union
|
|
||||||
|
|
||||||
import cloudpickle
|
|
||||||
import uvloop
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
from typing_extensions import Never
|
|
||||||
from zmq import Frame # type: ignore[attr-defined]
|
|
||||||
from zmq.asyncio import Socket
|
|
||||||
|
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
|
||||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig)
|
|
||||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
|
|
||||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
|
||||||
RPCGenerateRequest, RPCUtilityRequest)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.usage.usage_lib import UsageContext
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
|
||||||
SchedulerConfig, LoRAConfig]
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncEngineRPCServer:
|
|
||||||
|
|
||||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
|
||||||
usage_context: UsageContext, rpc_path: str):
|
|
||||||
# Initialize engine first.
|
|
||||||
self.engine = AsyncLLMEngine.from_engine_args(
|
|
||||||
async_engine_args, usage_context=usage_context)
|
|
||||||
|
|
||||||
# Initialize context.
|
|
||||||
self.context = zmq.asyncio.Context()
|
|
||||||
|
|
||||||
# Init socket.
|
|
||||||
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
|
|
||||||
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
|
||||||
self.socket.connect(rpc_path)
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Cleanup all resources."""
|
|
||||||
self.socket.close()
|
|
||||||
self.context.destroy()
|
|
||||||
# Clear the engine reference so that it can be GC'ed.
|
|
||||||
del self.engine
|
|
||||||
|
|
||||||
async def get_config(self, identity, request):
|
|
||||||
try:
|
|
||||||
config: CONFIG_TYPE
|
|
||||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
|
||||||
config = await self.engine.get_model_config()
|
|
||||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
|
||||||
config = await self.engine.get_decoding_config()
|
|
||||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
|
||||||
config = await self.engine.get_lora_config()
|
|
||||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
|
||||||
config = await self.engine.get_scheduler_config()
|
|
||||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
|
||||||
config = await self.engine.get_parallel_config()
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown Config Request: %s", request)
|
|
||||||
|
|
||||||
await self.socket.send_multipart((identity, pickle.dumps(config)),
|
|
||||||
copy=False)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
|
||||||
copy=False)
|
|
||||||
|
|
||||||
async def is_tracing_enabled(self, identity):
|
|
||||||
"""Send the is_tracing_enabled flag"""
|
|
||||||
tracing_flag = await self.engine.is_tracing_enabled()
|
|
||||||
|
|
||||||
await self.socket.send_multipart(
|
|
||||||
(identity, pickle.dumps(tracing_flag)))
|
|
||||||
|
|
||||||
async def do_log_stats(self, identity):
|
|
||||||
"""Log stats and confirm success."""
|
|
||||||
await self.engine.do_log_stats()
|
|
||||||
|
|
||||||
await self.socket.send_multipart(
|
|
||||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
|
||||||
|
|
||||||
async def is_server_ready(self, identity):
|
|
||||||
"""Notify the client that we are ready."""
|
|
||||||
await self.socket.send_multipart(
|
|
||||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
|
||||||
|
|
||||||
async def abort(self, identity, request: RPCAbortRequest):
|
|
||||||
"""Abort request and notify the client of success."""
|
|
||||||
try:
|
|
||||||
# Abort the request in the llm engine.
|
|
||||||
await self.engine.abort(request.request_id)
|
|
||||||
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
|
|
||||||
except Exception as e:
|
|
||||||
result = e
|
|
||||||
await self.socket.send_multipart((identity, pickle.dumps(result)))
|
|
||||||
|
|
||||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
|
||||||
try:
|
|
||||||
results_generator = self.engine.generate(
|
|
||||||
generate_request.inputs,
|
|
||||||
sampling_params=generate_request.sampling_params,
|
|
||||||
request_id=generate_request.request_id,
|
|
||||||
lora_request=generate_request.lora_request,
|
|
||||||
trace_headers=generate_request.trace_headers,
|
|
||||||
prompt_adapter_request=generate_request.prompt_adapter_request)
|
|
||||||
|
|
||||||
async for request_output in results_generator:
|
|
||||||
await self.socket.send_multipart(
|
|
||||||
(identity, pickle.dumps(request_output)), copy=False)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
|
||||||
copy=False)
|
|
||||||
|
|
||||||
async def check_health(self, identity):
|
|
||||||
try:
|
|
||||||
await self.engine.check_health()
|
|
||||||
await self.socket.send_multipart(
|
|
||||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
|
||||||
copy=False)
|
|
||||||
|
|
||||||
async def start_profile(self, identity):
|
|
||||||
logger.info("Starting profiler...")
|
|
||||||
await self.engine.start_profile()
|
|
||||||
logger.info("Profiler started.")
|
|
||||||
|
|
||||||
await self.socket.send_multipart((
|
|
||||||
identity,
|
|
||||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
|
||||||
))
|
|
||||||
|
|
||||||
async def stop_profile(self, identity):
|
|
||||||
logger.info("Stopping profiler...")
|
|
||||||
await self.engine.stop_profile()
|
|
||||||
logger.info("Profiler stopped.")
|
|
||||||
|
|
||||||
await self.socket.send_multipart((
|
|
||||||
identity,
|
|
||||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
|
||||||
))
|
|
||||||
|
|
||||||
def _make_handler_coro(self, identity,
|
|
||||||
message: Frame) -> Coroutine[Any, Any, Never]:
|
|
||||||
"""Route the zmq message to the handler coroutine."""
|
|
||||||
|
|
||||||
request = cloudpickle.loads(message.buffer)
|
|
||||||
|
|
||||||
if isinstance(request, RPCGenerateRequest):
|
|
||||||
return self.generate(identity, request)
|
|
||||||
|
|
||||||
elif isinstance(request, RPCAbortRequest):
|
|
||||||
return self.abort(identity, request)
|
|
||||||
|
|
||||||
elif isinstance(request, RPCUtilityRequest):
|
|
||||||
if request in [
|
|
||||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
|
||||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
|
||||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
|
||||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
|
||||||
RPCUtilityRequest.GET_LORA_CONFIG
|
|
||||||
]:
|
|
||||||
return self.get_config(identity, request)
|
|
||||||
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
|
||||||
return self.do_log_stats(identity)
|
|
||||||
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
|
||||||
return self.is_server_ready(identity)
|
|
||||||
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
|
|
||||||
return self.check_health(identity)
|
|
||||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
|
||||||
return self.is_tracing_enabled(identity)
|
|
||||||
elif request == RPCUtilityRequest.START_PROFILE:
|
|
||||||
return self.start_profile(identity)
|
|
||||||
elif request == RPCUtilityRequest.STOP_PROFILE:
|
|
||||||
return self.stop_profile(identity)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown RPCRequest type: {request}")
|
|
||||||
|
|
||||||
async def run_server_loop(self):
|
|
||||||
"""Inner RPC Server Loop"""
|
|
||||||
|
|
||||||
running_tasks = set()
|
|
||||||
while True:
|
|
||||||
# Wait for a request.
|
|
||||||
identity, message = await self.socket.recv_multipart(copy=False)
|
|
||||||
|
|
||||||
# Process the request async.
|
|
||||||
task = asyncio.create_task(
|
|
||||||
self._make_handler_coro(identity, message))
|
|
||||||
|
|
||||||
# We need to keep around a strong reference to the task,
|
|
||||||
# to avoid the task disappearing mid-execution as running tasks
|
|
||||||
# can be GC'ed. Below is a common "fire-and-forget" tasks
|
|
||||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
|
||||||
running_tasks.add(task)
|
|
||||||
task.add_done_callback(running_tasks.discard)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_server(server: AsyncEngineRPCServer):
|
|
||||||
# Put the server task into the asyncio loop.
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
server_task = loop.create_task(server.run_server_loop())
|
|
||||||
|
|
||||||
# Interruption handling.
|
|
||||||
def signal_handler() -> None:
|
|
||||||
# Kill the server on interrupt / terminate
|
|
||||||
server_task.cancel()
|
|
||||||
|
|
||||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
|
||||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await server_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("vLLM ZMQ RPC Server was interrupted.")
|
|
||||||
finally:
|
|
||||||
# Clean up all resources.
|
|
||||||
server.cleanup()
|
|
||||||
|
|
||||||
|
|
||||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
|
||||||
usage_context: UsageContext, rpc_path: str):
|
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
|
||||||
# Interrupt server on sigterm while initializing
|
|
||||||
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
|
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
|
||||||
|
|
||||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
|
|
||||||
uvloop.run(run_server(server))
|
|
@ -9,7 +9,7 @@ from typing import Union
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
@ -45,7 +45,7 @@ logger = init_logger(__name__)
|
|||||||
class OpenAIServingChat(OpenAIServing):
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
response_role: str,
|
response_role: str,
|
||||||
@ -57,7 +57,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
enable_auto_tools: bool = False,
|
enable_auto_tools: bool = False,
|
||||||
tool_parser: Optional[str] = None):
|
tool_parser: Optional[str] = None):
|
||||||
super().__init__(async_engine_client=async_engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
@ -105,6 +105,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.error("Error with model %s", error_check_ret)
|
logger.error("Error with model %s", error_check_ret)
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
@ -112,8 +118,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
lora_request)
|
|
||||||
|
|
||||||
conversation, mm_data_future = parse_chat_messages_futures(
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
request.messages, model_config, tokenizer)
|
request.messages, model_config, tokenizer)
|
||||||
@ -207,8 +212,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
engine_inputs["multi_modal_data"] = mm_data
|
engine_inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
is_tracing_enabled = (
|
is_tracing_enabled = (await
|
||||||
await self.async_engine_client.is_tracing_enabled())
|
self.engine_client.is_tracing_enabled())
|
||||||
trace_headers = None
|
trace_headers = None
|
||||||
if is_tracing_enabled and raw_request:
|
if is_tracing_enabled and raw_request:
|
||||||
trace_headers = extract_trace_headers(raw_request.headers)
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
@ -216,7 +221,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
and contains_trace_headers(raw_request.headers)):
|
and contains_trace_headers(raw_request.headers)):
|
||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
result_generator = self.async_engine_client.generate(
|
result_generator = self.engine_client.generate(
|
||||||
engine_inputs,
|
engine_inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
|
@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
*,
|
*,
|
||||||
@ -52,7 +52,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(async_engine_client=async_engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
@ -78,6 +78,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
|
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||||
|
# This is required for the streaming case, where we return a
|
||||||
|
# success status before we actually start generating text :).
|
||||||
|
if self.engine_client.errored:
|
||||||
|
raise self.engine_client.dead_error
|
||||||
|
|
||||||
# Return error for unsupported features.
|
# Return error for unsupported features.
|
||||||
if request.suffix is not None:
|
if request.suffix is not None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@ -95,8 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
lora_request)
|
|
||||||
|
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await self._guided_decode_logits_processor(request, tokenizer))
|
await self._guided_decode_logits_processor(request, tokenizer))
|
||||||
@ -124,8 +129,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request)
|
prompt_adapter_request=prompt_adapter_request)
|
||||||
|
|
||||||
is_tracing_enabled = (
|
is_tracing_enabled = (await
|
||||||
await self.async_engine_client.is_tracing_enabled())
|
self.engine_client.is_tracing_enabled())
|
||||||
trace_headers = None
|
trace_headers = None
|
||||||
if is_tracing_enabled:
|
if is_tracing_enabled:
|
||||||
trace_headers = extract_trace_headers(raw_request.headers)
|
trace_headers = extract_trace_headers(raw_request.headers)
|
||||||
@ -133,7 +138,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
raw_request.headers):
|
raw_request.headers):
|
||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
generator = self.async_engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
|
@ -8,7 +8,7 @@ from fastapi import Request
|
|||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
):
|
):
|
||||||
super().__init__(async_engine_client=async_engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
lora_request)
|
|
||||||
|
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
"Prompt adapter is not supported "
|
"Prompt adapter is not supported "
|
||||||
"for embedding models")
|
"for embedding models")
|
||||||
|
|
||||||
generator = self.async_engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||||
pooling_params,
|
pooling_params,
|
||||||
request_id_item,
|
request_id_item,
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import Field
|
|||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -64,7 +64,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
*,
|
*,
|
||||||
@ -75,7 +75,7 @@ class OpenAIServing:
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.async_engine_client = async_engine_client
|
self.engine_client = engine_client
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.max_model_len = model_config.max_model_len
|
self.max_model_len = model_config.max_model_len
|
||||||
|
|
||||||
@ -159,7 +159,7 @@ class OpenAIServing:
|
|||||||
async def _guided_decode_logits_processor(
|
async def _guided_decode_logits_processor(
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
||||||
decoding_config = await self.async_engine_client.get_decoding_config()
|
decoding_config = await self.engine_client.get_decoding_config()
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
or decoding_config.guided_decoding_backend
|
or decoding_config.guided_decoding_backend
|
||||||
return await get_guided_decoding_logits_processor(
|
return await get_guided_decoding_logits_processor(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
async_engine_client: AsyncEngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
served_model_names: List[str],
|
served_model_names: List[str],
|
||||||
*,
|
*,
|
||||||
@ -37,7 +37,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
):
|
):
|
||||||
super().__init__(async_engine_client=async_engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
served_model_names=served_model_names,
|
served_model_names=served_model_names,
|
||||||
lora_modules=lora_modules,
|
lora_modules=lora_modules,
|
||||||
@ -66,7 +66,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
prompt: Union[str, List[int]]
|
prompt: Union[str, List[int]]
|
||||||
if isinstance(request, TokenizeChatRequest):
|
if isinstance(request, TokenizeChatRequest):
|
||||||
@ -132,7 +132,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
request.tokens,
|
request.tokens,
|
||||||
|
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
|||||||
VERBOSE: bool = False
|
VERBOSE: bool = False
|
||||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||||
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
|
VLLM_RPC_TIMEOUT: int = 10000 # ms
|
||||||
VLLM_PLUGINS: Optional[List[str]] = None
|
VLLM_PLUGINS: Optional[List[str]] = None
|
||||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||||
VLLM_USE_TRITON_AWQ: bool = False
|
VLLM_USE_TRITON_AWQ: bool = False
|
||||||
@ -393,8 +393,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# Time in ms for the zmq client to wait for a response from the backend
|
# Time in ms for the zmq client to wait for a response from the backend
|
||||||
# server for simple data operations
|
# server for simple data operations
|
||||||
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
|
"VLLM_RPC_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
|
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
|
||||||
|
|
||||||
# a list of plugin names to load, separated by commas.
|
# a list of plugin names to load, separated by commas.
|
||||||
# if this is not set, it means all plugins will be loaded
|
# if this is not set, it means all plugins will be loaded
|
||||||
|
@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
|
|||||||
)) for rank in range(1, world_size)
|
)) for rank in range(1, world_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.worker_monitor = None
|
||||||
if world_size != 1 or is_async:
|
if world_size != 1 or is_async:
|
||||||
if is_async:
|
if is_async:
|
||||||
async_worker_list = self.workers + [self.driver_worker]
|
async_worker_list = self.workers + [self.driver_worker]
|
||||||
|
@ -168,6 +168,8 @@ class ProcessWorkerWrapper:
|
|||||||
self.tasks[task_id] = future
|
self.tasks[task_id] = future
|
||||||
try:
|
try:
|
||||||
self._task_queue.put((task_id, method, args, kwargs))
|
self._task_queue.put((task_id, method, args, kwargs))
|
||||||
|
except SystemExit:
|
||||||
|
raise
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
del self.tasks[task_id]
|
del self.tasks[task_id]
|
||||||
raise ChildProcessError("worker died") from e
|
raise ChildProcessError("worker died") from e
|
||||||
@ -222,6 +224,8 @@ def _run_worker_process(
|
|||||||
try:
|
try:
|
||||||
executor = getattr(worker, method)
|
executor = getattr(worker, method)
|
||||||
output = executor(*args, **kwargs)
|
output = executor(*args, **kwargs)
|
||||||
|
except SystemExit:
|
||||||
|
raise
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
break
|
break
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user