bugfix: Fix signature mismatch in benchmark's get_tokenizer
function (#11982)
Signed-off-by: elijah <f1renze.142857@gmail.com>
This commit is contained in:
parent
a7d59688fb
commit
c6db21313c
@ -417,14 +417,35 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
pretrained_model_name_or_path: str, trust_remote_code: bool
|
pretrained_model_name_or_path: str,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||||
pretrained_model_name_or_path):
|
pretrained_model_name_or_path):
|
||||||
pretrained_model_name_or_path = get_model(
|
pretrained_model_name_or_path = get_model(
|
||||||
pretrained_model_name_or_path)
|
pretrained_model_name_or_path)
|
||||||
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
|
if tokenizer_mode == "slow":
|
||||||
trust_remote_code=trust_remote_code)
|
if kwargs.get("use_fast", False):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
|
kwargs["use_fast"] = False
|
||||||
|
if tokenizer_mode == "mistral":
|
||||||
|
try:
|
||||||
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("MistralTokenizer requires vllm package.\n"
|
||||||
|
"Please install it with `pip install vllm` "
|
||||||
|
"to use mistral tokenizer mode.") from e
|
||||||
|
return MistralTokenizer.from_pretrained(
|
||||||
|
str(pretrained_model_name_or_path))
|
||||||
|
else:
|
||||||
|
return AutoTokenizer.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ASYNC_REQUEST_FUNCS = {
|
ASYNC_REQUEST_FUNCS = {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user