[Core] Add fault tolerance for RayTokenizerGroupPool (#5748)

This commit is contained in:
Antoni Baum 2024-06-25 10:15:10 -07:00 committed by GitHub
parent 7b99314301
commit 67882dbb44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 194 additions and 23 deletions

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
import os import os
import sys
from typing import List, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
max_num_seqs=1, max_num_seqs=1,
max_input_length=None) max_input_length=None)
tokenizer_pool.ping() tokenizer_pool.ping()
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""
class FailingTokenizerGroup(TokenizerGroup):
def __init__(self,
*args,
fail_at: Optional[List[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
self.fail_at = fail_at or []
def encode(self, *args, **kwargs):
self.i += 1
if self.i in self.fail_at:
sys.exit(1)
return super().encode(*args, **kwargs)
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = FailingTokenizerGroup
# Fail at first iteration
fail_at = [1]
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at[0] = 1000
# We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
# Fail at first iteration
fail_at = [1]
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# check_health should raise the same thing
with pytest.raises(RuntimeError):
tokenizer_group_pool.check_health()
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at = []
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=2,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt" * 100,
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors

View File

@ -310,6 +310,8 @@ class _AsyncLLMEngine(LLMEngine):
) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()

View File

@ -1013,6 +1013,8 @@ class LLMEngine:
return self.model_executor.pin_lora(lora_id) return self.model_executor.pin_lora(lora_id)
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:

View File

@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
) -> "PreTrainedTokenizer": ) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass
def check_health(self):
"""Raise exception if the tokenizer group is unhealthy."""
return

View File

@ -2,17 +2,21 @@ import asyncio
import os import os
from typing import List, Optional from typing import List, Optional
from ray.exceptions import ActorDiedError
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup) BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup) TokenizerGroup)
logger = init_logger(__name__)
class RayTokenizerGroupPool(BaseTokenizerGroup): class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization.""" """A Ray-based pool of TokenizerGroups for async tokenization."""
@ -46,24 +50,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
ray_actor_options: dict, **tokenizer_config): ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access # Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers. # to underlying HF tokenizers.
self._tokenizer_config = {
"tokenizer_id": tokenizer_id,
"enable_lora": enable_lora,
"max_num_seqs": max_num_seqs,
"max_input_length": max_input_length,
**tokenizer_config
}
self._local_tokenizer_group = self._worker_cls( self._local_tokenizer_group = self._worker_cls(
tokenizer_id=tokenizer_id, **self._tokenizer_config, )
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_input_length=max_input_length,
**tokenizer_config,
)
ray_tokenizer_group_cls = ray.remote( self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options) self._worker_cls).options(**ray_actor_options)
self.tokenizer_actors = [ self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
max_num_seqs, max_input_length,
**tokenizer_config)
for _ in range(num_actors)
]
self._idle_actors: Optional[asyncio.Queue] = None self._idle_actors: Optional[asyncio.Queue] = None
# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self._exception: Optional[ActorDiedError] = None
def _init_actor(self) -> ray.ObjectRef:
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
@property @property
def pool_size(self) -> int: def pool_size(self) -> int:
return len(self.tokenizer_actors) return len(self.tokenizer_actors)
@ -78,6 +86,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
for actor in self.tokenizer_actors: for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor) self._idle_actors.put_nowait(actor)
def _finalize_encode(self, actor: ray.ObjectRef,
original_actor: ray.ObjectRef, actor_is_alive: bool):
assert self._idle_actors is not None
# Cleanup the dead actor.
if not actor_is_alive or original_actor is not actor:
self.tokenizer_actors.remove(original_actor)
if actor_is_alive:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
# Add back the new actor.
if original_actor is not actor:
self.tokenizer_actors.append(actor)
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None, request_id: Optional[str] = None,
@ -88,23 +112,41 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use. The actor is then put back in the queue for future use.
This is blocking. This is blocking.
""" """
self.check_health()
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None assert self._idle_actors is not None
if self._idle_actors.empty(): if self._idle_actors.empty():
raise RuntimeError("No idle actors available.") raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait() actor = self._idle_actors.get_nowait()
actor_is_alive = True
original_actor = actor
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request)) lora_request=lora_request))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally: finally:
# Put the actor back in the queue. self._finalize_encode(actor, original_actor, actor_is_alive)
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret return ret
async def encode_async( async def encode_async(
@ -120,20 +162,37 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
The actor is then put back in the queue for future use. The actor is then put back in the queue for future use.
This is non-blocking. This is non-blocking.
""" """
self.check_health()
self._ensure_queue_initialized() self._ensure_queue_initialized()
assert self._idle_actors is not None assert self._idle_actors is not None
actor = await self._idle_actors.get() actor = await self._idle_actors.get()
actor_is_alive = True
original_actor = actor
try: try:
ret = await actor.encode.remote(request_id=request_id, ret = await actor.encode.remote(request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request) lora_request=lora_request)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally: finally:
# Put the actor back in the queue. self._finalize_encode(actor, original_actor, actor_is_alive)
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret return ret
def get_max_input_len(self, def get_max_input_len(self,
@ -155,6 +214,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
return await self._local_tokenizer_group.get_lora_tokenizer_async( return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request) lora_request)
def check_health(self):
if self._exception:
raise RuntimeError(
"TokenizerGroupPool is unhealthy.") from self._exception
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env. """Copy over all current process environment variables to the runtime_env.