[Core][Bugfix][Perf] Introduce MQLLMEngine
to avoid asyncio
OH (#8157)
Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
9d104b5beb
commit
7c7714d856
@ -43,13 +43,15 @@ steps:
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/async_engine
|
||||
- tests/test_inputs
|
||||
- tests/multimodal
|
||||
- tests/test_utils
|
||||
- tests/worker
|
||||
commands:
|
||||
- pytest -v -s async_engine # Async Engine
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
|
@ -21,8 +21,8 @@ Traces can be visualized using https://ui.perfetto.dev/.
|
||||
.. tip::
|
||||
|
||||
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
||||
Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes.
|
||||
``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000``
|
||||
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
|
||||
``export VLLM_RPC_TIMEOUT=1800000``
|
||||
|
||||
Example commands and usage:
|
||||
===========================
|
||||
|
@ -1,106 +0,0 @@
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--enforce-eager",
|
||||
"--chat-template",
|
||||
str(chatml_jinja_path),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_models(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_completion(client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 1
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
assert completion.choices[0].finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chat_session(client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
assert chat_completion.id is not None
|
||||
assert len(chat_completion.choices) == 1
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.finish_reason == "length"
|
||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=10, prompt_tokens=55, total_tokens=65)
|
||||
|
||||
message = choice.message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
messages.append({"role": "assistant", "content": message.content})
|
||||
|
||||
# test multi-turn dialogue
|
||||
messages.append({"role": "user", "content": "express your result in json"})
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
@ -1,120 +0,0 @@
|
||||
import asyncio
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.rpc.client import (AsyncEngineRPCClient,
|
||||
RPCClientClosedError)
|
||||
from vllm.entrypoints.openai.rpc.server import AsyncEngineRPCServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def dummy_server(tmp_socket, monkeypatch):
|
||||
dummy_engine = unittest.mock.AsyncMock()
|
||||
|
||||
def dummy_engine_builder(*args, **kwargs):
|
||||
return dummy_engine
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(AsyncLLMEngine, "from_engine_args", dummy_engine_builder)
|
||||
server = AsyncEngineRPCServer(None, None, rpc_path=tmp_socket)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
server_task = loop.create_task(server.run_server_loop())
|
||||
|
||||
try:
|
||||
yield server
|
||||
finally:
|
||||
server_task.cancel()
|
||||
server.cleanup()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def client(tmp_socket):
|
||||
client = AsyncEngineRPCClient(rpc_path=tmp_socket)
|
||||
# Sanity check: the server is connected
|
||||
await client._wait_for_server_rpc()
|
||||
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_data_methods_use_timeouts(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Make the server _not_ reply with a model config
|
||||
m.setattr(dummy_server, "get_config", lambda x: None)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
# And ensure the task completes anyway
|
||||
# (client.setup() invokes server.get_config())
|
||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
||||
with pytest.raises(TimeoutError, match="Server didn't reply within"):
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_aborts_use_timeouts(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Hang all abort requests
|
||||
m.setattr(dummy_server, "abort", lambda x: None)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
# The client should suppress timeouts on `abort`s
|
||||
# and return normally, assuming the server will eventually
|
||||
# abort the request.
|
||||
client_task = asyncio.get_running_loop().create_task(
|
||||
client.abort("test request id"))
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_data_methods_reraise_exceptions(
|
||||
monkeypatch, dummy_server, client: AsyncEngineRPCClient):
|
||||
with monkeypatch.context() as m:
|
||||
# Make the server raise some random exception
|
||||
exception = RuntimeError("Client test exception")
|
||||
|
||||
def raiser():
|
||||
raise exception
|
||||
|
||||
m.setattr(dummy_server.engine, "get_model_config", raiser)
|
||||
m.setattr(client, "_data_timeout", 10)
|
||||
|
||||
client_task = asyncio.get_running_loop().create_task(client.setup())
|
||||
# And ensure the task completes, raising the exception
|
||||
with pytest.raises(RuntimeError, match=str(exception)):
|
||||
await asyncio.wait_for(client_task, timeout=0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_errors_after_closing(monkeypatch, dummy_server,
|
||||
client: AsyncEngineRPCClient):
|
||||
|
||||
client.close()
|
||||
|
||||
# Healthchecks and generate requests will fail with explicit errors
|
||||
with pytest.raises(RPCClientClosedError):
|
||||
await client.check_health()
|
||||
with pytest.raises(RPCClientClosedError):
|
||||
async for _ in client.generate(None, None, None):
|
||||
pass
|
||||
|
||||
# But no-ops like aborting will pass
|
||||
await client.abort("test-request-id")
|
||||
await client.do_log_stats()
|
@ -18,38 +18,32 @@ TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len", "4096", "--enable-chunked-prefill",
|
||||
"--disable-log-requests", "--enforce-eager"
|
||||
]
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
def test_lm_eval_accuracy(more_args):
|
||||
args = list(DEFAULT_ARGS)
|
||||
args.extend(more_args)
|
||||
|
||||
print(f"Running with: {args}")
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
url = f"{remote_server.url_for('v1')}/completions"
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={url},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_data(server):
|
||||
return {
|
||||
"url": f"{server.url_for('v1')}/completions",
|
||||
}
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
|
||||
def test_lm_eval_accuracy(server_data):
|
||||
model_args = (f"model={MODEL_NAME},"
|
||||
f"base_url={server_data['url']},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
@ -5,7 +5,7 @@ from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ..utils import VLLM_PATH
|
||||
from ...utils import VLLM_PATH
|
||||
|
||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert chatml_jinja_path.exists()
|
@ -1,40 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_crash_detection():
|
||||
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
# use an invalid tensor_parallel_size to trigger the
|
||||
# error in the server
|
||||
args.tensor_parallel_size = 65536
|
||||
|
||||
start = time.perf_counter()
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
|
||||
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
|
||||
"if there is an error in the startup.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_cuda_init():
|
||||
# it should not crash, when cuda is initialized
|
||||
# in the API server process
|
||||
import torch
|
||||
torch.cuda.init()
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
@ -52,8 +52,9 @@ def test_async_serving_chat_init():
|
||||
|
||||
|
||||
def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLMEngine)
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
@ -18,7 +18,7 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
|
||||
|
||||
async def _async_serving_engine_init():
|
||||
mock_engine_client = MagicMock(spec=AsyncEngineClient)
|
||||
mock_engine_client = MagicMock(spec=EngineClient)
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config.max_model_len = 2048
|
||||
|
@ -44,5 +44,5 @@ async def test_shutdown_on_engine_failure(tmp_path):
|
||||
prompt="Hello, my name is")
|
||||
|
||||
# Now the server should shut down
|
||||
return_code = remote_server.proc.wait(timeout=3)
|
||||
return_code = remote_server.proc.wait(timeout=8)
|
||||
assert return_code is not None
|
||||
|
67
tests/mq_llm_engine/test_abort.py
Normal file
67
tests/mq_llm_engine/test_abort.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""Test that aborting is handled properly."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||
RAISED_ERROR = KeyError
|
||||
RAISED_VALUE = "foo"
|
||||
EXPECTED_TOKENS = 250
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
request_id_to_be_aborted = "request-aborted"
|
||||
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
|
||||
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
|
||||
|
||||
# Requests started before one to be aborted.
|
||||
tasks = []
|
||||
for request_id in request_ids_a:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, EXPECTED_TOKENS)))
|
||||
|
||||
# Aborted.
|
||||
task_aborted = asyncio.create_task(
|
||||
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
|
||||
|
||||
# Requests started after one to be aborted.
|
||||
for request_id in request_ids_b:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, EXPECTED_TOKENS)))
|
||||
|
||||
# Actually abort.
|
||||
await asyncio.sleep(0.5)
|
||||
await client.abort(request_id_to_be_aborted)
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
for task in tasks:
|
||||
count, request_id = await task
|
||||
assert count == EXPECTED_TOKENS, (
|
||||
f"{request_id} generated only {count} tokens")
|
||||
|
||||
# Cancel task (this will hang indefinitely if not).
|
||||
task_aborted.cancel()
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
244
tests/mq_llm_engine/test_error_handling.py
Normal file
244
tests/mq_llm_engine/test_error_handling.py
Normal file
@ -0,0 +1,244 @@
|
||||
"""Test that various errors are handled properly."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
|
||||
RAISED_ERROR = KeyError
|
||||
RAISED_VALUE = "foo"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
def run_with_evil_forward(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during first forward pass.
|
||||
engine.engine.model_executor.execute_model = Mock(
|
||||
side_effect=RAISED_ERROR(RAISED_VALUE))
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evil_forward(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_forward) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
# Server should be healthy after initial probe.
|
||||
await asyncio.sleep(2.0)
|
||||
await client.check_health()
|
||||
|
||||
# Throws an error in first forward pass.
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
async for _ in client.generate(inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
assert client.errored
|
||||
|
||||
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
assert client.errored
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
assert client.errored
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_model_executor_health(engine_args: AsyncEngineArgs,
|
||||
ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during first forward pass.
|
||||
engine.engine.model_executor.check_health = Mock(side_effect=RAISED_ERROR)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_health_check(tmp_socket):
|
||||
with RemoteMQLLMEngine(
|
||||
engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_model_executor_health) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Health probe should throw RAISED_ERROR.
|
||||
await asyncio.sleep(15.)
|
||||
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
assert client.errored
|
||||
|
||||
# Generate call should throw ENGINE_DEAD_ERROR
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_abort(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Raise error during abort call.
|
||||
engine.engine.abort_request = Mock(side_effect=RAISED_ERROR)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_abort(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_abort) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Firsh check health should work.
|
||||
await client.check_health()
|
||||
|
||||
# Trigger an abort on the client side.
|
||||
async def bad_abort_after_2s():
|
||||
await asyncio.sleep(2.0)
|
||||
await client.abort(request_id="foo")
|
||||
|
||||
# Trigger an abort in 2s from now.
|
||||
abort_task = asyncio.create_task(bad_abort_after_2s())
|
||||
|
||||
# Exception in abort() will happen during this generation.
|
||||
# This will kill the engine and should return ENGINE_DEAD_ERROR
|
||||
# with reference to the original KeyError("foo")
|
||||
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||
async for _ in client.generate(
|
||||
inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=2000),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
assert "KeyError" in repr(execinfo.value)
|
||||
assert client.errored
|
||||
|
||||
await abort_task
|
||||
|
||||
# This should raise the original error.
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
# Invalid request should fail, but not crash the server.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in client.generate(inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id="abcd-1",
|
||||
lora_request=LoRARequest(
|
||||
"invalid-lora", 1,
|
||||
"invalid-path")):
|
||||
pass
|
||||
|
||||
# This request should be okay.
|
||||
async for _ in client.generate(inputs="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id="abcd-2"):
|
||||
pass
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_crash_detection(monkeypatch):
|
||||
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
# When LLMEngine is loaded, it will crash.
|
||||
def mock_init():
|
||||
raise ValueError
|
||||
|
||||
monkeypatch.setattr(LLMEngine, "__init__", mock_init)
|
||||
|
||||
start = time.perf_counter()
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
|
||||
assert end - start < 60, ("Expected vLLM to gracefully shutdown in <60s "
|
||||
"if there is an error in the startup.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mp_cuda_init():
|
||||
# it should not crash, when cuda is initialized
|
||||
# in the API server process
|
||||
import torch
|
||||
torch.cuda.init()
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args([])
|
||||
|
||||
async with build_async_engine_client(args):
|
||||
pass
|
57
tests/mq_llm_engine/test_load.py
Normal file
57
tests/mq_llm_engine/test_load.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Test that the MQLLMEngine is able to handle 10k concurrent requests."""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
MODEL = "google/gemma-1.1-2b-it"
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
NUM_REQUESTS = 10000
|
||||
|
||||
# Scenarios to test for num generated token.
|
||||
ENGINE_ARGS = AsyncEngineArgs(model=MODEL, disable_log_requests=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tmp_socket():
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
yield f"ipc://{td}/{uuid.uuid4()}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(client, request_id, NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
failed_request_id = None
|
||||
tokens = None
|
||||
for task in tasks:
|
||||
num_generated_tokens, request_id = await task
|
||||
if (num_generated_tokens != NUM_EXPECTED_TOKENS
|
||||
and failed_request_id is None):
|
||||
failed_request_id = request_id
|
||||
tokens = num_generated_tokens
|
||||
|
||||
assert failed_request_id is None, (
|
||||
f"{failed_request_id} generated {tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
# Shutdown.
|
||||
client.close()
|
78
tests/mq_llm_engine/utils.py
Normal file
78
tests/mq_llm_engine/utils.py
Normal file
@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
|
||||
async def generate(
|
||||
client: MQLLMEngineClient,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
return_output: bool = False) -> Union[RequestOutput, Tuple[int, str]]:
|
||||
|
||||
final_output = None
|
||||
count = 0
|
||||
async for out in client.generate(
|
||||
request_id=request_id,
|
||||
inputs="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||
temperature=0)):
|
||||
|
||||
count += 1
|
||||
final_output = out
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
if return_output:
|
||||
return final_output
|
||||
|
||||
# Confirm we generated all the tokens we expected.
|
||||
return count, request_id
|
||||
|
||||
|
||||
def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
class RemoteMQLLMEngine:
|
||||
|
||||
def __init__(self,
|
||||
engine_args: AsyncEngineArgs,
|
||||
ipc_path: str,
|
||||
run_fn: Callable = run_normal) -> None:
|
||||
|
||||
self.engine_args = engine_args
|
||||
self.ipc_path = ipc_path
|
||||
context = multiprocessing.get_context("spawn")
|
||||
self.proc = context.Process(target=run_fn,
|
||||
args=(engine_args, ipc_path))
|
||||
self.proc.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.kill()
|
||||
|
||||
async def make_client(self) -> MQLLMEngineClient:
|
||||
engine_config = self.engine_args.create_engine_config()
|
||||
client = MQLLMEngineClient(self.ipc_path, engine_config)
|
||||
while True:
|
||||
try:
|
||||
await client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
assert self.proc.is_alive()
|
||||
return client
|
@ -1,5 +1,12 @@
|
||||
import os
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
# --enforce-eager on TPU causes graph compilation
|
||||
# this times out default Health Check in the MQLLMEngine,
|
||||
# so we set the timeout here to 30s
|
||||
os.environ["VLLM_RPC_TIMEOUT"] = "30000"
|
||||
|
||||
|
||||
def test_custom_dispatcher():
|
||||
compare_two_settings("google/gemma-2b",
|
||||
|
@ -119,7 +119,7 @@ class RemoteOpenAIServer:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(3)
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
@ -601,9 +601,12 @@ class AsyncLLMEngine:
|
||||
return self._errored_with is not None
|
||||
|
||||
@property
|
||||
def limit_concurrency(self) -> Optional[int]:
|
||||
"""Maximum number of concurrently running requests."""
|
||||
return None
|
||||
def dead_error(self) -> BaseException:
|
||||
return AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
def set_errored(self, exc: Exception) -> None:
|
||||
self._errored_with = exc
|
||||
|
@ -1289,6 +1289,7 @@ class LLMEngine:
|
||||
# torch.distributed ops which may otherwise timeout, and unblocks
|
||||
# the RPC thread in the workers so that they can process any other
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
logger.debug("Stopping remote worker execution loop.")
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
return ctx.request_outputs
|
||||
|
73
vllm/engine/multiprocessing/__init__.py
Normal file
73
vllm/engine/multiprocessing/__init__.py
Normal file
@ -0,0 +1,73 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Mapping, Optional, Union
|
||||
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
IPC_INPUT_EXT = "_input_socket"
|
||||
IPC_OUTPUT_EXT = "_output_socket"
|
||||
IPC_HEALTH_EXT = "_health_socket"
|
||||
IPC_DATA_EXT = "_data_socket"
|
||||
|
||||
|
||||
class MQEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCGenerateRequest:
|
||||
inputs: PromptInputs
|
||||
sampling_params: SamplingParams
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCError:
|
||||
request_id: Optional[str]
|
||||
is_engine_errored: bool
|
||||
exception: BaseException
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCHealthRequest:
|
||||
pass
|
||||
|
||||
|
||||
class RPCStartupRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
|
||||
RPC_REQUEST_T = Union[RPCGenerateRequest, RPCAbortRequest, RPCHealthRequest,
|
||||
RPCStartupRequest]
|
||||
|
||||
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
|
||||
|
||||
|
||||
def ENGINE_DEAD_ERROR(
|
||||
error: Optional[BaseException] = None) -> MQEngineDeadError:
|
||||
if error is None:
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
"find the original error")
|
||||
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
f"find the original error: {repr(error)}.")
|
452
vllm/engine/multiprocessing/client.py
Normal file
452
vllm/engine/multiprocessing/client.py
Normal file
@ -0,0 +1,452 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
|
||||
Union)
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCGenerateRequest,
|
||||
RPCHealthRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MQClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class MQLLMEngineClient:
|
||||
"""A client wrapper for MQLLMEngine that conforms to the
|
||||
EngineClient protocol.
|
||||
|
||||
MQLLMEngine and MQLLMEngineClient are intended to run in separate
|
||||
processes communicating via zeromq ipc sockets.
|
||||
|
||||
The entrypoint to MQLLMEngineClient is through the generate()
|
||||
method. On generate() MQLLMEngine does three things:
|
||||
- Creates an asyncio output queue
|
||||
- Sends a RPCGenerateRequest to the MQLLMEngine via zmq
|
||||
- Pulls RequestOutputs from its queue and yields them
|
||||
|
||||
MQLLMEngine runs two background loops:
|
||||
- output_loop: the output loop pulls List[RequestOutput]
|
||||
from the MQLLMEngine via zmq (each list is the output
|
||||
of one engine_step in the LLMEngine). It then parses
|
||||
the list and pushes individual request_outputs into
|
||||
the corresponding output_queue such that they can be
|
||||
consumed by the .generate() method.
|
||||
- health_loop: the health loop queries the health socket
|
||||
every N seconds, confirming the engine is healthy
|
||||
"""
|
||||
|
||||
def __init__(self, ipc_path: str, engine_config: EngineConfig):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = engine_config.model_config
|
||||
self.decoding_config = engine_config.decoding_config
|
||||
|
||||
# Create the tokenizer group.
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
enable_lora=bool(engine_config.lora_config),
|
||||
)
|
||||
|
||||
# Send RPCGenerateRequest to the MQLLMEngine.
|
||||
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Receive streams of RequestOutput from the MQLLMEngine.
|
||||
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# IPC path for ack of check_health requests.
|
||||
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Stream for each individual request.
|
||||
self.output_queues: Dict[str, asyncio.Queue] = {}
|
||||
self.output_loop = asyncio.create_task(self.run_output_handler_loop())
|
||||
|
||||
# Loop to check health of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.health_loop: Optional[asyncio.Task] = None
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||
if engine_args.pipeline_parallel_size > 1:
|
||||
return True
|
||||
|
||||
is_embedding = ModelConfig(
|
||||
model=engine_args.model,
|
||||
revision=engine_args.revision,
|
||||
tokenizer=engine_args.model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=engine_args.trust_remote_code,
|
||||
quantization=engine_args.quantization,
|
||||
seed=0,
|
||||
dtype="auto").embedding_mode
|
||||
|
||||
return is_embedding
|
||||
|
||||
@contextmanager
|
||||
def get_data_socket(self) -> Iterator[Socket]:
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
try:
|
||||
socket.connect(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def run_check_health_loop(self, timeout: int):
|
||||
"""Background loop that continually probes the RPCServer for health.
|
||||
|
||||
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
|
||||
the MQLLMEngine server is blocking on.
|
||||
|
||||
The Server replies on the HEALTH_SOCKET (rather than on the
|
||||
OUTPUT_SOCKET such that the messages are not intermingled with
|
||||
output streaming).
|
||||
"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await self.health_socket.poll(timeout=timeout) == 0:
|
||||
# Wakeup every N seconds and do a health probe.
|
||||
await self._send_one_way_rpc_request(
|
||||
RPCHealthRequest(), self.input_socket)
|
||||
|
||||
# Wait for ack from the health socket.
|
||||
await self._await_ack(error_message="Health check failed.",
|
||||
socket=self.health_socket)
|
||||
else:
|
||||
# Server sent a health status message unprompted.
|
||||
await self._check_success(
|
||||
error_message="Health check failed.",
|
||||
socket=self.health_socket)
|
||||
|
||||
logger.debug("Health probe successful.")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check health loop.")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
async def run_output_handler_loop(self):
|
||||
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Poll, checking for ENGINE_DEAD
|
||||
while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
|
||||
) == 0:
|
||||
logger.debug("Waiting for output from MQLLMEngine.")
|
||||
|
||||
# If errored, alert all running requests.
|
||||
if self.errored:
|
||||
for queue_j in tuple(self.output_queues.values()):
|
||||
queue_j.put_nowait(
|
||||
ENGINE_DEAD_ERROR(self._errored_with))
|
||||
return
|
||||
|
||||
message: Frame = await self.output_socket.recv(copy=False)
|
||||
request_outputs = pickle.loads(message.buffer)
|
||||
|
||||
is_error = isinstance(request_outputs,
|
||||
(BaseException, RPCError))
|
||||
if is_error:
|
||||
if isinstance(request_outputs, RPCError):
|
||||
rpc_error: RPCError = request_outputs
|
||||
request_id = rpc_error.request_id
|
||||
exception = rpc_error.exception
|
||||
is_engine_errored = rpc_error.is_engine_errored
|
||||
else:
|
||||
# MPLLMEngine should always return an RPCError to
|
||||
# the output_socket when an issue arises.
|
||||
# If we are here, we are in a bad state and
|
||||
# should shut down the server.
|
||||
error: BaseException = request_outputs
|
||||
logger.error(
|
||||
"Received Exception %s rather than RPCError from "
|
||||
"MPLLMEngine. This should never happen.", error)
|
||||
request_id = None
|
||||
exception = error
|
||||
is_engine_errored = True
|
||||
|
||||
# Set to error state only on engine critical error
|
||||
# (and record only the first one)
|
||||
if is_engine_errored and not self._errored_with:
|
||||
self._errored_with = exception
|
||||
|
||||
if request_id is None:
|
||||
for queue_i in tuple(self.output_queues.values()):
|
||||
queue_i.put_nowait(exception)
|
||||
else:
|
||||
queue = self.output_queues.get(request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(exception)
|
||||
else:
|
||||
# Put each output into the appropriate steam.
|
||||
for request_output in request_outputs:
|
||||
queue = self.output_queues.get(
|
||||
request_output.request_id)
|
||||
if queue is not None:
|
||||
queue.put_nowait(request_output)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient output handler.")
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
|
||||
with self.get_data_socket() as socket:
|
||||
# Wait until server is ready.
|
||||
response = await self._wait_for_server_rpc(socket)
|
||||
|
||||
self.tracing_flag = response.tracing_enabled
|
||||
|
||||
# Start health_loop.
|
||||
self.health_loop = asyncio.create_task(
|
||||
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
# Close all sockets and terminate the context.
|
||||
self.context.destroy(linger=0)
|
||||
|
||||
# Cancel background tasks.
|
||||
if self.health_loop is not None:
|
||||
self.health_loop.cancel()
|
||||
self.output_loop.cancel()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
logger.exception(repr(e))
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
@staticmethod
|
||||
async def _send_get_data_rpc_request(request: RPCStartupRequest,
|
||||
expected_type: Any,
|
||||
error_message: str,
|
||||
socket: Socket) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart((pickle.dumps(request), ), copy=False)
|
||||
|
||||
# Make sure the server responds in time.
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("RPCServer didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, BaseException):
|
||||
raise data
|
||||
elif not isinstance(data, expected_type):
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
|
||||
socket: Socket):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
await socket.send_multipart((pickle.dumps(request), ))
|
||||
|
||||
async def _await_ack(self, error_message: str, socket: Socket):
|
||||
"""Await acknowledgement that a request succeeded."""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
|
||||
raise TimeoutError("MQLLMEngine didn't reply within "
|
||||
f"{VLLM_RPC_TIMEOUT}ms")
|
||||
|
||||
await self._check_success(error_message, socket)
|
||||
|
||||
@staticmethod
|
||||
async def _check_success(error_message: str, socket: Socket):
|
||||
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""
|
||||
|
||||
if socket.closed:
|
||||
raise MQClientClosedError()
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
response = pickle.loads(frame.buffer)
|
||||
|
||||
# Raise error if unsuccessful
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
elif (not isinstance(response, str)
|
||||
or response != VLLM_RPC_SUCCESS_STR):
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
return self.decoding_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
request=RPCStartupRequest.IS_SERVER_READY,
|
||||
expected_type=RPCStartupResponse,
|
||||
error_message="Unable to start RPC Server",
|
||||
socket=socket)
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
with suppress(MQClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id), socket=self.input_socket)
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
|
||||
pass
|
||||
|
||||
async def check_health(self):
|
||||
"""
|
||||
The check health loop probes the health status of the
|
||||
Engine's health every N seconds and sets _errored_with
|
||||
if the engine is unhealthy.
|
||||
"""
|
||||
if self._errored_with is not None:
|
||||
raise self._errored_with
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return not self.errored
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
# If already dead, error out.
|
||||
if self._errored_with is not None:
|
||||
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||
|
||||
# 1) Create output queue for this requests.
|
||||
queue: asyncio.Queue[Union[RequestOutput,
|
||||
BaseException]] = asyncio.Queue()
|
||||
self.output_queues[request_id] = queue
|
||||
|
||||
try:
|
||||
# 2) Detach logits processors so that they can be pickled
|
||||
# separately (may require cloudpickle which is slower)
|
||||
if sampling_params.logits_processors:
|
||||
# Defensive shallow copy
|
||||
sampling_params = copy.copy(sampling_params)
|
||||
logits_processors = sampling_params.logits_processors
|
||||
sampling_params.logits_processors = None
|
||||
lp_bytes = cloudpickle.dumps(logits_processors)
|
||||
else:
|
||||
lp_bytes = None
|
||||
|
||||
request_bytes = pickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request))
|
||||
|
||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||
parts = (request_bytes,
|
||||
lp_bytes) if lp_bytes else (request_bytes, )
|
||||
await self.input_socket.send_multipart(parts, copy=False)
|
||||
|
||||
# 4) Stream the RequestOutputs from the output queue. Note
|
||||
# that the output_loop pushes RequestOutput objects to this
|
||||
# queue after pulling them from the zmq socket.
|
||||
finished = False
|
||||
try:
|
||||
while not finished:
|
||||
request_output = await queue.get()
|
||||
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
finally:
|
||||
# Request was canceled by the client.
|
||||
if not finished and not self.errored:
|
||||
await self.abort(request_id)
|
||||
finally:
|
||||
self.output_queues.pop(request_id)
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
321
vllm/engine/multiprocessing/engine.py
Normal file
321
vllm/engine/multiprocessing/engine.py
Normal file
@ -0,0 +1,321 @@
|
||||
import pickle
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
|
||||
from vllm import AsyncEngineArgs, LLMEngine
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCError, RPCGenerateRequest,
|
||||
RPCHealthRequest, RPCStartupRequest,
|
||||
RPCStartupResponse)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 10000
|
||||
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
"""A multiprocessing wrapper for :class:`LLMEngine`.
|
||||
|
||||
This class is used to wrap the :class:`LLMEngine` class to enable use
|
||||
in concurrnet manner. It runs a background loop and uses zeromq to
|
||||
receive new requests and stream outputs incrementally via ipc.
|
||||
|
||||
The :class:`LLMEngine.generate` is kicked off when a new
|
||||
RPCGenerateRequest is received by the input_socket.
|
||||
|
||||
The self.engine_loop checks the input_socket for new requests,
|
||||
adds them to the LLMEngine if there are any, calls the internal
|
||||
:class:`LLMEngine.step()`, and sends the RequestOutputs back over
|
||||
the output_socket.
|
||||
|
||||
If use_async_sockets is set, the logic associated with reading new
|
||||
requests from the socket and sending data to the socket is passed
|
||||
as a callback to the llm_engine, which calls the logic asynchronously
|
||||
such that the IPC can be overlapped with the GPU.
|
||||
|
||||
Args:
|
||||
ipc_path: Base path for zeromq interprocess messaging
|
||||
use_async_sockets: Whether to make send/recv async with GPU
|
||||
log_requests: Whether to log the requests.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
**kwargs: Arguments for :class:`LLMEngine`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ipc_path: str,
|
||||
use_async_sockets: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.engine = LLMEngine(*args, **kwargs)
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.use_async_sockets = use_async_sockets
|
||||
if self.use_async_sockets:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self._async_socket_engine_callback
|
||||
|
||||
self.ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
# Receive input from the client.
|
||||
self.input_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}")
|
||||
|
||||
# Send output stream back to client.
|
||||
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
|
||||
|
||||
# Send health status back to client.
|
||||
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||
else:
|
||||
return ENGINE_DEAD_ERROR()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, ipc_path: str):
|
||||
"""Creates an MQLLMEngine from the engine arguments."""
|
||||
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
executor_class = LLMEngine._get_executor_cls(engine_config)
|
||||
|
||||
return cls(
|
||||
ipc_path=ipc_path,
|
||||
use_async_sockets=engine_config.model_config.use_async_output_proc,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context)
|
||||
|
||||
def start(self):
|
||||
try:
|
||||
try:
|
||||
logger.debug("Starting Startup Loop.")
|
||||
self.run_startup_loop()
|
||||
logger.debug("Starting Engine Loop.")
|
||||
self.run_engine_loop()
|
||||
except Exception as e:
|
||||
logger.exception(repr(e))
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Shutting down MQLLMEngine.")
|
||||
finally:
|
||||
logger.debug("MQLLMEngine is shut down.")
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup zeromq state on shutdown."""
|
||||
# Closes all sockets and destroys context.
|
||||
self.ctx.destroy(linger=0)
|
||||
del self.engine
|
||||
|
||||
@contextmanager
|
||||
def make_data_socket(
|
||||
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
socket = self.ctx.socket(zmq.constants.ROUTER)
|
||||
try:
|
||||
socket.bind(self.data_ipc_path)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
def run_startup_loop(self) -> None:
|
||||
"""Startup loop for sending data from Engine -> Client."""
|
||||
|
||||
with self.make_data_socket() as socket:
|
||||
response: Union[RPCStartupResponse, BaseException]
|
||||
try:
|
||||
identity, message = socket.recv_multipart(copy=False)
|
||||
request: RPCStartupRequest = pickle.loads(message.buffer)
|
||||
|
||||
# Handle the query from the Client.
|
||||
if request == RPCStartupRequest.IS_SERVER_READY:
|
||||
tracing_enabled = self.engine.is_tracing_enabled()
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled)
|
||||
|
||||
except Exception as e:
|
||||
response = e
|
||||
|
||||
socket.send_multipart((identity, pickle.dumps(response)),
|
||||
copy=False)
|
||||
|
||||
def run_engine_loop(self):
|
||||
"""Core busy loop of the LLMEngine."""
|
||||
|
||||
while True:
|
||||
if not self.engine.has_unfinished_requests():
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
self.engine.do_log_stats()
|
||||
logger.debug("Waiting for new requests in engine loop.")
|
||||
|
||||
# Handle any input from the client.
|
||||
self.handle_new_input()
|
||||
|
||||
# Engine step.
|
||||
request_outputs = self.engine_step()
|
||||
|
||||
# Send request outputs (if async, done in engine_step callback).
|
||||
if not self.use_async_sockets:
|
||||
self._send_outputs(request_outputs)
|
||||
|
||||
def engine_step(self) -> List[RequestOutput]:
|
||||
"""Engine step wrapper with error handling."""
|
||||
|
||||
try:
|
||||
return self.engine.step()
|
||||
except SystemExit:
|
||||
raise
|
||||
except BaseException as e:
|
||||
self._set_errored(e)
|
||||
rpc_err = RPCError(request_id=None,
|
||||
is_engine_errored=True,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
raise e
|
||||
|
||||
def handle_new_input(self):
|
||||
"""Handle new input from the socket"""
|
||||
try:
|
||||
while self.input_socket.poll(timeout=0) != 0:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
if len(frames) > 1:
|
||||
# Use cloudpickle for logits processors
|
||||
lprocs = cloudpickle.loads(frames[1].buffer)
|
||||
request.sampling_params.logits_processors = lprocs
|
||||
self._handle_generate_request(request)
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
self._handle_abort_request(request)
|
||||
elif isinstance(request, RPCHealthRequest):
|
||||
self._handle_health_request()
|
||||
else:
|
||||
raise ValueError("Unknown RPCRequest Type: {request}")
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
self._send_unhealthy(e)
|
||||
raise e
|
||||
|
||||
def _handle_generate_request(self, request: RPCGenerateRequest):
|
||||
"""Handle RPCGenerateRequest by adding it to the LLMEngine."""
|
||||
request_id = request.request_id
|
||||
|
||||
if self._errored_with is not None:
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=True,
|
||||
exception=ENGINE_DEAD_ERROR(self._errored_with))
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
inputs=request.inputs,
|
||||
params=request.sampling_params,
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
except Exception as e:
|
||||
# We do not set self._errored = True here, since the error
|
||||
# is due to an issue adding this request to the engine,
|
||||
# rather than an issue with the engine itself.
|
||||
is_errored = self._errored_with is not None
|
||||
rpc_err = RPCError(request_id=request_id,
|
||||
is_engine_errored=is_errored,
|
||||
exception=e)
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
# Remove request from the engine.
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
def _handle_abort_request(self, request: RPCAbortRequest):
|
||||
self.engine.abort_request(request.request_id)
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request.request_id)
|
||||
|
||||
def _handle_health_request(self):
|
||||
if self._errored_with is not None:
|
||||
self._send_unhealthy(self._errored_with)
|
||||
|
||||
# Raises error if unhealthy.
|
||||
self.engine.check_health()
|
||||
self._send_healthy()
|
||||
|
||||
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
|
||||
"""Send List of RequestOutput to RPCClient."""
|
||||
if outputs:
|
||||
output_bytes = pickle.dumps(outputs)
|
||||
self.output_socket.send_multipart((output_bytes, ), copy=False)
|
||||
|
||||
def _send_healthy(self):
|
||||
"""Send HEALTHY message to RPCClient."""
|
||||
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
|
||||
|
||||
def _send_unhealthy(self, error: BaseException):
|
||||
"""Send UNHEALTHY message to RPCClient."""
|
||||
error_bytes = pickle.dumps(error)
|
||||
self.health_socket.send_multipart((error_bytes, ), copy=False)
|
||||
|
||||
def _async_socket_engine_callback(self,
|
||||
request_outputs: REQUEST_OUTPUTS_T):
|
||||
"""Callback used by engine to make socket handling async with GPU."""
|
||||
self._send_outputs(request_outputs)
|
||||
self.handle_new_input()
|
||||
|
||||
def _set_errored(self, e: BaseException):
|
||||
"""Log and set errored status if this is the first issue."""
|
||||
if self._errored_with is None:
|
||||
self._errored_with = e
|
||||
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||
usage_context=usage_context,
|
||||
ipc_path=ipc_path)
|
||||
engine.start()
|
@ -14,8 +14,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncEngineClient(Protocol):
|
||||
"""Protocol class for Clients to AsyncLLMEngine"""
|
||||
class EngineClient(Protocol):
|
||||
"""Protocol class for Clients to Engine"""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -30,8 +30,8 @@ class AsyncEngineClient(Protocol):
|
||||
...
|
||||
|
||||
@property
|
||||
def limit_concurrency(self) -> Optional[int]:
|
||||
"""Maximum number of concurrently running requests."""
|
||||
def dead_error(self) -> BaseException:
|
||||
...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -1,21 +1,21 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
||||
**uvicorn_kwargs: Any):
|
||||
async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
methods = getattr(route, "methods", None)
|
||||
@ -26,15 +26,6 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
||||
|
||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
||||
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
||||
if limit_concurrency is not None:
|
||||
logger.info(
|
||||
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
||||
"limit at the expense of performance run with "
|
||||
"--disable-frontend-multiprocessing", limit_concurrency)
|
||||
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
@ -63,7 +54,7 @@ async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
||||
logger.debug(
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Gracefully stopping http server")
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
@ -90,7 +81,7 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
async def engine_dead_handler(_, __):
|
||||
async def async_engine_dead_handler(_, __):
|
||||
"""Kill the server if the async engine is already dead. It will
|
||||
not handle any further requests."""
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||
@ -99,3 +90,14 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@app.exception_handler(MQEngineDeadError)
|
||||
async def mq_engine_dead_handler(_, __):
|
||||
"""Kill the server if the mq engine is already dead. It will
|
||||
not handle any further requests."""
|
||||
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
|
||||
logger.fatal("MQLLMEngine is already dead, terminating server "
|
||||
"process")
|
||||
server.should_exit = True
|
||||
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
@ -26,7 +26,9 @@ import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
@ -44,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@ -67,29 +67,16 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str, trust_remote_code: bool,
|
||||
quantization: Optional[str],
|
||||
revision: Optional[str]) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
revision=revision,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
quantization=quantization,
|
||||
seed=0,
|
||||
dtype="auto").embedding_mode
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
async_engine_client = app.state.engine_client
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
await asyncio.sleep(10.)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
@ -108,9 +95,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
args: Namespace) -> AsyncIterator[Optional[EngineClient]]:
|
||||
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
@ -123,19 +110,18 @@ async def build_async_engine_client(
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
) -> AsyncIterator[Optional[EngineClient]]:
|
||||
"""
|
||||
Create AsyncEngineClient, either:
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||
engine_args.quantization, engine_args.revision)
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args(
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Select random path for IPC.
|
||||
rpc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for RPC Path.",
|
||||
rpc_path)
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
context = multiprocessing.get_context("spawn")
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
rpc_server_process = context.Process(
|
||||
target=run_rpc_server,
|
||||
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
|
||||
rpc_server_process.start()
|
||||
logger.info("Started engine process with PID %d",
|
||||
rpc_server_process.pid)
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await rpc_client.setup()
|
||||
await mp_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not rpc_server_process.is_alive():
|
||||
logger.error(
|
||||
"RPCServer process died before responding "
|
||||
"to readiness probe")
|
||||
if not engine_process.is_alive():
|
||||
logger.error("Engine process died before responding "
|
||||
"to readiness probe")
|
||||
yield None
|
||||
return
|
||||
|
||||
yield rpc_client # type: ignore[misc]
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
rpc_client.close()
|
||||
mp_engine_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(rpc_server_process.pid)
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> AsyncEngineClient:
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@ -473,7 +463,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
|
||||
def init_app_state(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
@ -488,11 +478,11 @@ def init_app_state(
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
state.engine_client = async_engine_client
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
async_engine_client,
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
@ -504,7 +494,7 @@ def init_app_state(
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
async_engine_client,
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@ -513,13 +503,13 @@ def init_app_state(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
async_engine_client,
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
async_engine_client,
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@ -541,21 +531,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
# If None, creation of the client failed and we exit.
|
||||
if async_engine_client is None:
|
||||
if engine_client is None:
|
||||
return
|
||||
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
init_app_state(async_engine_client, model_config, app.state, args)
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
temp_socket.close()
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
limit_concurrency=async_engine_client.limit_concurrency,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
|
@ -1,50 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Mapping, Optional, Union
|
||||
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
# Success string used for RPC instructions.
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000
|
||||
|
||||
# HWM is set to Infinity.
|
||||
VLLM_RPC_ZMQ_HWM = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCGenerateRequest:
|
||||
inputs: PromptInputs
|
||||
sampling_params: SamplingParams
|
||||
request_id: str
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RPCAbortRequest:
|
||||
request_id: str
|
||||
|
||||
|
||||
class RPCUtilityRequest(Enum):
|
||||
IS_SERVER_READY = 1
|
||||
GET_MODEL_CONFIG = 2
|
||||
GET_DECODING_CONFIG = 3
|
||||
GET_PARALLEL_CONFIG = 4
|
||||
GET_SCHEDULER_CONFIG = 5
|
||||
GET_LORA_CONFIG = 6
|
||||
DO_LOG_STATS = 7
|
||||
IS_SERVER_HEALTHY = 8
|
||||
IS_TRACING_ENABLED = 9
|
||||
START_PROFILE = 10
|
||||
STOP_PROFILE = 11
|
||||
|
||||
|
||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
||||
RPCUtilityRequest]
|
@ -1,451 +0,0 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import Any, AsyncGenerator, Iterator, Mapping, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
|
||||
VLLM_RPC_SOCKET_LIMIT_CUTOFF,
|
||||
VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Path used for inprocess proxy.
|
||||
INPROC_PROXY_PATH = f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
class RPCClientClosedError(Exception):
|
||||
"""Exception class raised when the client is used post-close.
|
||||
|
||||
The client can be closed, which closes the ZMQ context. This normally
|
||||
happens on server shutdown. In some cases, methods like abort and
|
||||
do_log_stats will still be called and then try to open a socket, which
|
||||
causes a ZMQError and creates a huge stack trace.
|
||||
So, we throw this error such that we can suppress it.
|
||||
"""
|
||||
|
||||
|
||||
class AsyncEngineRPCClient:
|
||||
"""
|
||||
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
|
||||
|
||||
The overall design mirrors the Asynchronous Client Server Pattern
|
||||
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
|
||||
|
||||
On startup, the RPCClient:
|
||||
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
|
||||
via ipc, which uses unix sockets under the hood
|
||||
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
|
||||
- makes ROUTER socket (from_api_server) that binds to a random
|
||||
inproc address, which uses memory under the hood
|
||||
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
|
||||
- runs a proxy in a background asyncio task between
|
||||
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
|
||||
|
||||
Each request handled by the asyncio api_server calls generate():
|
||||
- make a DEALER socket that connects to from_api_server via inproc
|
||||
- send a RCPGenerateRequest to the inproc socket
|
||||
- background proxy forwards the request from inproc -> ipc
|
||||
- RPCServer responds to the request one token at a time over ipc
|
||||
- background proxy forwards the response from ipc -> inproc
|
||||
|
||||
The connection looks like this:
|
||||
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
|
||||
|
||||
Message routing is performed via identities that are managed by the
|
||||
ROUTER socket. ROUTER sockets track every connection it has and
|
||||
tells the caller about these. The way it tells the caller is to stick
|
||||
the connection identity in front of each message received. When we
|
||||
send the message via a ROUTER, we first send an identity frame.
|
||||
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
|
||||
for more details on connection identities.
|
||||
|
||||
This proxy design enables us to use a single unix socket, which
|
||||
improves performance by avoiding syscalls (~5%) and avoids resource limits
|
||||
such as ulimit, which defaults to 1024 on ubuntu.
|
||||
|
||||
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
|
||||
which is required to avoid dropping messages under high load.
|
||||
This is generally not advisable. However, since we are in control
|
||||
of both sides of the connection + failure on either side is
|
||||
catastrophic to the overall system health and memory profiling
|
||||
suggests limited memory overhead relative to asyncio, we will
|
||||
proceed for now.
|
||||
|
||||
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
|
||||
for more details on high water marks.
|
||||
"""
|
||||
|
||||
def __init__(self, rpc_path: str):
|
||||
self.context = zmq.asyncio.Context()
|
||||
self._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
|
||||
self._errored = False
|
||||
|
||||
# Maximum number of sockets that can be opened (typically 65536).
|
||||
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
|
||||
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
|
||||
assert isinstance(socket_limit, int)
|
||||
if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF:
|
||||
raise ValueError(
|
||||
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
|
||||
"the number of concurrent requests vLLM can process. Launch "
|
||||
"vLLM with --disable-frontend-multiprocessing and open a "
|
||||
"GitHub issue so we can investigate.")
|
||||
|
||||
# We only have 1 ipc connection that uses unix sockets, so
|
||||
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
|
||||
# not run into ulimit issues)
|
||||
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)
|
||||
|
||||
# IPC connection to RPC Server (uses unix sockets).
|
||||
self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.to_rpc_server.bind(rpc_path)
|
||||
|
||||
# In process proxy to RPC Server (uses memory-based messaging).
|
||||
self.from_api_server: Socket = self.context.socket(
|
||||
zmq.constants.ROUTER)
|
||||
self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.from_api_server.bind(INPROC_PROXY_PATH)
|
||||
|
||||
# Asyncio background task for the proxy.
|
||||
self.proxy_in_task = asyncio.create_task(
|
||||
self.run_proxy(self.from_api_server, self.to_rpc_server))
|
||||
self.proxy_out_task = asyncio.create_task(
|
||||
self.run_proxy(self.to_rpc_server, self.from_api_server))
|
||||
|
||||
# Since we open 1 inproc socket per request, we have a hard cap on
|
||||
# the number of requests that can run in vLLM w. frontend
|
||||
# mulitprocessing. This value is used uvicorn to launch
|
||||
# with --limit-concurrency to return 503 when server is overloaded.
|
||||
# We need 2 sockets per request - 2:
|
||||
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
|
||||
self.limit_concurrency = socket_limit // 2 - 2
|
||||
|
||||
async def run_proxy(self, socket_from: Socket, socket_to: Socket):
|
||||
"""Background task that runs a proxy"""
|
||||
while True:
|
||||
frames = await socket_from.recv_multipart(copy=False)
|
||||
await socket_to.send_multipart(frames, copy=False)
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the client before it starts sending server requests."""
|
||||
|
||||
# Wait until server is ready.
|
||||
await self._wait_for_server_rpc()
|
||||
|
||||
# Get the configs.
|
||||
self.model_config = await self._get_model_config_rpc()
|
||||
self.decoding_config = await self._get_decoding_config_rpc()
|
||||
self.tracing_flag = await self._is_tracing_enabled_rpc()
|
||||
|
||||
# Create the tokenizer group.
|
||||
# TODO: refactor OAI server to avoid needing this info.
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.model_config,
|
||||
scheduler_config=(await self._get_scheduler_config_rpc()),
|
||||
parallel_config=(await self._get_parallel_config_rpc()),
|
||||
enable_lora=bool(await self._get_lora_config_rpc()),
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
# Close all sockets associated with this context and
|
||||
# then terminate the context.
|
||||
self.from_api_server.close()
|
||||
self.to_rpc_server.close()
|
||||
self.context.destroy()
|
||||
|
||||
@contextmanager
|
||||
def to_proxy_socket(self) -> Iterator[Socket]:
|
||||
# Connect to the RPCServer via the proxy.
|
||||
|
||||
# Raise a sensible error if the client was already closed.
|
||||
# This can happen if a server shutdown is triggered but some coroutines
|
||||
# are still running requests.
|
||||
# There should not be a race condition with this check because we don't
|
||||
# yield to the event loop between here and opening the socket.
|
||||
if self.context.closed:
|
||||
raise RPCClientClosedError("The ZMQ client has already shut down")
|
||||
|
||||
# Note that we use DEALER to enable asynchronous communication
|
||||
# to enable streaming.
|
||||
socket = self.context.socket(zmq.constants.DEALER)
|
||||
socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
try:
|
||||
socket.connect(INPROC_PROXY_PATH)
|
||||
yield socket
|
||||
finally:
|
||||
socket.close(linger=0)
|
||||
|
||||
async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
|
||||
expected_type: Any,
|
||||
error_message: str) -> Any:
|
||||
"""Send an RPC request that is expecting data back."""
|
||||
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Ping RPCServer with a request.
|
||||
await socket.send_multipart((cloudpickle.dumps(request), ),
|
||||
copy=False)
|
||||
|
||||
# Make sure the server responds
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, Exception):
|
||||
# Re-raise exceptions returned by the server
|
||||
raise data
|
||||
|
||||
if not isinstance(data, expected_type):
|
||||
# LoRAConfig can be None.
|
||||
if expected_type == LoRAConfig and data is None:
|
||||
pass
|
||||
elif isinstance(data, Exception):
|
||||
logger.error(error_message)
|
||||
raise data
|
||||
else:
|
||||
raise ValueError(error_message)
|
||||
|
||||
return data
|
||||
|
||||
async def _send_one_way_rpc_request(self,
|
||||
request: RPC_REQUEST_TYPE,
|
||||
error_message: str,
|
||||
socket: Optional[Socket] = None):
|
||||
"""Send one-way RPC request to trigger an action."""
|
||||
|
||||
async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE):
|
||||
|
||||
await socket.send_multipart((cloudpickle.dumps(request), ))
|
||||
|
||||
if await socket.poll(timeout=self._data_timeout) == 0:
|
||||
raise TimeoutError("Server didn't reply within "
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
return pickle.loads(frame.buffer)
|
||||
|
||||
# Make a new socket connection.
|
||||
if socket is None:
|
||||
with self.to_proxy_socket() as socket:
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
# Use existing socket connection.
|
||||
else:
|
||||
response = await do_rpc_call(socket, request)
|
||||
|
||||
if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR:
|
||||
if isinstance(response, Exception):
|
||||
logger.error(error_message)
|
||||
raise response
|
||||
raise ValueError(error_message)
|
||||
|
||||
async def get_tokenizer(self, lora_request: LoRARequest):
|
||||
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
|
||||
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
return self.decoding_config
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
return self.model_config
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.tracing_flag
|
||||
|
||||
async def _wait_for_server_rpc(self):
|
||||
"""Wait for the RPCServer to start up."""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_READY,
|
||||
error_message="Unable to start RPC Server")
|
||||
|
||||
async def _get_model_config_rpc(self) -> ModelConfig:
|
||||
"""Get the ModelConfig object from the RPC Server"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
||||
expected_type=ModelConfig,
|
||||
error_message="Could not get ModelConfig from RPC Server")
|
||||
|
||||
async def _get_decoding_config_rpc(self) -> DecodingConfig:
|
||||
"""Get DecodingConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
||||
expected_type=DecodingConfig,
|
||||
error_message="Could not get DecodingConfig from RPC Server")
|
||||
|
||||
async def _get_parallel_config_rpc(self) -> ParallelConfig:
|
||||
"""Get ParallelConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
||||
expected_type=ParallelConfig,
|
||||
error_message="Could not get ParallelConfig from RPC Server")
|
||||
|
||||
async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
|
||||
"""Get SchedulerConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
||||
expected_type=SchedulerConfig,
|
||||
error_message="Could not get SchedulerConfig from RPC Server")
|
||||
|
||||
async def _get_lora_config_rpc(self) -> LoRAConfig:
|
||||
"""Get LoRAConfig from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.GET_LORA_CONFIG,
|
||||
expected_type=LoRAConfig,
|
||||
error_message="Could not get LoRAConfig from RPC Server")
|
||||
|
||||
async def _is_tracing_enabled_rpc(self) -> bool:
|
||||
"""Get is_tracing_enabled flag from the RPCServer"""
|
||||
|
||||
return await self._send_get_data_rpc_request(
|
||||
RPCUtilityRequest.IS_TRACING_ENABLED,
|
||||
expected_type=bool,
|
||||
error_message="Could not get is_tracing_enabled from RPC Server")
|
||||
|
||||
async def abort(self, request_id: str):
|
||||
"""Send an ABORT_REQUEST signal to the RPC Server"""
|
||||
|
||||
# Suppress timeouts as well.
|
||||
# In cases where the server is busy processing requests and a very
|
||||
# large volume of abort requests arrive, it is likely that the server
|
||||
# will not be able to ack all of them in time. We have seen this when
|
||||
# we abort 20k requests at once while another 2k are processing- many
|
||||
# of them time out, but we see the server successfully abort all of the
|
||||
# requests.
|
||||
# In this case we assume that the server has received or will receive
|
||||
# these abort requests, and ignore the timeout. This prevents a massive
|
||||
# wall of `TimeoutError` stack traces.
|
||||
with suppress(RPCClientClosedError, TimeoutError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCAbortRequest(request_id),
|
||||
error_message=f"RPCAbortRequest {request_id} failed")
|
||||
|
||||
async def do_log_stats(self):
|
||||
"""Send a DO_LOG_STATS signal to the RPC Server"""
|
||||
with suppress(RPCClientClosedError):
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.DO_LOG_STATS,
|
||||
error_message="RPCRequest DO_LOG_STATS failed.")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return not self._errored
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self._errored
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
|
||||
finished = False
|
||||
try:
|
||||
with self.to_proxy_socket() as socket:
|
||||
# Send RPCGenerateRequest to the RPCServer.
|
||||
await socket.send_multipart((cloudpickle.dumps(
|
||||
RPCGenerateRequest(
|
||||
inputs=inputs,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)), ))
|
||||
|
||||
# Stream back the results from the RPC Server.
|
||||
while not finished:
|
||||
message = await socket.recv(copy=False)
|
||||
assert isinstance(message, Frame)
|
||||
request_output = pickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
# On exception, check if the server is still healthy
|
||||
# possibly setting the `errored` property.
|
||||
if not self._errored:
|
||||
try:
|
||||
await self.check_health(socket=socket)
|
||||
except Exception as e:
|
||||
self._errored = True
|
||||
logger.exception(repr(e))
|
||||
|
||||
# NB: do before raising here so that the flag is set
|
||||
# by the time the caller receives this exception
|
||||
raise request_output
|
||||
|
||||
finished = request_output.finished
|
||||
yield request_output
|
||||
|
||||
finally:
|
||||
# Request was canceled by the client.
|
||||
if not finished and not self._errored:
|
||||
await self.abort(request_id)
|
||||
|
||||
async def check_health(self, socket: Optional[Socket] = None) -> None:
|
||||
"""Raise if unhealthy"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.IS_SERVER_HEALTHY,
|
||||
error_message="Got Unhealthy response from RPC Server",
|
||||
socket=socket)
|
||||
|
||||
async def encode(self, *args,
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.START_PROFILE,
|
||||
error_message="RPCRequest START_PROFILE failed.")
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Stop profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.STOP_PROFILE,
|
||||
error_message="RPCRequest STOP_PROFILE failed.")
|
@ -1,243 +0,0 @@
|
||||
import asyncio
|
||||
import pickle
|
||||
import signal
|
||||
from typing import Any, Coroutine, Union
|
||||
|
||||
import cloudpickle
|
||||
import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from typing_extensions import Never
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR,
|
||||
VLLM_RPC_ZMQ_HWM, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig]
|
||||
|
||||
|
||||
class AsyncEngineRPCServer:
|
||||
|
||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, rpc_path: str):
|
||||
# Initialize engine first.
|
||||
self.engine = AsyncLLMEngine.from_engine_args(
|
||||
async_engine_args, usage_context=usage_context)
|
||||
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
# Init socket.
|
||||
self.socket: Socket = self.context.socket(zmq.constants.DEALER)
|
||||
self.socket.set_hwm(VLLM_RPC_ZMQ_HWM)
|
||||
self.socket.connect(rpc_path)
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup all resources."""
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
# Clear the engine reference so that it can be GC'ed.
|
||||
del self.engine
|
||||
|
||||
async def get_config(self, identity, request):
|
||||
try:
|
||||
config: CONFIG_TYPE
|
||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||
config = await self.engine.get_model_config()
|
||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||
config = await self.engine.get_decoding_config()
|
||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||
config = await self.engine.get_lora_config()
|
||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||
config = await self.engine.get_scheduler_config()
|
||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||
config = await self.engine.get_parallel_config()
|
||||
else:
|
||||
raise ValueError("Unknown Config Request: %s", request)
|
||||
|
||||
await self.socket.send_multipart((identity, pickle.dumps(config)),
|
||||
copy=False)
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def is_tracing_enabled(self, identity):
|
||||
"""Send the is_tracing_enabled flag"""
|
||||
tracing_flag = await self.engine.is_tracing_enabled()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
(identity, pickle.dumps(tracing_flag)))
|
||||
|
||||
async def do_log_stats(self, identity):
|
||||
"""Log stats and confirm success."""
|
||||
await self.engine.do_log_stats()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
async def is_server_ready(self, identity):
|
||||
"""Notify the client that we are ready."""
|
||||
await self.socket.send_multipart(
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
try:
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR
|
||||
except Exception as e:
|
||||
result = e
|
||||
await self.socket.send_multipart((identity, pickle.dumps(result)))
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
results_generator = self.engine.generate(
|
||||
generate_request.inputs,
|
||||
sampling_params=generate_request.sampling_params,
|
||||
request_id=generate_request.request_id,
|
||||
lora_request=generate_request.lora_request,
|
||||
trace_headers=generate_request.trace_headers,
|
||||
prompt_adapter_request=generate_request.prompt_adapter_request)
|
||||
|
||||
async for request_output in results_generator:
|
||||
await self.socket.send_multipart(
|
||||
(identity, pickle.dumps(request_output)), copy=False)
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def check_health(self, identity):
|
||||
try:
|
||||
await self.engine.check_health()
|
||||
await self.socket.send_multipart(
|
||||
(identity, pickle.dumps(VLLM_RPC_SUCCESS_STR)))
|
||||
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart((identity, pickle.dumps(e)),
|
||||
copy=False)
|
||||
|
||||
async def start_profile(self, identity):
|
||||
logger.info("Starting profiler...")
|
||||
await self.engine.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
|
||||
await self.socket.send_multipart((
|
||||
identity,
|
||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
))
|
||||
|
||||
async def stop_profile(self, identity):
|
||||
logger.info("Stopping profiler...")
|
||||
await self.engine.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
|
||||
await self.socket.send_multipart((
|
||||
identity,
|
||||
pickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
))
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message: Frame) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
|
||||
request = cloudpickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
return self.generate(identity, request)
|
||||
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
return self.abort(identity, request)
|
||||
|
||||
elif isinstance(request, RPCUtilityRequest):
|
||||
if request in [
|
||||
RPCUtilityRequest.GET_MODEL_CONFIG,
|
||||
RPCUtilityRequest.GET_PARALLEL_CONFIG,
|
||||
RPCUtilityRequest.GET_DECODING_CONFIG,
|
||||
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
|
||||
RPCUtilityRequest.GET_LORA_CONFIG
|
||||
]:
|
||||
return self.get_config(identity, request)
|
||||
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
||||
return self.do_log_stats(identity)
|
||||
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
||||
return self.is_server_ready(identity)
|
||||
elif request == RPCUtilityRequest.IS_SERVER_HEALTHY:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
elif request == RPCUtilityRequest.START_PROFILE:
|
||||
return self.start_profile(identity)
|
||||
elif request == RPCUtilityRequest.STOP_PROFILE:
|
||||
return self.stop_profile(identity)
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCRequest type: {request}")
|
||||
|
||||
async def run_server_loop(self):
|
||||
"""Inner RPC Server Loop"""
|
||||
|
||||
running_tasks = set()
|
||||
while True:
|
||||
# Wait for a request.
|
||||
identity, message = await self.socket.recv_multipart(copy=False)
|
||||
|
||||
# Process the request async.
|
||||
task = asyncio.create_task(
|
||||
self._make_handler_coro(identity, message))
|
||||
|
||||
# We need to keep around a strong reference to the task,
|
||||
# to avoid the task disappearing mid-execution as running tasks
|
||||
# can be GC'ed. Below is a common "fire-and-forget" tasks
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
running_tasks.add(task)
|
||||
task.add_done_callback(running_tasks.discard)
|
||||
|
||||
|
||||
async def run_server(server: AsyncEngineRPCServer):
|
||||
# Put the server task into the asyncio loop.
|
||||
loop = asyncio.get_running_loop()
|
||||
server_task = loop.create_task(server.run_server_loop())
|
||||
|
||||
# Interruption handling.
|
||||
def signal_handler() -> None:
|
||||
# Kill the server on interrupt / terminate
|
||||
server_task.cancel()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("vLLM ZMQ RPC Server was interrupted.")
|
||||
finally:
|
||||
# Clean up all resources.
|
||||
server.cleanup()
|
||||
|
||||
|
||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, rpc_path: str):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
|
||||
uvloop.run(run_server(server))
|
@ -9,7 +9,7 @@ from typing import Union
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
@ -45,7 +45,7 @@ logger = init_logger(__name__)
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
@ -57,7 +57,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None):
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -105,6 +105,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@ -112,8 +118,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
request.messages, model_config, tokenizer)
|
||||
@ -207,8 +212,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = (
|
||||
await self.async_engine_client.is_tracing_enabled())
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
@ -216,7 +221,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.async_engine_client.generate(
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
|
@ -8,7 +8,7 @@ from typing import Tuple, Union, cast
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -43,7 +43,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -52,7 +52,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -78,6 +78,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
@ -95,8 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
@ -124,8 +129,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = (
|
||||
await self.async_engine_client.is_tracing_enabled())
|
||||
is_tracing_enabled = (await
|
||||
self.engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
@ -133,7 +138,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
generator = self.async_engine_client.generate(
|
||||
generator = self.engine_client.generate(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
|
@ -8,7 +8,7 @@ from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
@ -71,13 +71,13 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=None,
|
||||
@ -118,8 +118,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
@ -144,7 +143,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
generator = self.async_engine_client.encode(
|
||||
generator = self.engine_client.encode(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
|
@ -8,7 +8,7 @@ from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -64,7 +64,7 @@ class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -75,7 +75,7 @@ class OpenAIServing:
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.async_engine_client = async_engine_client
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
self.max_model_len = model_config.max_model_len
|
||||
|
||||
@ -159,7 +159,7 @@ class OpenAIServing:
|
||||
async def _guided_decode_logits_processor(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
||||
decoding_config = await self.async_engine_client.get_decoding_config()
|
||||
decoding_config = await self.engine_client.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
return await get_guided_decoding_logits_processor(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
load_chat_template,
|
||||
@ -29,7 +29,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
@ -37,7 +37,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@ -66,7 +66,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
@ -132,7 +132,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
|
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
||||
VERBOSE: bool = False
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_TEST_FORCE_FP8_MARLIN: bool = False
|
||||
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
|
||||
VLLM_RPC_TIMEOUT: int = 10000 # ms
|
||||
VLLM_PLUGINS: Optional[List[str]] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
@ -393,8 +393,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Time in ms for the zmq client to wait for a response from the backend
|
||||
# server for simple data operations
|
||||
"VLLM_RPC_GET_DATA_TIMEOUT_MS":
|
||||
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
|
||||
"VLLM_RPC_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
|
||||
|
||||
# a list of plugin names to load, separated by commas.
|
||||
# if this is not set, it means all plugins will be loaded
|
||||
|
@ -106,6 +106,7 @@ class CPUExecutor(ExecutorBase):
|
||||
)) for rank in range(1, world_size)
|
||||
]
|
||||
|
||||
self.worker_monitor = None
|
||||
if world_size != 1 or is_async:
|
||||
if is_async:
|
||||
async_worker_list = self.workers + [self.driver_worker]
|
||||
|
@ -168,6 +168,8 @@ class ProcessWorkerWrapper:
|
||||
self.tasks[task_id] = future
|
||||
try:
|
||||
self._task_queue.put((task_id, method, args, kwargs))
|
||||
except SystemExit:
|
||||
raise
|
||||
except BaseException as e:
|
||||
del self.tasks[task_id]
|
||||
raise ChildProcessError("worker died") from e
|
||||
@ -222,6 +224,8 @@ def _run_worker_process(
|
||||
try:
|
||||
executor = getattr(worker, method)
|
||||
output = executor(*args, **kwargs)
|
||||
except SystemExit:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except BaseException as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user