vllm/tests/lora/test_tokenizer.py
Antoni Baum 9b945daaf1
[Experimental] Add multi-LoRA support (#1804)
Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
2024-01-23 15:26:37 -08:00

70 lines
2.7 KiB
Python

import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer
@pytest.mark.asyncio
async def test_transformers_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=None)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
@pytest.mark.asyncio
async def test_transformers_tokenizer_lora(sql_lora_files):
reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
tokenizer = TokenizerGroup(
tokenizer_id="gpt2",
enable_lora=True,
max_num_seqs=1,
max_input_length=None,
)
lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode(
"prompt") == await tokenizer.encode_async(request_id="request_id",
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
None) == await tokenizer.get_lora_tokenizer_async(None)
assert isinstance(tokenizer.get_lora_tokenizer(lora_request),
PreTrainedTokenizerBase)
assert tokenizer.get_lora_tokenizer(
lora_request) != tokenizer.get_lora_tokenizer(None)
assert tokenizer.get_lora_tokenizer(
lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request)
def test_get_lora_tokenizer(sql_lora_files, tmpdir):
lora_request = None
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer
lora_request = LoRARequest("1", 1, sql_lora_files)
tokenizer = get_lora_tokenizer(lora_request)
assert tokenizer.get_added_vocab()
lora_request = LoRARequest("1", 1, str(tmpdir))
tokenizer = get_lora_tokenizer(lora_request)
assert not tokenizer