[Bugfix] Do not crash V0 engine on input errors (#13101)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
ec8a5e5386
commit
3f808cc044
@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
|
||||
ipc_path: str):
|
||||
"""Simulate an exception while preparing inputs for the model.
|
||||
In the wild, this could be something like a multimodal input processor
|
||||
failing on invalid image data."""
|
||||
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
runner = engine.engine.model_executor.driver_worker.worker.model_runner
|
||||
|
||||
# Raise error in the model runner when adding a sequence group.
|
||||
# See class ModelInputForGPUBuilder
|
||||
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
|
||||
if seq_group_metadata.request_id.startswith("evil"):
|
||||
raise RAISED_ERROR(RAISED_VALUE)
|
||||
|
||||
runner.builder.per_seq_group_compute_fns.append(raiser)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_inputs(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_input_processing) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Engine should be healthy
|
||||
await client.check_health()
|
||||
|
||||
async def run_failing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id="evil" + str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
async def run_passing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
passing_tasks = [
|
||||
asyncio.create_task(run_passing_request()) for _ in range(10)
|
||||
]
|
||||
failing_tasks = [
|
||||
asyncio.create_task(run_failing_request()) for _ in range(10)
|
||||
]
|
||||
await asyncio.gather(*failing_tasks, return_exceptions=True)
|
||||
await asyncio.gather(*passing_tasks)
|
||||
|
||||
# All the bad inputs should have raised
|
||||
for task in failing_tasks:
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
task.result()
|
||||
|
||||
# But all good inputs should have still succeeded
|
||||
for task in passing_tasks:
|
||||
task.result()
|
||||
|
||||
# And the engine should remain healthy
|
||||
assert not client.errored
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
@ -60,6 +60,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
from vllm.utils import (Counter, Device, deprecate_kwargs,
|
||||
resolve_obj_by_qualname, weak_bind)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
@ -410,6 +411,10 @@ class LLMEngine:
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
||||
# Flag to set when an input fails to process and the engine should run
|
||||
# the next step without re-scheduling.
|
||||
self._skip_scheduling_next_step = False
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -1334,7 +1339,11 @@ class LLMEngine:
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# The scheduler is also skipped if a single request caused the last
|
||||
# engine step to fail, and the previous schedule needs to be rerun.
|
||||
if not self._has_remaining_steps(
|
||||
seq_group_metadata_list
|
||||
) and not self._skip_scheduling_next_step:
|
||||
# Schedule iteration
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
@ -1388,8 +1397,23 @@ class LLMEngine:
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
try:
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
self._skip_scheduling_next_step = False
|
||||
except InputProcessingError as e:
|
||||
# The input for this request cannot be processed, so we must
|
||||
# abort it. If there are remaining requests in the batch that
|
||||
# have been scheduled, they will be retried on the next step.
|
||||
invalid_request_id = e.request_id
|
||||
self._abort_and_cache_schedule(
|
||||
request_id=invalid_request_id,
|
||||
virtual_engine=virtual_engine,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
allow_async_output_proc=allow_async_output_proc)
|
||||
# Raise so the caller is notified that this request failed
|
||||
raise
|
||||
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
@ -1464,6 +1488,38 @@ class LLMEngine:
|
||||
|
||||
return ctx.request_outputs
|
||||
|
||||
def _abort_and_cache_schedule(
|
||||
self, request_id: str, virtual_engine: int,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs,
|
||||
allow_async_output_proc: bool) -> None:
|
||||
"""Aborts a single request, and caches the scheduler outputs minus that
|
||||
request. This allows the next step to continue processing the remaining
|
||||
requests without having to re-run the scheduler."""
|
||||
|
||||
# Abort the request and remove its sequence group from the current
|
||||
# schedule
|
||||
self.abort_request(request_id)
|
||||
for i, metadata in enumerate(seq_group_metadata_list):
|
||||
if metadata.request_id == request_id:
|
||||
del seq_group_metadata_list[i]
|
||||
break
|
||||
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
|
||||
if group.seq_group.request_id == request_id:
|
||||
del scheduler_outputs.scheduled_seq_groups[i]
|
||||
break
|
||||
|
||||
# If there are still other sequence groups left in the schedule, cache
|
||||
# them and flag the engine to reuse the schedule.
|
||||
if len(seq_group_metadata_list) > 0:
|
||||
self._skip_scheduling_next_step = True
|
||||
# Reuse multi-step caching logic
|
||||
self._cache_scheduler_outputs_for_multi_step(
|
||||
virtual_engine=virtual_engine,
|
||||
scheduler_outputs=scheduler_outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
allow_async_output_proc=allow_async_output_proc)
|
||||
|
||||
def _has_remaining_steps(
|
||||
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
|
||||
) -> bool:
|
||||
|
@ -27,6 +27,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.worker.model_runner_base import InputProcessingError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -210,6 +211,14 @@ class MQLLMEngine:
|
||||
return self.engine.step()
|
||||
except SystemExit:
|
||||
raise
|
||||
except InputProcessingError as e:
|
||||
# Special case where we handle an error preparing the inputs for
|
||||
# a single request in the batch
|
||||
rpc_err = RPCError(request_id=e.request_id,
|
||||
is_engine_errored=False,
|
||||
exception=e.__cause__)
|
||||
self._send_outputs(rpc_err)
|
||||
return []
|
||||
except BaseException as e:
|
||||
self._set_errored(e)
|
||||
rpc_err = RPCError(request_id=None,
|
||||
|
@ -53,8 +53,8 @@ from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
|
||||
is_pin_memory_available, supports_dynamo,
|
||||
weak_ref_tensor)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
|
||||
ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
@ -1216,7 +1216,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"""
|
||||
self.builder.prepare(finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
self.builder.add_seq_group(seq_group_metadata)
|
||||
try:
|
||||
self.builder.add_seq_group(seq_group_metadata)
|
||||
except Exception as e:
|
||||
# Raise an exception that tracks the ID of the bad request
|
||||
raise InputProcessingError(seq_group_metadata.request_id,
|
||||
str(e)) from e
|
||||
|
||||
self.builder.reset_cached_inter_data()
|
||||
|
||||
|
@ -261,3 +261,21 @@ class ModelRunnerWrapperBase:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.model_runner, attr)
|
||||
|
||||
|
||||
class InputProcessingError(Exception):
|
||||
"""This exception is raised when an error occurs preparing the inputs for
|
||||
a single sequence group.
|
||||
This allows the engine to gracefully handle errors with a single sequence
|
||||
group without having to fail the entire batch.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id, message):
|
||||
"""request_id is the id of the offending sequence group"""
|
||||
self.request_id = request_id
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return "Failed to prepare inputs for sequence group with request id: " \
|
||||
f"{self.request_id}, Error: {self.message}"
|
||||
|
Loading…
x
Reference in New Issue
Block a user