
Co-authored-by: Chen Shen <scv119@gmail.com> Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com> Co-authored-by: Avnish Narayan <avnish@anyscale.com>
70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
import pytest
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transformers_tokenizer():
|
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
tokenizer = TokenizerGroup(
|
|
tokenizer_id="gpt2",
|
|
enable_lora=False,
|
|
max_num_seqs=1,
|
|
max_input_length=None,
|
|
)
|
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
|
request_id="request_id", prompt="prompt", lora_request=None)
|
|
assert reference_tokenizer.encode(
|
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
|
prompt="prompt",
|
|
lora_request=None)
|
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
|
PreTrainedTokenizerBase)
|
|
assert tokenizer.get_lora_tokenizer(
|
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transformers_tokenizer_lora(sql_lora_files):
|
|
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
|
|
tokenizer = TokenizerGroup(
|
|
tokenizer_id="gpt2",
|
|
enable_lora=True,
|
|
max_num_seqs=1,
|
|
max_input_length=None,
|
|
)
|
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
|
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
|
|
request_id="request_id", prompt="prompt", lora_request=lora_request)
|
|
assert reference_tokenizer.encode(
|
|
"prompt") == await tokenizer.encode_async(request_id="request_id",
|
|
prompt="prompt",
|
|
lora_request=lora_request)
|
|
assert isinstance(tokenizer.get_lora_tokenizer(None),
|
|
PreTrainedTokenizerBase)
|
|
assert tokenizer.get_lora_tokenizer(
|
|
None) == await tokenizer.get_lora_tokenizer_async(None)
|
|
|
|
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
|
|
PreTrainedTokenizerBase)
|
|
assert tokenizer.get_lora_tokenizer(
|
|
lora_request) != tokenizer.get_lora_tokenizer(None)
|
|
assert tokenizer.get_lora_tokenizer(
|
|
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
|
|
|
|
|
|
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
|
|
lora_request = None
|
|
tokenizer = get_lora_tokenizer(lora_request)
|
|
assert not tokenizer
|
|
|
|
lora_request = LoRARequest("1", 1, sql_lora_files)
|
|
tokenizer = get_lora_tokenizer(lora_request)
|
|
assert tokenizer.get_added_vocab()
|
|
|
|
lora_request = LoRARequest("1", 1, str(tmpdir))
|
|
tokenizer = get_lora_tokenizer(lora_request)
|
|
assert not tokenizer
|