Use slow tokenizer for LLaMA (#84)

This commit is contained in:
Woosuk Kwon 2023-05-09 16:03:44 -07:00 committed by GitHub
parent add055e151
commit 85eb631839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 5 deletions

View File

@ -7,12 +7,12 @@ from typing import List, Dict, Optional
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import ray import ray
from transformers import AutoTokenizer
import uvicorn import uvicorn
from cacheflow.core.server import (Server, add_server_arguments, from cacheflow.core.server import (Server, add_server_arguments,
process_server_arguments, process_server_arguments,
initialize_cluster) initialize_cluster)
from cacheflow.frontend.utils import get_tokenizer
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
@ -44,7 +44,7 @@ class FastAPIServer:
): ):
self.block_size = block_size self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model) self.tokenizer = get_tokenizer(model)
self.seq_group_counter = Counter() self.seq_group_counter = Counter()
self.seq_counter = Counter() self.seq_counter = Counter()
if server_use_ray: if server_use_ray:

View File

@ -1,8 +1,7 @@
import time import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from transformers import AutoTokenizer from cacheflow.frontend.utils import get_tokenizer
from cacheflow.logger import init_logger from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup from cacheflow.sequence import Sequence, SequenceGroup
@ -21,7 +20,7 @@ class SimpleFrontend:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = get_tokenizer(model_name)
self.seq_group_counter = Counter() self.seq_group_counter = Counter()
self.seq_counter = Counter() self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = [] self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []

View File

@ -0,0 +1,22 @@
from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)