[Core][Bugfix][Perf] Introduce MQLLMEngine to avoid asyncio OH (#8157)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Alexander Matveev 2024-09-18 09:56:58 -04:00 committed by GitHub
parent 9d104b5beb
commit 7c7714d856
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1464 additions and 1169 deletions

View File

@ -43,13 +43,15 @@ steps:
fast_check: true
source_file_dependencies:
- vllm/
- tests/mq_llm_engine
- tests/async_engine
- tests/test_inputs
- tests/multimodal
- tests/test_utils
- tests/worker
commands:
- pytest -v -s async_engine # Async Engine
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
- pytest -v -s multimodal

View File

@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/.
.. tip::
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
``export VLLM_RPC_TIMEOUT=1800000``
Example commands and usage:
===========================

View File

@ -1,106 +0,0 @@
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
@pytest.fixture(scope="module")
def server():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--enforce-eager",
"--chat-template",
str(chatml_jinja_path),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
@pytest.mark.asyncio
async def test_single_completion(client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 5
@pytest.mark.asyncio
async def test_single_chat_session(client: openai.AsyncOpenAI):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# test single completion
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5)
assert chat_completion.id is not None
assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=55, total_tokens=65)
message = choice.message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content})
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0

View File

@ -1,120 +0,0 @@
import asyncio
import tempfile
import unittest
import unittest.mock
import uuid
import pytest
import pytest_asyncio
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
RPCClientClosedError)
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest_asyncio.fixture(scope="function")
async def dummy_server(tmp_socket, monkeypatch):
dummy_engine = unittest.mock.AsyncMock()
def dummy_engine_builder(*args, **kwargs):
return dummy_engine
with monkeypatch.context() as m:
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
try:
yield server
finally:
server_task.cancel()
server.cleanup()
@pytest_asyncio.fixture(scope="function")
async def client(tmp_socket):
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
# Sanity check: the server is connected
await client._wait_for_server_rpc()
try:
yield client
finally:
client.close()
@pytest.mark.asyncio
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server _not_ reply with a model config
m.setattr(dummy_server, "get_config", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# And ensure the task completes anyway
# (client.setup() invokes server.get_config())
client_task = asyncio.get_running_loop().create_task(client.setup())
with pytest.raises(TimeoutError, match="Server didn't reply within"):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Hang all abort requests
m.setattr(dummy_server, "abort", lambda x: None)
m.setattr(client, "_data_timeout", 10)
# The client should suppress timeouts on `abort`s
# and return normally, assuming the server will eventually
# abort the request.
client_task = asyncio.get_running_loop().create_task(
client.abort("test request id"))
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_data_methods_reraise_exceptions(
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
with monkeypatch.context() as m:
# Make the server raise some random exception
exception = RuntimeError("Client test exception")
def raiser():
raise exception
m.setattr(dummy_server.engine, "get_model_config", raiser)
m.setattr(client, "_data_timeout", 10)
client_task = asyncio.get_running_loop().create_task(client.setup())
# And ensure the task completes, raising the exception
with pytest.raises(RuntimeError, match=str(exception)):
await asyncio.wait_for(client_task, timeout=0.05)
@pytest.mark.asyncio
async def test_client_errors_after_closing(monkeypatch, dummy_server,
client: AsyncEngineRPCClient):
client.close()
# Healthchecks and generate requests will fail with explicit errors
with pytest.raises(RPCClientClosedError):
await client.check_health()
with pytest.raises(RPCClientClosedError):
async for _ in client.generate(None, None, None):
pass
# But no-ops like aborting will pass
await client.abort("test-request-id")
await client.do_log_stats()

View File

@ -18,38 +18,32 @@ TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "4096", "--enable-chunked-prefill",
"--disable-log-requests", "--enforce-eager"
]
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
def test_lm_eval_accuracy(more_args):
args = list(DEFAULT_ARGS)
args.extend(more_args)
print(f"Running with: {args}")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
url = f"{remote_server.url_for('v1')}/completions"
model_args = (
f"model={MODEL_NAME},"
f"base_url={url},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
@pytest.fixture(scope="module")
def server_data(server):
return {
"url": f"{server.url_for('v1')}/completions",
}
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
def test_lm_eval_accuracy(server_data):
model_args = (f"model={MODEL_NAME},"
f"base_url={server_data['url']},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

View File

@ -5,7 +5,7 @@ from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ..utils import VLLM_PATH
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

View File

@ -1,40 +0,0 @@
import time
import pytest
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
@pytest.mark.asyncio
async def test_mp_crash_detection():
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# use an invalid tensor_parallel_size to trigger the
# error in the server
args.tensor_parallel_size = 65536
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from unittest.mock import MagicMock
from vllm.config import MultiModalConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.transformers_utils.tokenizer import get_tokenizer
@ -52,8 +52,9 @@ def test_async_serving_chat_init():
def test_serving_chat_should_set_correct_max_tokens():
mock_engine = MagicMock(spec=AsyncLLMEngine)
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
@ -18,7 +18,7 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=AsyncEngineClient)
mock_engine_client = MagicMock(spec=EngineClient)
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

View File

@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path):
prompt="Hello, my name is")
# Now the server should shut down
return_code = remote_server.proc.wait(timeout=3)
return_code = remote_server.proc.wait(timeout=8)
assert return_code is not None

View File

@ -0,0 +1,67 @@
"""Test that aborting is handled properly."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_id_to_be_aborted = "request-aborted"
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
# Requests started before one to be aborted.
tasks = []
for request_id in request_ids_a:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Aborted.
task_aborted = asyncio.create_task(
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
# Requests started after one to be aborted.
for request_id in request_ids_b:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Actually abort.
await asyncio.sleep(0.5)
await client.abort(request_id_to_be_aborted)
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
count, request_id = await task
assert count == EXPECTED_TOKENS, (
f"{request_id} generated only {count} tokens")
# Cancel task (this will hang indefinitely if not).
task_aborted.cancel()
# Shutdown.
client.close()

View File

@ -0,0 +1,244 @@
"""Test that various errors are handled properly."""
import asyncio
import tempfile
import time
import uuid
from unittest.mock import Mock
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.execute_model = Mock(
side_effect=RAISED_ERROR(RAISED_VALUE))
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_evil_forward(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_forward) as engine:
client = await engine.make_client()
# Server should be healthy after initial probe.
await asyncio.sleep(2.0)
await client.check_health()
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
assert client.errored
await asyncio.sleep(1.0)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Shutdown.
client.close()
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during first forward pass.
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_health_check(tmp_socket):
with RemoteMQLLMEngine(
engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_model_executor_health) as engine:
client = await engine.make_client()
assert client.is_running
# Health probe should throw RAISED_ERROR.
await asyncio.sleep(15.)
with pytest.raises(RAISED_ERROR):
await client.check_health()
assert client.errored
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
client.close()
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Raise error during abort call.
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_abort) as engine:
client = await engine.make_client()
assert client.is_running
# Firsh check health should work.
await client.check_health()
# Trigger an abort on the client side.
async def bad_abort_after_2s():
await asyncio.sleep(2.0)
await client.abort(request_id="foo")
# Trigger an abort in 2s from now.
abort_task = asyncio.create_task(bad_abort_after_2s())
# Exception in abort() will happen during this generation.
# This will kill the engine and should return ENGINE_DEAD_ERROR
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
assert "KeyError" in repr(execinfo.value)
assert client.errored
await abort_task
# This should raise the original error.
with pytest.raises(RAISED_ERROR):
await client.check_health()
client.close()
@pytest.mark.asyncio
async def test_bad_request(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
"invalid-lora", 1,
"invalid-path")):
pass
# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass
# Shutdown.
client.close()
@pytest.mark.asyncio
async def test_mp_crash_detection(monkeypatch):
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
# When LLMEngine is loaded, it will crash.
def mock_init():
raise ValueError
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
start = time.perf_counter()
async with build_async_engine_client(args):
pass
end = time.perf_counter()
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
"if there is an error in the startup.")
@pytest.mark.asyncio
async def test_mp_cuda_init():
# it should not crash, when cuda is initialized
# in the API server process
import torch
torch.cuda.init()
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser)
args = parser.parse_args([])
async with build_async_engine_client(args):
pass

View File

@ -0,0 +1,57 @@
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
NUM_EXPECTED_TOKENS = 10
NUM_REQUESTS = 10000
# Scenarios to test for num generated token.
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_load(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
# Create concurrent requests.
tasks = []
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(client, request_id, NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
failed_request_id = None
tokens = None
for task in tasks:
num_generated_tokens, request_id = await task
if (num_generated_tokens != NUM_EXPECTED_TOKENS
and failed_request_id is None):
failed_request_id = request_id
tokens = num_generated_tokens
assert failed_request_id is None, (
f"{failed_request_id} generated {tokens} but "
f"expected {NUM_EXPECTED_TOKENS}")
# Shutdown.
client.close()

View File

@ -0,0 +1,78 @@
import asyncio
import multiprocessing
from typing import Callable, Tuple, Union
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
async def generate(
client: MQLLMEngineClient,
request_id: str,
num_tokens: int,
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
final_output = None
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):
count += 1
final_output = out
await asyncio.sleep(0.)
if return_output:
return final_output
# Confirm we generated all the tokens we expected.
return count, request_id
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
# Run engine.
engine.start()
class RemoteMQLLMEngine:
def __init__(self,
engine_args: AsyncEngineArgs,
ipc_path: str,
run_fn: Callable = run_normal) -> None:
self.engine_args = engine_args
self.ipc_path = ipc_path
context = multiprocessing.get_context("spawn")
self.proc = context.Process(target=run_fn,
args=(engine_args, ipc_path))
self.proc.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.kill()
async def make_client(self) -> MQLLMEngineClient:
engine_config = self.engine_args.create_engine_config()
client = MQLLMEngineClient(self.ipc_path, engine_config)
while True:
try:
await client.setup()
break
except TimeoutError:
assert self.proc.is_alive()
return client

View File

@ -1,5 +1,12 @@
import os
from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation
# this times out default Health Check in the MQLLMEngine,
# so we set the timeout here to 30s
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",

View File

@ -119,7 +119,7 @@ class RemoteOpenAIServer:
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
try:
self.proc.wait(3)
self.proc.wait(8)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()

View File

@ -601,9 +601,12 @@ class AsyncLLMEngine:
return self._errored_with is not None
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
return None
def dead_error(self) -> BaseException:
return AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc

View File

@ -1289,6 +1289,7 @@ class LLMEngine:
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs

View File

@ -0,0 +1,73 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
VLLM_RPC_SUCCESS_STR = "SUCCESS"
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
class MQEngineDeadError(RuntimeError):
pass
@dataclass
class RPCGenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCError:
request_id: Optional[str]
is_engine_errored: bool
exception: BaseException
@dataclass
class RPCAbortRequest:
request_id: str
class RPCHealthRequest:
pass
class RPCStartupRequest(Enum):
IS_SERVER_READY = 1
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
RPCStartupRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
def ENGINE_DEAD_ERROR(
error: Optional[BaseException] = None) -> MQEngineDeadError:
if error is None:
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
"find the original error")
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")

View File

@ -0,0 +1,452 @@
import asyncio
import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union)
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
logger = init_logger(__name__)
class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class MQLLMEngineClient:
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
MQLLMEngine and MQLLMEngineClient are intended to run in separate
processes communicating via zeromq ipc sockets.
The entrypoint to MQLLMEngineClient is through the generate()
method. On generate() MQLLMEngine does three things:
- Creates an asyncio output queue
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
- Pulls RequestOutputs from its queue and yields them
MQLLMEngine runs two background loops:
- output_loop: the output loop pulls List[RequestOutput]
from the MQLLMEngine via zmq (each list is the output
of one engine_step in the LLMEngine). It then parses
the list and pushes individual request_outputs into
the corresponding output_queue such that they can be
consumed by the .generate() method.
- health_loop: the health loop queries the health socket
every N seconds, confirming the engine is healthy
"""
def __init__(self, ipc_path: str, engine_config: EngineConfig):
self.context = zmq.asyncio.Context()
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
# Create the tokenizer group.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config),
)
# Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
# Receive streams of RequestOutput from the MQLLMEngine.
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
# IPC path for ack of check_health requests.
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Stream for each individual request.
self.output_queues: Dict[str, asyncio.Queue] = {}
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
if engine_args.pipeline_parallel_size > 1:
return True
is_embedding = ModelConfig(
model=engine_args.model,
revision=engine_args.revision,
tokenizer=engine_args.model,
tokenizer_mode="auto",
trust_remote_code=engine_args.trust_remote_code,
quantization=engine_args.quantization,
seed=0,
dtype="auto").embedding_mode
return is_embedding
@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
async def run_check_health_loop(self, timeout: int):
"""Background loop that continually probes the RPCServer for health.
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
the MQLLMEngine server is blocking on.
The Server replies on the HEALTH_SOCKET (rather than on the
OUTPUT_SOCKET such that the messages are not intermingled with
output streaming).
"""
try:
while True:
if await self.health_socket.poll(timeout=timeout) == 0:
# Wakeup every N seconds and do a health probe.
await self._send_one_way_rpc_request(
RPCHealthRequest(), self.input_socket)
# Wait for ack from the health socket.
await self._await_ack(error_message="Health check failed.",
socket=self.health_socket)
else:
# Server sent a health status message unprompted.
await self._check_success(
error_message="Health check failed.",
socket=self.health_socket)
logger.debug("Health probe successful.")
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check health loop.")
except Exception as e:
self._set_errored(e)
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
try:
while True:
# Poll, checking for ENGINE_DEAD
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
) == 0:
logger.debug("Waiting for output from MQLLMEngine.")
# If errored, alert all running requests.
if self.errored:
for queue_j in tuple(self.output_queues.values()):
queue_j.put_nowait(
ENGINE_DEAD_ERROR(self._errored_with))
return
message: Frame = await self.output_socket.recv(copy=False)
request_outputs = pickle.loads(message.buffer)
is_error = isinstance(request_outputs,
(BaseException, RPCError))
if is_error:
if isinstance(request_outputs, RPCError):
rpc_error: RPCError = request_outputs
request_id = rpc_error.request_id
exception = rpc_error.exception
is_engine_errored = rpc_error.is_engine_errored
else:
# MPLLMEngine should always return an RPCError to
# the output_socket when an issue arises.
# If we are here, we are in a bad state and
# should shut down the server.
error: BaseException = request_outputs
logger.error(
"Received Exception %s rather than RPCError from "
"MPLLMEngine. This should never happen.", error)
request_id = None
exception = error
is_engine_errored = True
# Set to error state only on engine critical error
# (and record only the first one)
if is_engine_errored and not self._errored_with:
self._errored_with = exception
if request_id is None:
for queue_i in tuple(self.output_queues.values()):
queue_i.put_nowait(exception)
else:
queue = self.output_queues.get(request_id)
if queue is not None:
queue.put_nowait(exception)
else:
# Put each output into the appropriate steam.
for request_output in request_outputs:
queue = self.output_queues.get(
request_output.request_id)
if queue is not None:
queue.put_nowait(request_output)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient output handler.")
async def setup(self):
"""Setup the client before it starts sending server requests."""
with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
self.tracing_flag = response.tracing_enabled
# Start health_loop.
self.health_loop = asyncio.create_task(
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets and terminate the context.
self.context.destroy(linger=0)
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
self.output_loop.cancel()
def _set_errored(self, e: BaseException):
logger.exception(repr(e))
if self._errored_with is None:
self._errored_with = e
@staticmethod
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket) -> Any:
"""Send an RPC request that is expecting data back."""
# Ping RPCServer with a request.
await socket.send_multipart((pickle.dumps(request), ), copy=False)
# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
data = pickle.loads(frame.buffer)
if isinstance(data, BaseException):
raise data
elif not isinstance(data, expected_type):
raise ValueError(error_message)
return data
@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
socket: Socket):
"""Send one-way RPC request to trigger an action."""
if socket.closed:
raise MQClientClosedError()
await socket.send_multipart((pickle.dumps(request), ))
async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
if socket.closed:
raise MQClientClosedError()
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("MQLLMEngine didn't reply within "
f"{VLLM_RPC_TIMEOUT}ms")
await self._check_success(error_message, socket)
@staticmethod
async def _check_success(error_message: str, socket: Socket):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
if socket.closed:
raise MQClientClosedError()
frame = await socket.recv(copy=False)
response = pickle.loads(frame.buffer)
# Raise error if unsuccessful
if isinstance(response, BaseException):
raise response
elif (not isinstance(response, str)
or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
"""Wait for the RPCServer to start up."""
return await self._send_get_data_rpc_request(
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
socket=socket)
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
with suppress(MQClientClosedError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id), socket=self.input_socket)
async def do_log_stats(self):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
pass
async def check_health(self):
"""
The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with
if the engine is unhealthy.
"""
if self._errored_with is not None:
raise self._errored_with
@property
def is_running(self) -> bool:
return not self.errored
@property
def is_stopped(self) -> bool:
return self.errored
@property
def errored(self) -> bool:
return self._errored_with is not None
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
# If already dead, error out.
if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with)
# 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue
try:
# 2) Detach logits processors so that they can be pickled
# separately (may require cloudpickle which is slower)
if sampling_params.logits_processors:
# Defensive shallow copy
sampling_params = copy.copy(sampling_params)
logits_processors = sampling_params.logits_processors
sampling_params.logits_processors = None
lp_bytes = cloudpickle.dumps(logits_processors)
else:
lp_bytes = None
request_bytes = pickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
# queue after pulling them from the zmq socket.
finished = False
try:
while not finished:
request_output = await queue.get()
if isinstance(request_output, BaseException):
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self.errored:
await self.abort(request_id)
finally:
self.output_queues.pop(request_id)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")

View File

@ -0,0 +1,321 @@
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
import cloudpickle
import zmq
from vllm import AsyncEngineArgs, LLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCGenerateRequest,
RPCHealthRequest, RPCStartupRequest,
RPCStartupResponse)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
The :class:`LLMEngine.generate` is kicked off when a new
RPCGenerateRequest is received by the input_socket.
The self.engine_loop checks the input_socket for new requests,
adds them to the LLMEngine if there are any, calls the internal
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
the output_socket.
If use_async_sockets is set, the logic associated with reading new
requests from the socket and sending data to the socket is passed
as a callback to the llm_engine, which calls the logic asynchronously
such that the IPC can be overlapped with the GPU.
Args:
ipc_path: Base path for zeromq interprocess messaging
use_async_sockets: Whether to make send/recv async with GPU
log_requests: Whether to log the requests.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.engine = LLMEngine(*args, **kwargs)
self.log_requests = log_requests
self.use_async_sockets = use_async_sockets
if self.use_async_sockets:
self.engine.process_request_outputs_callback = \
self._async_socket_engine_callback
self.ctx = zmq.Context() # type: ignore[attr-defined]
# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
# Send output stream back to client.
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
# Send health status back to client.
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
return ENGINE_DEAD_ERROR(self._errored_with)
else:
return ENGINE_DEAD_ERROR()
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
"""Creates an MQLLMEngine from the engine arguments."""
engine_config = engine_args.create_engine_config()
executor_class = LLMEngine._get_executor_cls(engine_config)
return cls(
ipc_path=ipc_path,
use_async_sockets=engine_config.model_config.use_async_output_proc,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context)
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
logger.exception(repr(e))
except KeyboardInterrupt:
logger.debug("Shutting down MQLLMEngine.")
finally:
logger.debug("MQLLMEngine is shut down.")
self.cleanup()
def cleanup(self):
"""Cleanup zeromq state on shutdown."""
# Closes all sockets and destroys context.
self.ctx.destroy(linger=0)
del self.engine
@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
try:
socket.bind(self.data_ipc_path)
yield socket
finally:
socket.close(linger=0)
def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
request: RPCStartupRequest = pickle.loads(message.buffer)
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
self.engine.do_log_stats()
logger.debug("Waiting for new requests in engine loop.")
# Handle any input from the client.
self.handle_new_input()
# Engine step.
request_outputs = self.engine_step()
# Send request outputs (if async, done in engine_step callback).
if not self.use_async_sockets:
self._send_outputs(request_outputs)
def engine_step(self) -> List[RequestOutput]:
"""Engine step wrapper with error handling."""
try:
return self.engine.step()
except SystemExit:
raise
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,
is_engine_errored=True,
exception=e)
self._send_outputs(rpc_err)
raise e
def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
if isinstance(request, RPCGenerateRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
lprocs = cloudpickle.loads(frames[1].buffer)
request.sampling_params.logits_processors = lprocs
self._handle_generate_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
elif isinstance(request, RPCHealthRequest):
self._handle_health_request()
else:
raise ValueError("Unknown RPCRequest Type: {request}")
except Exception as e:
self._set_errored(e)
self._send_unhealthy(e)
raise e
def _handle_generate_request(self, request: RPCGenerateRequest):
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
request_id = request.request_id
if self._errored_with is not None:
rpc_err = RPCError(request_id=request_id,
is_engine_errored=True,
exception=ENGINE_DEAD_ERROR(self._errored_with))
self._send_outputs(rpc_err)
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
params=request.sampling_params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
except Exception as e:
# We do not set self._errored = True here, since the error
# is due to an issue adding this request to the engine,
# rather than an issue with the engine itself.
is_errored = self._errored_with is not None
rpc_err = RPCError(request_id=request_id,
is_engine_errored=is_errored,
exception=e)
self._send_outputs(rpc_err)
# Remove request from the engine.
self.engine.abort_request(request_id)
def _handle_abort_request(self, request: RPCAbortRequest):
self.engine.abort_request(request.request_id)
if self.log_requests:
logger.info("Aborted request %s.", request.request_id)
def _handle_health_request(self):
if self._errored_with is not None:
self._send_unhealthy(self._errored_with)
# Raises error if unhealthy.
self.engine.check_health()
self._send_healthy()
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
"""Send List of RequestOutput to RPCClient."""
if outputs:
output_bytes = pickle.dumps(outputs)
self.output_socket.send_multipart((output_bytes, ), copy=False)
def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
error_bytes = pickle.dumps(error)
self.health_socket.send_multipart((error_bytes, ), copy=False)
def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
"""Callback used by engine to make socket handling async with GPU."""
self._send_outputs(request_outputs)
self.handle_new_input()
def _set_errored(self, e: BaseException):
"""Log and set errored status if this is the first issue."""
if self._errored_with is None:
self._errored_with = e
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated")
signal.signal(signal.SIGTERM, signal_handler)
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
engine.start()

View File

@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class AsyncEngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine"""
class EngineClient(Protocol):
"""Protocol class for Clients to Engine"""
@property
def is_running(self) -> bool:
@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
...
@property
def limit_concurrency(self) -> Optional[int]:
"""Maximum number of concurrently running requests."""
def dead_error(self) -> BaseException:
...
def generate(
self,

View File

@ -1,21 +1,21 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any, Optional
from typing import Any
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
**uvicorn_kwargs: Any):
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
@ -63,7 +54,7 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
logger.debug(
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Gracefully stopping http server")
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
@ -90,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __):
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
@ -99,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

View File

@ -26,7 +26,9 @@ import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
@ -44,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -67,29 +67,16 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: Optional[str],
revision: Optional[str]) -> bool:
return ModelConfig(model=model_name,
revision=revision,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
quantization=quantization,
seed=0,
dtype="auto").embedding_mode
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
async_engine_client = app.state.engine_client
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
await asyncio.sleep(10.)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
@ -108,9 +95,9 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
args: Namespace) -> AsyncIterator[Optional[EngineClient]]:
# Context manager to handle async_engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
@ -123,19 +110,18 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
) -> AsyncIterator[Optional[EngineClient]]:
"""
Create AsyncEngineClient, either:
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization, engine_args.revision)
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args(
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.",
rpc_path)
ipc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process = context.Process(
target=run_rpc_server,
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
rpc_server_process.start()
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try:
while True:
try:
await rpc_client.setup()
await mp_engine_client.setup()
break
except TimeoutError:
if not rpc_server_process.is_alive():
logger.error(
"RPCServer process died before responding "
"to readiness probe")
if not engine_process.is_alive():
logger.error("Engine process died before responding "
"to readiness probe")
yield None
return
yield rpc_client # type: ignore[misc]
yield mp_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
engine_process.terminate()
# Close all open connections to the backend
rpc_client.close()
mp_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> AsyncEngineClient:
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@ -473,7 +463,7 @@ def build_app(args: Namespace) -> FastAPI:
def init_app_state(
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
@ -488,11 +478,11 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
state.engine_client = async_engine_client
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
async_engine_client,
engine_client,
model_config,
served_model_names,
args.response_role,
@ -504,7 +494,7 @@ def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@ -513,13 +503,13 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@ -541,21 +531,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as async_engine_client:
async with build_async_engine_client(args) as engine_client:
# If None, creation of the client failed and we exit.
if async_engine_client is None:
if engine_client is None:
return
app = build_app(args)
model_config = await async_engine_client.get_model_config()
init_app_state(async_engine_client, model_config, app.state, args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http(
app,
limit_concurrency=async_engine_client.limit_concurrency,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,

View File

@ -1,50 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from typing import Mapping, Optional, Union
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR = "SUCCESS"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM = 0
@dataclass
class RPCGenerateRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@dataclass
class RPCAbortRequest:
request_id: str
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
IS_SERVER_HEALTHY = 8
IS_TRACING_ENABLED = 9
START_PROFILE = 10
STOP_PROFILE = 11
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
RPCUtilityRequest]

View File

@ -1,451 +0,0 @@
import asyncio
import pickle
from contextlib import contextmanager, suppress
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
from uuid import uuid4
import cloudpickle
import zmq
import zmq.asyncio
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
# yapf: disable
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
logger = init_logger(__name__)
# Path used for inprocess proxy.
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
class RPCClientClosedError(Exception):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class AsyncEngineRPCClient:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
self._errored = False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
assert isinstance(socket_limit, int)
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate.")
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
# IPC connection to RPC Server (uses unix sockets).
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.to_rpc_server.bind(rpc_path)
# In process proxy to RPC Server (uses memory-based messaging).
self.from_api_server: Socket = self.context.socket(
zmq.constants.ROUTER)
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
self.from_api_server.bind(INPROC_PROXY_PATH)
# Asyncio background task for the proxy.
self.proxy_in_task = asyncio.create_task(
self.run_proxy(self.from_api_server, self.to_rpc_server))
self.proxy_out_task = asyncio.create_task(
self.run_proxy(self.to_rpc_server, self.from_api_server))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self.limit_concurrency = socket_limit // 2 - 2
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
"""Background task that runs a proxy"""
while True:
frames = await socket_from.recv_multipart(copy=False)
await socket_to.send_multipart(frames, copy=False)
async def setup(self):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await self._wait_for_server_rpc()
# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()
self.tracing_flag = await self._is_tracing_enabled_rpc()
# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)
def close(self):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self.from_api_server.close()
self.to_rpc_server.close()
self.context.destroy()
@contextmanager
def to_proxy_socket(self) -> Iterator[Socket]:
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
try:
socket.connect(INPROC_PROXY_PATH)
yield socket
finally:
socket.close(linger=0)
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""
with self.to_proxy_socket() as socket:
# Ping RPCServer with a request.
await socket.send_multipart((cloudpickle.dumps(request), ),
copy=False)
# Make sure the server responds
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
# Await the data from the Server.
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
data = pickle.loads(frame.buffer)
if isinstance(data, Exception):
# Re-raise exceptions returned by the server
raise data
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
elif isinstance(data, Exception):
logger.error(error_message)
raise data
else:
raise ValueError(error_message)
return data
async def _send_one_way_rpc_request(self,
request: RPC_REQUEST_TYPE,
error_message: str,
socket: Optional[Socket] = None):
"""Send one-way RPC request to trigger an action."""
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
await socket.send_multipart((cloudpickle.dumps(request), ))
if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
frame = await socket.recv(copy=False)
assert isinstance(frame, Frame)
return pickle.loads(frame.buffer)
# Make a new socket connection.
if socket is None:
with self.to_proxy_socket() as socket:
response = await do_rpc_call(socket, request)
# Use existing socket connection.
else:
response = await do_rpc_call(socket, request)
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
if isinstance(response, Exception):
logger.error(error_message)
raise response
raise ValueError(error_message)
async def get_tokenizer(self, lora_request: LoRARequest):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
async def get_model_config(self) -> ModelConfig:
return self.model_config
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag
async def _wait_for_server_rpc(self):
"""Wait for the RPCServer to start up."""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server")
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ParallelConfig from RPC Server")
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")
async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")
async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer"""
return await self._send_get_data_rpc_request(
RPCUtilityRequest.IS_TRACING_ENABLED,
expected_type=bool,
error_message="Could not get is_tracing_enabled from RPC Server")
async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with suppress(RPCClientClosedError, TimeoutError):
await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
error_message=f"RPCAbortRequest {request_id} failed")
async def do_log_stats(self):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with suppress(RPCClientClosedError):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")
@property
def is_running(self) -> bool:
return not self._errored
@property
def is_stopped(self) -> bool:
return self._errored
@property
def errored(self) -> bool:
return self._errored
async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
finished = False
try:
with self.to_proxy_socket() as socket:
# Send RPCGenerateRequest to the RPCServer.
await socket.send_multipart((cloudpickle.dumps(
RPCGenerateRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)), ))
# Stream back the results from the RPC Server.
while not finished:
message = await socket.recv(copy=False)
assert isinstance(message, Frame)
request_output = pickle.loads(message.buffer)
if isinstance(request_output, Exception):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if not self._errored:
try:
await self.check_health(socket=socket)
except Exception as e:
self._errored = True
logger.exception(repr(e))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output
finished = request_output.finished
yield request_output
finally:
# Request was canceled by the client.
if not finished and not self._errored:
await self.abort(request_id)
async def check_health(self, socket: Optional[Socket] = None) -> None:
"""Raise if unhealthy"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
error_message="Got Unhealthy response from RPC Server",
socket=socket)
async def encode(self, *args,
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
raise NotImplementedError(
"Embeddings not supported with multiprocessing backend")
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.START_PROFILE,
error_message="RPCRequest START_PROFILE failed.")
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.STOP_PROFILE,
error_message="RPCRequest STOP_PROFILE failed.")

View File

@ -1,243 +0,0 @@
import asyncio
import pickle
import signal
from typing import Any, Coroutine, Union
import cloudpickle
import uvloop
import zmq
import zmq.asyncio
from typing_extensions import Never
from zmq import Frame # type: ignore[attr-defined]
from zmq.asyncio import Socket
from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
SchedulerConfig, LoRAConfig]
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(
async_engine_args, usage_context=usage_context)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket.
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
self.socket.connect(rpc_path)
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
# Clear the engine reference so that it can be GC'ed.
del self.engine
async def get_config(self, identity, request):
try:
config: CONFIG_TYPE
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
config = await self.engine.get_model_config()
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
config = await self.engine.get_decoding_config()
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
config = await self.engine.get_lora_config()
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
config = await self.engine.get_scheduler_config()
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
config = await self.engine.get_parallel_config()
else:
raise ValueError("Unknown Config Request: %s", request)
await self.socket.send_multipart((identity, pickle.dumps(config)),
copy=False)
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
(identity, pickle.dumps(tracing_flag)))
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
try:
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
except Exception as e:
result = e
await self.socket.send_multipart((identity, pickle.dumps(result)))
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
(identity, pickle.dumps(request_output)), copy=False)
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
except Exception as e:
await self.socket.send_multipart((identity, pickle.dumps(e)),
copy=False)
async def start_profile(self, identity):
logger.info("Starting profiler...")
await self.engine.start_profile()
logger.info("Profiler started.")
await self.socket.send_multipart((
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
async def stop_profile(self, identity):
logger.info("Stopping profiler...")
await self.engine.stop_profile()
logger.info("Profiler stopped.")
await self.socket.send_multipart((
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR),
))
def _make_handler_coro(self, identity,
message: Frame) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message.buffer)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request in [
RPCUtilityRequest.GET_MODEL_CONFIG,
RPCUtilityRequest.GET_PARALLEL_CONFIG,
RPCUtilityRequest.GET_DECODING_CONFIG,
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
RPCUtilityRequest.GET_LORA_CONFIG
]:
return self.get_config(identity, request)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
elif request == RPCUtilityRequest.START_PROFILE:
return self.start_profile(identity)
elif request == RPCUtilityRequest.STOP_PROFILE:
return self.stop_profile(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
else:
raise ValueError(f"Unknown RPCRequest type: {request}")
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart(copy=False)
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
signal.signal(signal.SIGTERM, signal_handler)
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
uvloop.run(run_server(server))

View File

@ -9,7 +9,7 @@ from typing import Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
@ -45,7 +45,7 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
@ -57,7 +57,7 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -105,6 +105,12 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
try:
(
lora_request,
@ -112,8 +118,7 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
@ -207,8 +212,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
@ -216,7 +221,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.async_engine_client.generate(
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,

View File

@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -52,7 +52,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -78,6 +78,12 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
@ -95,8 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
@ -124,8 +129,8 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
@ -133,7 +138,7 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.async_engine_client.generate(
generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,

View File

@ -8,7 +8,7 @@ from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None,
@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params()
@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"Prompt adapter is not supported "
"for embedding models")
generator = self.async_engine_client.encode(
generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
request_id_item,

View File

@ -8,7 +8,7 @@ from pydantic import Field
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -64,7 +64,7 @@ class OpenAIServing:
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -75,7 +75,7 @@ class OpenAIServing:
):
super().__init__()
self.async_engine_client = async_engine_client
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
@ -159,7 +159,7 @@ class OpenAIServing:
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.async_engine_client.get_decoding_config()
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(

View File

@ -1,7 +1,7 @@
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing):
def __init__(
self,
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
*,
@ -37,7 +37,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(async_engine_client=async_engine_client,
super().__init__(engine_client=engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@ -66,7 +66,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
@ -132,7 +132,7 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
self._log_inputs(request_id,
request.tokens,

View File

@ -57,7 +57,7 @@ if TYPE_CHECKING:
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_RPC_TIMEOUT: int = 10000 # ms
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
@ -393,8 +393,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
"VLLM_RPC_TIMEOUT":
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded

View File

@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
)) for rank in range(1, world_size)
]
self.worker_monitor = None
if world_size != 1 or is_async:
if is_async:
async_worker_list = self.workers + [self.driver_worker]

View File

@ -168,6 +168,8 @@ class ProcessWorkerWrapper:
self.tasks[task_id] = future
try:
self._task_queue.put((task_id, method, args, kwargs))
except SystemExit:
raise
except BaseException as e:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
@ -222,6 +224,8 @@ def _run_worker_process(
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except SystemExit:
raise
except KeyboardInterrupt:
break
except BaseException as e: