[Core] Add fault tolerance for RayTokenizerGroupPool
(#5748)
This commit is contained in:
parent
7b99314301
commit
67882dbb44
@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
|
|||||||
max_num_seqs=1,
|
max_num_seqs=1,
|
||||||
max_input_length=None)
|
max_input_length=None)
|
||||||
tokenizer_pool.ping()
|
tokenizer_pool.ping()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||||
|
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||||
|
"""Test that Ray tokenizer pool group can recover from failures and
|
||||||
|
if that's not possible, mark itself as unhealthy."""
|
||||||
|
|
||||||
|
class FailingTokenizerGroup(TokenizerGroup):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*args,
|
||||||
|
fail_at: Optional[List[int]] = None,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.i = 0
|
||||||
|
self.fail_at = fail_at or []
|
||||||
|
|
||||||
|
def encode(self, *args, **kwargs):
|
||||||
|
self.i += 1
|
||||||
|
if self.i in self.fail_at:
|
||||||
|
sys.exit(1)
|
||||||
|
return super().encode(*args, **kwargs)
|
||||||
|
|
||||||
|
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||||
|
_worker_cls = FailingTokenizerGroup
|
||||||
|
|
||||||
|
# Fail at first iteration
|
||||||
|
fail_at = [1]
|
||||||
|
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||||
|
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||||
|
tokenizer_pool_config,
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
fail_at=fail_at)
|
||||||
|
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||||
|
|
||||||
|
# Modify fail at to not fail at all (will be re-read when actor is
|
||||||
|
# re-initialized).
|
||||||
|
fail_at[0] = 1000
|
||||||
|
|
||||||
|
# We should recover successfully.
|
||||||
|
await tokenizer_group_pool.encode_async(request_id="1",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
await tokenizer_group_pool.encode_async(request_id="1",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
|
||||||
|
# Check that we have a new actor
|
||||||
|
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||||
|
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
|
||||||
|
|
||||||
|
# Fail at first iteration
|
||||||
|
fail_at = [1]
|
||||||
|
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||||
|
tokenizer_pool_config,
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=None,
|
||||||
|
fail_at=fail_at)
|
||||||
|
|
||||||
|
# We should fail after re-initialization.
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await tokenizer_group_pool.encode_async(request_id="1",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
|
||||||
|
# check_health should raise the same thing
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
tokenizer_group_pool.check_health()
|
||||||
|
|
||||||
|
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
|
||||||
|
# cause a re-initialization.
|
||||||
|
fail_at = []
|
||||||
|
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||||
|
tokenizer_pool_config,
|
||||||
|
tokenizer_id="gpt2",
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_input_length=2,
|
||||||
|
fail_at=fail_at)
|
||||||
|
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||||
|
|
||||||
|
# Prompt too long error
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await tokenizer_group_pool.encode_async(request_id="1",
|
||||||
|
prompt="prompt" * 100,
|
||||||
|
lora_request=None)
|
||||||
|
await tokenizer_group_pool.encode_async(request_id="1",
|
||||||
|
prompt="prompt",
|
||||||
|
lora_request=None)
|
||||||
|
# Actors should stay the same.
|
||||||
|
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||||
|
@ -310,6 +310,8 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
|
if self.tokenizer:
|
||||||
|
self.tokenizer.check_health()
|
||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1013,6 +1013,8 @@ class LLMEngine:
|
|||||||
return self.model_executor.pin_lora(lora_id)
|
return self.model_executor.pin_lora(lora_id)
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
|
if self.tokenizer:
|
||||||
|
self.tokenizer.check_health()
|
||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
|
|
||||||
def is_tracing_enabled(self) -> bool:
|
def is_tracing_enabled(self) -> bool:
|
||||||
|
@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
|
|||||||
) -> "PreTrainedTokenizer":
|
) -> "PreTrainedTokenizer":
|
||||||
"""Get a tokenizer for a LoRA request."""
|
"""Get a tokenizer for a LoRA request."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def check_health(self):
|
||||||
|
"""Raise exception if the tokenizer group is unhealthy."""
|
||||||
|
return
|
||||||
|
@ -2,17 +2,21 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ray.exceptions import ActorDiedError
|
||||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.executor.ray_utils import ray
|
from vllm.executor.ray_utils import ray
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||||
BaseTokenizerGroup)
|
BaseTokenizerGroup)
|
||||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||||
TokenizerGroup)
|
TokenizerGroup)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||||
"""A Ray-based pool of TokenizerGroups for async tokenization."""
|
"""A Ray-based pool of TokenizerGroups for async tokenization."""
|
||||||
@ -46,24 +50,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
ray_actor_options: dict, **tokenizer_config):
|
ray_actor_options: dict, **tokenizer_config):
|
||||||
# Store a local copy of the TokenizerGroup for quick access
|
# Store a local copy of the TokenizerGroup for quick access
|
||||||
# to underlying HF tokenizers.
|
# to underlying HF tokenizers.
|
||||||
|
self._tokenizer_config = {
|
||||||
|
"tokenizer_id": tokenizer_id,
|
||||||
|
"enable_lora": enable_lora,
|
||||||
|
"max_num_seqs": max_num_seqs,
|
||||||
|
"max_input_length": max_input_length,
|
||||||
|
**tokenizer_config
|
||||||
|
}
|
||||||
self._local_tokenizer_group = self._worker_cls(
|
self._local_tokenizer_group = self._worker_cls(
|
||||||
tokenizer_id=tokenizer_id,
|
**self._tokenizer_config, )
|
||||||
enable_lora=enable_lora,
|
|
||||||
max_num_seqs=max_num_seqs,
|
|
||||||
max_input_length=max_input_length,
|
|
||||||
**tokenizer_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
ray_tokenizer_group_cls = ray.remote(
|
self._ray_tokenizer_group_cls = ray.remote(
|
||||||
self._worker_cls).options(**ray_actor_options)
|
self._worker_cls).options(**ray_actor_options)
|
||||||
self.tokenizer_actors = [
|
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
|
||||||
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
|
|
||||||
max_num_seqs, max_input_length,
|
|
||||||
**tokenizer_config)
|
|
||||||
for _ in range(num_actors)
|
|
||||||
]
|
|
||||||
self._idle_actors: Optional[asyncio.Queue] = None
|
self._idle_actors: Optional[asyncio.Queue] = None
|
||||||
|
|
||||||
|
# If set, actor is unhealthy. Will reraise on the next
|
||||||
|
# check_health call.
|
||||||
|
self._exception: Optional[ActorDiedError] = None
|
||||||
|
|
||||||
|
def _init_actor(self) -> ray.ObjectRef:
|
||||||
|
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pool_size(self) -> int:
|
def pool_size(self) -> int:
|
||||||
return len(self.tokenizer_actors)
|
return len(self.tokenizer_actors)
|
||||||
@ -78,6 +86,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
for actor in self.tokenizer_actors:
|
for actor in self.tokenizer_actors:
|
||||||
self._idle_actors.put_nowait(actor)
|
self._idle_actors.put_nowait(actor)
|
||||||
|
|
||||||
|
def _finalize_encode(self, actor: ray.ObjectRef,
|
||||||
|
original_actor: ray.ObjectRef, actor_is_alive: bool):
|
||||||
|
assert self._idle_actors is not None
|
||||||
|
# Cleanup the dead actor.
|
||||||
|
if not actor_is_alive or original_actor is not actor:
|
||||||
|
self.tokenizer_actors.remove(original_actor)
|
||||||
|
if actor_is_alive:
|
||||||
|
# Put the actor back in the queue.
|
||||||
|
# This is done in a finally block to ensure that the actor is
|
||||||
|
# always put back in the queue, even if an exception/cancellation
|
||||||
|
# is raised.
|
||||||
|
self._idle_actors.put_nowait(actor)
|
||||||
|
# Add back the new actor.
|
||||||
|
if original_actor is not actor:
|
||||||
|
self.tokenizer_actors.append(actor)
|
||||||
|
|
||||||
def encode(self,
|
def encode(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
@ -88,23 +112,41 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
The actor is then put back in the queue for future use.
|
The actor is then put back in the queue for future use.
|
||||||
This is blocking.
|
This is blocking.
|
||||||
"""
|
"""
|
||||||
|
self.check_health()
|
||||||
self._ensure_queue_initialized()
|
self._ensure_queue_initialized()
|
||||||
assert self._idle_actors is not None
|
assert self._idle_actors is not None
|
||||||
|
|
||||||
if self._idle_actors.empty():
|
if self._idle_actors.empty():
|
||||||
raise RuntimeError("No idle actors available.")
|
raise RuntimeError("No idle actors available.")
|
||||||
actor = self._idle_actors.get_nowait()
|
actor = self._idle_actors.get_nowait()
|
||||||
|
actor_is_alive = True
|
||||||
|
original_actor = actor
|
||||||
try:
|
try:
|
||||||
ret = ray.get(
|
ret = ray.get(
|
||||||
actor.encode.remote(request_id=request_id,
|
actor.encode.remote(request_id=request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
lora_request=lora_request))
|
lora_request=lora_request))
|
||||||
|
except ActorDiedError as e:
|
||||||
|
# If the actor is dead, we first try to reinitialize it.
|
||||||
|
logger.warning("%s died with ActorDiedError, reinitializing.",
|
||||||
|
actor,
|
||||||
|
exc_info=e)
|
||||||
|
actor = self._init_actor()
|
||||||
|
try:
|
||||||
|
ret = ray.get(
|
||||||
|
actor.encode.remote(request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
lora_request=lora_request))
|
||||||
|
except ActorDiedError as e:
|
||||||
|
logger.error(
|
||||||
|
"%s died for second time in a row, marking "
|
||||||
|
"RayTokenizerGroupPool as unhealthy.", actor)
|
||||||
|
actor_is_alive = False
|
||||||
|
if not self._exception:
|
||||||
|
self._exception = e
|
||||||
|
self.check_health()
|
||||||
finally:
|
finally:
|
||||||
# Put the actor back in the queue.
|
self._finalize_encode(actor, original_actor, actor_is_alive)
|
||||||
# This is done in a finally block to ensure that the actor is
|
|
||||||
# always put back in the queue, even if an exception/cancellation
|
|
||||||
# is raised.
|
|
||||||
self._idle_actors.put_nowait(actor)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def encode_async(
|
async def encode_async(
|
||||||
@ -120,20 +162,37 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
The actor is then put back in the queue for future use.
|
The actor is then put back in the queue for future use.
|
||||||
This is non-blocking.
|
This is non-blocking.
|
||||||
"""
|
"""
|
||||||
|
self.check_health()
|
||||||
self._ensure_queue_initialized()
|
self._ensure_queue_initialized()
|
||||||
assert self._idle_actors is not None
|
assert self._idle_actors is not None
|
||||||
|
|
||||||
actor = await self._idle_actors.get()
|
actor = await self._idle_actors.get()
|
||||||
|
actor_is_alive = True
|
||||||
|
original_actor = actor
|
||||||
try:
|
try:
|
||||||
ret = await actor.encode.remote(request_id=request_id,
|
ret = await actor.encode.remote(request_id=request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
lora_request=lora_request)
|
lora_request=lora_request)
|
||||||
|
except ActorDiedError as e:
|
||||||
|
# If the actor is dead, we first try to reinitialize it.
|
||||||
|
logger.warning("%s died with ActorDiedError, reinitializing.",
|
||||||
|
actor,
|
||||||
|
exc_info=e)
|
||||||
|
actor = self._init_actor()
|
||||||
|
try:
|
||||||
|
ret = await actor.encode.remote(request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
lora_request=lora_request)
|
||||||
|
except ActorDiedError as e:
|
||||||
|
logger.error(
|
||||||
|
"%s died for second time in a row, marking "
|
||||||
|
"RayTokenizerGroupPool as unhealthy.", actor)
|
||||||
|
actor_is_alive = False
|
||||||
|
if not self._exception:
|
||||||
|
self._exception = e
|
||||||
|
self.check_health()
|
||||||
finally:
|
finally:
|
||||||
# Put the actor back in the queue.
|
self._finalize_encode(actor, original_actor, actor_is_alive)
|
||||||
# This is done in a finally block to ensure that the actor is
|
|
||||||
# always put back in the queue, even if an exception/cancellation
|
|
||||||
# is raised.
|
|
||||||
self._idle_actors.put_nowait(actor)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_max_input_len(self,
|
def get_max_input_len(self,
|
||||||
@ -155,6 +214,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|||||||
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
||||||
lora_request)
|
lora_request)
|
||||||
|
|
||||||
|
def check_health(self):
|
||||||
|
if self._exception:
|
||||||
|
raise RuntimeError(
|
||||||
|
"TokenizerGroupPool is unhealthy.") from self._exception
|
||||||
|
|
||||||
|
|
||||||
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
||||||
"""Copy over all current process environment variables to the runtime_env.
|
"""Copy over all current process environment variables to the runtime_env.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user