bugfix: Fix signature mismatch in benchmark's get_tokenizer
function (#11982)
Signed-off-by: elijah <f1renze.142857@gmail.com>
This commit is contained in:
parent
a7d59688fb
commit
c6db21313c
@ -417,14 +417,35 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
pretrained_model_name_or_path: str, trust_remote_code: bool
|
||||
pretrained_model_name_or_path: str,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||
pretrained_model_name_or_path):
|
||||
pretrained_model_name_or_path = get_model(
|
||||
pretrained_model_name_or_path)
|
||||
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code)
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError(
|
||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError("MistralTokenizer requires vllm package.\n"
|
||||
"Please install it with `pip install vllm` "
|
||||
"to use mistral tokenizer mode.") from e
|
||||
return MistralTokenizer.from_pretrained(
|
||||
str(pretrained_model_name_or_path))
|
||||
else:
|
||||
return AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
|
Loading…
x
Reference in New Issue
Block a user