Asynchronous tokenization (#2879)

This commit is contained in:
Antoni Baum 2024-03-15 16:37:01 -07:00 committed by GitHub
parent 8fa7357f2d
commit fb96c1e98c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 658 additions and 153 deletions

View File

@ -28,7 +28,7 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
- label: Engine Test
command: pytest -v -s engine test_sequence.py
command: pytest -v -s engine tokenization test_sequence.py
- label: Entrypoints Test
command: pytest -v -s entrypoints

View File

@ -25,23 +25,21 @@ def _query_server_long(prompt: str) -> dict:
@pytest.fixture
def api_server():
def api_server(tokenizer_pool_size: int):
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([
sys.executable,
"-u",
str(script_path),
"--model",
"facebook/opt-125m",
"--host",
"127.0.0.1",
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m", "--host",
"127.0.0.1", "--tokenizer-pool-size",
str(tokenizer_pool_size)
])
yield
uvicorn_process.terminate()
def test_api_server(api_server):
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
def test_api_server(api_server, tokenizer_pool_size: int):
"""
Run the API server and test it.

View File

@ -7,6 +7,7 @@ from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.config import TokenizerPoolConfig
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
@ -258,3 +259,13 @@ class VllmRunner:
@pytest.fixture
def vllm_runner():
return VllmRunner
def get_tokenizer_pool_config(tokenizer_group_type):
if tokenizer_group_type is None:
return None
if tokenizer_group_type == "ray":
return TokenizerPoolConfig(pool_size=1,
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")

View File

@ -1,69 +0,0 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
@pytest.mark.asyncio
async def test_transformers_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=None)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
@pytest.mark.asyncio
async def test_transformers_tokenizer_lora(sql_lora_files):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
lora_request) != tokenizer.get_lora_tokenizer(None)
assert tokenizer.get_lora_tokenizer(
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

View File

@ -0,0 +1,53 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer import get_lora_tokenizer
from ..conftest import get_tokenizer_pool_config
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
lora_request) != tokenizer_group.get_lora_tokenizer(None)
assert tokenizer_group.get_lora_tokenizer(
lora_request) == await tokenizer_group.get_lora_tokenizer_async(
lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer

View File

View File

@ -0,0 +1,20 @@
from copy import deepcopy
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from transformers import AutoTokenizer
def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)

View File

@ -0,0 +1,100 @@
import os
import pytest
import asyncio
from unittest.mock import patch
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from ..conftest import get_tokenizer_pool_config
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
async def test_tokenizer_group(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
None) == await tokenizer_group.get_lora_tokenizer_async(None)
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_pool(tokenizer_group_type):
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer_group_pool = get_tokenizer_group(
get_tokenizer_pool_config(tokenizer_group_type),
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
# Send multiple requests to the tokenizer group pool
# (more than the pool size)
# and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5
requests = [
tokenizer_group_pool.encode_async(request_id=str(i),
prompt=f"prompt {i}",
lora_request=None)
for i in range(num_requests)
]
results = await asyncio.gather(*requests)
expected_results = [
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
]
assert results == expected_results
@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_env_var_propagation(
tokenizer_group_type):
"""Test that env vars from caller process are propagated to
tokenizer Ray actors."""
env_var = "MY_ENV_VAR"
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
def ping(self):
assert os.environ.get(env_var) == "1"
return super().ping()
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = EnvVarCheckerTokenizerGroup
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
with pytest.raises(AssertionError):
tokenizer_pool.ping()
with patch.dict(os.environ, {env_var: "1"}):
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
import os
from packaging.version import Version
import json
import torch
from transformers import PretrainedConfig
@ -389,6 +390,58 @@ class CacheConfig:
logger.warning("Possibly too large swap space. " + msg)
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: str
extra_config: dict
def __post_init__(self):
if self.pool_type not in ("ray", ):
raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.")
@classmethod
def create_config(
cls, tokenizer_pool_size: int, tokenizer_pool_type: str,
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
If tokenizer_pool_size is 0, return None.
Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
tokenizer_pool_extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type. This can be a JSON string (will be parsed).
"""
if tokenizer_pool_size:
if isinstance(tokenizer_pool_extra_config, str):
tokenizer_pool_extra_config_parsed = json.loads(
tokenizer_pool_extra_config)
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
class ParallelConfig:
"""Configuration for the distributed execution.
@ -403,6 +456,8 @@ class ParallelConfig:
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
tokenizer_pool_config: Config for the tokenizer pool.
If None, will use synchronous tokenization.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
"""
@ -414,6 +469,7 @@ class ParallelConfig:
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None:
@ -430,6 +486,7 @@ class ParallelConfig:
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.tokenizer_pool_config = tokenizer_pool_config
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

View File

@ -4,7 +4,8 @@ from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
ParallelConfig, SchedulerConfig, LoRAConfig,
TokenizerPoolConfig)
@dataclass
@ -40,6 +41,9 @@ class EngineArgs:
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray"
tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
@ -249,6 +253,25 @@ class EngineArgs:
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
@ -312,14 +335,16 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
model_config.get_sliding_window())
parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.max_parallel_loading_workers,
self.disable_custom_all_reduce,
self.ray_workers_use_nsight)
TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,

View File

@ -17,8 +17,9 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
from vllm.transformers_utils.tokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.utils import Counter
logger = init_logger(__name__)
@ -102,6 +103,10 @@ class LLMEngine:
parallel_config, scheduler_config,
device_config, lora_config)
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
@ -152,6 +157,7 @@ class LLMEngine:
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
@ -159,8 +165,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(
self.model_config.tokenizer, **init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)

View File

@ -5,12 +5,48 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import make_async, LRUCache
from vllm.utils import make_async
from vllm.transformers_utils.tokenizers import *
logger = init_logger(__name__)
def get_cached_tokenizer(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
class CachedTokenizer(tokenizer.__class__):
@property
def all_special_ids(self):
return tokenizer_all_special_ids
@property
def all_special_tokens(self):
return tokenizer_all_special_tokens
@property
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer
return tokenizer
def get_tokenizer(
tokenizer_name: str,
*args,
@ -64,7 +100,7 @@ def get_tokenizer(
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return tokenizer
return get_cached_tokenizer(tokenizer)
def get_lora_tokenizer(lora_request: LoRARequest, *args,
@ -88,65 +124,6 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
get_lora_tokenizer_async = make_async(get_lora_tokenizer)
class TokenizerGroup:
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],

View File

@ -0,0 +1,32 @@
from typing import Optional
from vllm.config import TokenizerPoolConfig
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from vllm.engine.ray_utils import ray
if ray:
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
RayTokenizerGroupPool)
else:
RayTokenizerGroupPool = None
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
**init_kwargs) -> BaseTokenizerGroup:
if tokenizer_pool_config is None:
return TokenizerGroup(**init_kwargs)
if tokenizer_pool_config.pool_type == "ray":
if RayTokenizerGroupPool is None:
raise ImportError(
"RayTokenizerGroupPool is not available. Please install "
"the ray package to use the Ray tokenizer group pool.")
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
**init_kwargs)
else:
raise ValueError(
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]

View File

@ -0,0 +1,48 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
class BaseTokenizerGroup(ABC):
"""A group of tokenizers that can be used for LoRA adapters."""
@abstractmethod
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
pass
@abstractmethod
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
pass
@abstractmethod
def encode(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

View File

@ -0,0 +1,166 @@
import asyncio
import os
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.config import TokenizerPoolConfig
from vllm.lora.request import LoRARequest
from vllm.engine.ray_utils import ray
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
# Class to use for workers making up the pool.
_worker_cls = TokenizerGroup
@classmethod
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
**init_kwargs) -> "RayTokenizerGroupPool":
ray_actor_options = (tokenizer_pool_config.extra_config or {
"num_cpus": 0
})
ray_actor_options.setdefault(
"scheduling_strategy",
NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(), soft=True))
# Carry over the env vars to the actors.
# This is necessary for API keys and such.
ray_actor_options.setdefault("runtime_env", {})
_carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
init_kwargs["ray_actor_options"] = ray_actor_options
return cls(**init_kwargs)
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], num_actors: int,
ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self._local_tokenizer_group = self._worker_cls(
tokenizer_id=tokenizer_id,
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_input_length=max_input_length,
)
ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self.tokenizer_actors = [
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
max_num_seqs, max_input_length,
**tokenizer_config)
for _ in range(num_actors)
]
self._idle_actors: Optional[asyncio.Queue] = None
@property
def pool_size(self) -> int:
return len(self.tokenizer_actors)
def ping(self):
return ray.get(
[actor.ping.remote() for actor in self.tokenizer_actors])
def _ensure_queue_initialized(self):
if self._idle_actors is None:
self._idle_actors = asyncio.Queue()
for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor)
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
The actor is then put back in the queue for future use.
This is blocking.
"""
self._ensure_queue_initialized()
if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait()
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
If there are no idle actors, we wait until one becomes
available.
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self._ensure_queue_initialized()
actor = await self._idle_actors.get()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
return ret
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self._local_tokenizer_group.get_max_input_len(lora_request)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env.
The variables in runtime_env will take precedence over the current process
environment variables.
runtime_env will be modified in place."""
env_vars = os.environ.copy()
runtime_env.setdefault("env_vars", {})
env_vars.update(runtime_env["env_vars"])
runtime_env["env_vars"] = env_vars

View File

@ -0,0 +1,80 @@
from typing import List, Optional
from transformers import PreTrainedTokenizer
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
get_lora_tokenizer_async)
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.utils import LRUCache
from vllm.transformers_utils.tokenizer import get_tokenizer
class TokenizerGroup(BaseTokenizerGroup):
"""A group of tokenizers that can be used for LoRA adapters."""
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
max_input_length: Optional[int], **tokenizer_config):
self.tokenizer_id = tokenizer_id
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
return True
def get_max_input_len(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
"""Get the maximum input length for the LoRA request."""
return self.max_input_length
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
tokenizer = (await get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)