2024-03-15 16:37:01 -07:00
|
|
|
import asyncio
|
2024-03-25 23:59:47 +09:00
|
|
|
import os
|
2024-03-15 16:37:01 -07:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
import pytest
|
2024-03-15 16:37:01 -07:00
|
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
|
|
|
|
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
|
|
|
RayTokenizerGroupPool)
|
|
|
|
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
|
|
|
TokenizerGroup)
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
from ..conftest import get_tokenizer_pool_config
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
|
|
|
|
async def test_tokenizer_group(tokenizer_group_type):
|
|
|
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
tokenizer_group = get_tokenizer_group(
|
|
|
|
get_tokenizer_pool_config(tokenizer_group_type),
|
|
|
|
tokenizer_id="gpt2",
|
|
|
|
enable_lora=False,
|
|
|
|
max_num_seqs=1,
|
|
|
|
max_input_length=None,
|
|
|
|
)
|
|
|
|
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
|
|
|
request_id="request_id", prompt="prompt", lora_request=None)
|
|
|
|
assert reference_tokenizer.encode(
|
|
|
|
"prompt") == await tokenizer_group.encode_async(
|
|
|
|
request_id="request_id", prompt="prompt", lora_request=None)
|
|
|
|
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
|
|
|
PreTrainedTokenizerBase)
|
|
|
|
assert tokenizer_group.get_lora_tokenizer(
|
|
|
|
None) == await tokenizer_group.get_lora_tokenizer_async(None)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
|
|
|
async def test_tokenizer_group_pool(tokenizer_group_type):
|
|
|
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
tokenizer_group_pool = get_tokenizer_group(
|
|
|
|
get_tokenizer_pool_config(tokenizer_group_type),
|
|
|
|
tokenizer_id="gpt2",
|
|
|
|
enable_lora=False,
|
|
|
|
max_num_seqs=1,
|
|
|
|
max_input_length=None,
|
|
|
|
)
|
|
|
|
# Send multiple requests to the tokenizer group pool
|
|
|
|
# (more than the pool size)
|
|
|
|
# and check that all requests are processed correctly.
|
|
|
|
num_requests = tokenizer_group_pool.pool_size * 5
|
|
|
|
requests = [
|
|
|
|
tokenizer_group_pool.encode_async(request_id=str(i),
|
|
|
|
prompt=f"prompt {i}",
|
|
|
|
lora_request=None)
|
|
|
|
for i in range(num_requests)
|
|
|
|
]
|
|
|
|
results = await asyncio.gather(*requests)
|
|
|
|
expected_results = [
|
|
|
|
reference_tokenizer.encode(f"prompt {i}") for i in range(num_requests)
|
|
|
|
]
|
|
|
|
assert results == expected_results
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
|
|
|
async def test_tokenizer_group_ray_pool_env_var_propagation(
|
|
|
|
tokenizer_group_type):
|
|
|
|
"""Test that env vars from caller process are propagated to
|
|
|
|
tokenizer Ray actors."""
|
|
|
|
env_var = "MY_ENV_VAR"
|
|
|
|
|
|
|
|
class EnvVarCheckerTokenizerGroup(TokenizerGroup):
|
|
|
|
|
|
|
|
def ping(self):
|
|
|
|
assert os.environ.get(env_var) == "1"
|
|
|
|
return super().ping()
|
|
|
|
|
|
|
|
class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
|
|
|
|
_worker_cls = EnvVarCheckerTokenizerGroup
|
|
|
|
|
|
|
|
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
|
|
|
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
|
|
|
tokenizer_pool_config,
|
|
|
|
tokenizer_id="gpt2",
|
|
|
|
enable_lora=False,
|
|
|
|
max_num_seqs=1,
|
|
|
|
max_input_length=None)
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
tokenizer_pool.ping()
|
|
|
|
|
|
|
|
with patch.dict(os.environ, {env_var: "1"}):
|
|
|
|
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
|
|
|
tokenizer_pool = EnvVarCheckerRayTokenizerGroupPool.from_config(
|
|
|
|
tokenizer_pool_config,
|
|
|
|
tokenizer_id="gpt2",
|
|
|
|
enable_lora=False,
|
|
|
|
max_num_seqs=1,
|
|
|
|
max_input_length=None)
|
|
|
|
tokenizer_pool.ping()
|