Use slow tokenizer for LLaMA (#84)
This commit is contained in:
parent
add055e151
commit
85eb631839
@ -7,12 +7,12 @@ from typing import List, Dict, Optional
|
|||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import ray
|
import ray
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from cacheflow.core.server import (Server, add_server_arguments,
|
from cacheflow.core.server import (Server, add_server_arguments,
|
||||||
process_server_arguments,
|
process_server_arguments,
|
||||||
initialize_cluster)
|
initialize_cluster)
|
||||||
|
from cacheflow.frontend.utils import get_tokenizer
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
from cacheflow.sequence import Sequence, SequenceGroup
|
||||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||||
@ -44,7 +44,7 @@ class FastAPIServer:
|
|||||||
):
|
):
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model)
|
self.tokenizer = get_tokenizer(model)
|
||||||
self.seq_group_counter = Counter()
|
self.seq_group_counter = Counter()
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
if server_use_ray:
|
if server_use_ray:
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
from cacheflow.frontend.utils import get_tokenizer
|
||||||
|
|
||||||
from cacheflow.logger import init_logger
|
from cacheflow.logger import init_logger
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup
|
from cacheflow.sequence import Sequence, SequenceGroup
|
||||||
@ -21,7 +20,7 @@ class SimpleFrontend:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
self.tokenizer = get_tokenizer(model_name)
|
||||||
self.seq_group_counter = Counter()
|
self.seq_group_counter = Counter()
|
||||||
self.seq_counter = Counter()
|
self.seq_counter = Counter()
|
||||||
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
|
||||||
|
22
cacheflow/frontend/utils.py
Normal file
22
cacheflow/frontend/utils.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
|
||||||
|
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
||||||
|
# LLaMA fast tokenizer has a bug related to protobuf.
|
||||||
|
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
||||||
|
"llama",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(
|
||||||
|
model_name: str,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||||
|
kwargs["use_fast"] = False
|
||||||
|
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
|
Loading…
x
Reference in New Issue
Block a user