2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
from copy import deepcopy
|
2024-03-25 23:59:47 +09:00
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
|
|
|
|
2024-03-15 16:37:01 -07:00
|
|
|
|
|
|
|
def test_cached_tokenizer():
|
|
|
|
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
|
|
|
|
reference_tokenizer.add_special_tokens(
|
|
|
|
{"additional_special_tokens": ["<SEP>"]})
|
|
|
|
cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
|
|
|
|
|
|
|
|
assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
|
|
|
|
"prompt")
|
|
|
|
assert set(reference_tokenizer.all_special_ids) == set(
|
|
|
|
cached_tokenizer.all_special_ids)
|
|
|
|
assert set(reference_tokenizer.all_special_tokens) == set(
|
|
|
|
cached_tokenizer.all_special_tokens)
|
|
|
|
assert set(reference_tokenizer.all_special_tokens_extended) == set(
|
|
|
|
cached_tokenizer.all_special_tokens_extended)
|