Use slow tokenizer for LLaMA (#84)
This commit is contained in:
parent
add055e151
commit
85eb631839
@ -7,12 +7,12 @@ from typing import List, Dict, Optional
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import ray
|
||||
from transformers import AutoTokenizer
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.core.server import (Server, add_server_arguments,
|
||||
process_server_arguments,
|
||||
initialize_cluster)
|
||||
from cacheflow.frontend.utils import get_tokenizer
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||
@ -44,7 +44,7 @@ class FastAPIServer:
|
||||
):
|
||||
self.block_size = block_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
self.tokenizer = get_tokenizer(model)
|
||||
self.seq_group_counter = Counter()
|
||||
self.seq_counter = Counter()
|
||||
if server_use_ray:
|
||||
|
@ -1,8 +1,7 @@
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from cacheflow.frontend.utils import get_tokenizer
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup
|
||||
@ -21,7 +20,7 @@ class SimpleFrontend:
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.tokenizer = get_tokenizer(model_name)
|
||||
self.seq_group_counter = Counter()
|
||||
self.seq_counter = Counter()
|
||||
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
||||
|
22
cacheflow/frontend/utils.py
Normal file
22
cacheflow/frontend/utils.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Union
|
||||
|
||||
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
|
||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
||||
# LLaMA fast tokenizer has a bug related to protobuf.
|
||||
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
||||
"llama",
|
||||
]
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||
kwargs["use_fast"] = False
|
||||
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
|
Loading…
x
Reference in New Issue
Block a user