Use slow tokenizer for LLaMA (#84)

This commit is contained in:
Woosuk Kwon 2023-05-09 16:03:44 -07:00 committed by GitHub
parent add055e151
commit 85eb631839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 5 deletions

View File

@ -7,12 +7,12 @@ from typing import List, Dict, Optional
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import ray
from transformers import AutoTokenizer
import uvicorn
from cacheflow.core.server import (Server, add_server_arguments,
process_server_arguments,
initialize_cluster)
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
@ -44,7 +44,7 @@ class FastAPIServer:
):
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.tokenizer = get_tokenizer(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
if server_use_ray:

View File

@ -1,8 +1,7 @@
import time
from typing import List, Optional, Tuple
from transformers import AutoTokenizer
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
@ -21,7 +20,7 @@ class SimpleFrontend:
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer = get_tokenizer(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []

View File

@ -0,0 +1,22 @@
from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)