[Bugfix] Do not crash V0 engine on input errors (#13101)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2025-02-26 04:07:29 -07:00 committed by GitHub
parent ec8a5e5386
commit 3f808cc044
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 172 additions and 6 deletions

View File

@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroupMetadata
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
await client.check_health()
client.close()
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
ipc_path: str):
"""Simulate an exception while preparing inputs for the model.
In the wild, this could be something like a multimodal input processor
failing on invalid image data."""
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
runner = engine.engine.model_executor.driver_worker.worker.model_runner
# Raise error in the model runner when adding a sequence group.
# See class ModelInputForGPUBuilder
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
if seq_group_metadata.request_id.startswith("evil"):
raise RAISED_ERROR(RAISED_VALUE)
runner.builder.per_seq_group_compute_fns.append(raiser)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_inputs(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_input_processing) as engine:
client = await engine.make_client()
assert client.is_running
# Engine should be healthy
await client.check_health()
async def run_failing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id="evil" + str(uuid.uuid4())):
pass
async def run_passing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=str(uuid.uuid4())):
pass
passing_tasks = [
asyncio.create_task(run_passing_request()) for _ in range(10)
]
failing_tasks = [
asyncio.create_task(run_failing_request()) for _ in range(10)
]
await asyncio.gather(*failing_tasks, return_exceptions=True)
await asyncio.gather(*passing_tasks)
# All the bad inputs should have raised
for task in failing_tasks:
with pytest.raises(RAISED_ERROR):
task.result()
# But all good inputs should have still succeeded
for task in passing_tasks:
task.result()
# And the engine should remain healthy
assert not client.errored
await client.check_health()
client.close()

View File

@ -60,6 +60,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@ -410,6 +411,10 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -1334,7 +1339,11 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if not self._has_remaining_steps(
seq_group_metadata_list
) and not self._skip_scheduling_next_step:
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
@ -1388,8 +1397,23 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
try:
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
self._skip_scheduling_next_step = False
except InputProcessingError as e:
# The input for this request cannot be processed, so we must
# abort it. If there are remaining requests in the batch that
# have been scheduled, they will be retried on the next step.
invalid_request_id = e.request_id
self._abort_and_cache_schedule(
request_id=invalid_request_id,
virtual_engine=virtual_engine,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
allow_async_output_proc=allow_async_output_proc)
# Raise so the caller is notified that this request failed
raise
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
@ -1464,6 +1488,38 @@ class LLMEngine:
return ctx.request_outputs
def _abort_and_cache_schedule(
self, request_id: str, virtual_engine: int,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs,
allow_async_output_proc: bool) -> None:
"""Aborts a single request, and caches the scheduler outputs minus that
request. This allows the next step to continue processing the remaining
requests without having to re-run the scheduler."""
# Abort the request and remove its sequence group from the current
# schedule
self.abort_request(request_id)
for i, metadata in enumerate(seq_group_metadata_list):
if metadata.request_id == request_id:
del seq_group_metadata_list[i]
break
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
if group.seq_group.request_id == request_id:
del scheduler_outputs.scheduled_seq_groups[i]
break
# If there are still other sequence groups left in the schedule, cache
# them and flag the engine to reuse the schedule.
if len(seq_group_metadata_list) > 0:
self._skip_scheduling_next_step = True
# Reuse multi-step caching logic
self._cache_scheduler_outputs_for_multi_step(
virtual_engine=virtual_engine,
scheduler_outputs=scheduler_outputs,
seq_group_metadata_list=seq_group_metadata_list,
allow_async_output_proc=allow_async_output_proc)
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:

View File

@ -27,6 +27,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
@ -210,6 +211,14 @@ class MQLLMEngine:
return self.engine.step()
except SystemExit:
raise
except InputProcessingError as e:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err = RPCError(request_id=e.request_id,
is_engine_errored=False,
exception=e.__cause__)
self._send_outputs(rpc_err)
return []
except BaseException as e:
self._set_errored(e)
rpc_err = RPCError(request_id=None,

View File

@ -53,8 +53,8 @@ from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
is_pin_memory_available, supports_dynamo,
weak_ref_tensor)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
@ -1216,7 +1216,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
self.builder.prepare(finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
self.builder.add_seq_group(seq_group_metadata)
try:
self.builder.add_seq_group(seq_group_metadata)
except Exception as e:
# Raise an exception that tracks the ID of the bad request
raise InputProcessingError(seq_group_metadata.request_id,
str(e)) from e
self.builder.reset_cached_inter_data()

View File

@ -261,3 +261,21 @@ class ModelRunnerWrapperBase:
def __getattr__(self, attr):
return getattr(self.model_runner, attr)
class InputProcessingError(Exception):
"""This exception is raised when an error occurs preparing the inputs for
a single sequence group.
This allows the engine to gracefully handle errors with a single sequence
group without having to fail the entire batch.
"""
def __init__(self, request_id, message):
"""request_id is the id of the offending sequence group"""
self.request_id = request_id
self.message = message
super().__init__(self.message)
def __str__(self):
return "Failed to prepare inputs for sequence group with request id: " \
f"{self.request_id}, Error: {self.message}"