make sure mistral_common not imported for non-mistral models (#12669)
When people use deepseek models, they find that they need to solve cv2 version conflict, see https://zhuanlan.zhihu.com/p/21064432691 . I added the check, and make all imports of `cv2` lazy. --------- Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
95460fc513
commit
20579c0fae
@ -50,9 +50,9 @@ steps:
|
|||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/test_utils
|
- tests/test_utils
|
||||||
- tests/worker
|
- tests/worker
|
||||||
- tests/standalone_tests/lazy_torch_compile.py
|
- tests/standalone_tests/lazy_imports.py
|
||||||
commands:
|
commands:
|
||||||
- python3 standalone_tests/lazy_torch_compile.py
|
- python3 standalone_tests/lazy_imports.py
|
||||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||||
- pytest -v -s async_engine # AsyncLLMEngine
|
- pytest -v -s async_engine # AsyncLLMEngine
|
||||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||||
|
@ -8,7 +8,17 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
from vllm_test_utils import BlameResult, blame
|
from vllm_test_utils import BlameResult, blame
|
||||||
|
|
||||||
module_name = "torch._inductor.async_compile"
|
# List of modules that should not be imported too early.
|
||||||
|
# Lazy import `torch._inductor.async_compile` to avoid creating
|
||||||
|
# too many processes before we set the number of compiler threads.
|
||||||
|
# Lazy import `cv2` to avoid bothering users who only use text models.
|
||||||
|
# `cv2` can easily mess up the environment.
|
||||||
|
module_names = ["torch._inductor.async_compile", "cv2"]
|
||||||
|
|
||||||
|
|
||||||
|
def any_module_imported():
|
||||||
|
return any(module_name in sys.modules for module_name in module_names)
|
||||||
|
|
||||||
|
|
||||||
# In CI, we only check finally if the module is imported.
|
# In CI, we only check finally if the module is imported.
|
||||||
# If it is indeed imported, we can rerun the test with `use_blame=True`,
|
# If it is indeed imported, we can rerun the test with `use_blame=True`,
|
||||||
@ -16,8 +26,7 @@ module_name = "torch._inductor.async_compile"
|
|||||||
# and help find the root cause.
|
# and help find the root cause.
|
||||||
# We don't run it in CI by default because it is slow.
|
# We don't run it in CI by default because it is slow.
|
||||||
use_blame = False
|
use_blame = False
|
||||||
context = blame(
|
context = blame(any_module_imported) if use_blame else nullcontext()
|
||||||
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
|
||||||
with context as result:
|
with context as result:
|
||||||
import vllm # noqa
|
import vllm # noqa
|
||||||
|
|
||||||
@ -25,6 +34,6 @@ if use_blame:
|
|||||||
assert isinstance(result, BlameResult)
|
assert isinstance(result, BlameResult)
|
||||||
print(f"the first import location is:\n{result.trace_stack}")
|
print(f"the first import location is:\n{result.trace_stack}")
|
||||||
|
|
||||||
assert module_name not in sys.modules, (
|
assert not any_module_imported(), (
|
||||||
f"Module {module_name} is imported. To see the first"
|
f"Some the modules in {module_names} are imported. To see the first"
|
||||||
f" import location, run the test with `use_blame=True`.")
|
f" import location, run the test with `use_blame=True`.")
|
@ -6,7 +6,6 @@ from io import BytesIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -95,6 +94,8 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
|||||||
new_height, new_width = size
|
new_height, new_width = size
|
||||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||||
dtype=frames.dtype)
|
dtype=frames.dtype)
|
||||||
|
# lazy import cv2 to avoid bothering users who only use text models
|
||||||
|
import cv2
|
||||||
for i, frame in enumerate(frames):
|
for i, frame in enumerate(frames):
|
||||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||||
resized_frames[i] = resized_frame
|
resized_frames[i] = resized_frame
|
||||||
|
@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
|
||||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
|
||||||
# yapf: disable
|
|
||||||
from mistral_common.tokens.tokenizers.mistral import (
|
|
||||||
MistralTokenizer as PublicMistralTokenizer)
|
|
||||||
# yapf: enable
|
|
||||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
|
||||||
SentencePieceTokenizer)
|
|
||||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
|
||||||
Tekkenizer)
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
# make sure `mistral_common` is lazy imported,
|
||||||
|
# so that users who only use non-mistral models
|
||||||
|
# will not be bothered by the dependency.
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import (
|
||||||
|
MistralTokenizer as PublicMistralTokenizer)
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -33,7 +30,7 @@ class Encoding:
|
|||||||
input_ids: Union[List[int], List[List[int]]]
|
input_ids: Union[List[int], List[List[int]]]
|
||||||
|
|
||||||
|
|
||||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||||
# Credits go to: @gcalmettes
|
# Credits go to: @gcalmettes
|
||||||
# NOTE: There is currently a bug in pydantic where attributes
|
# NOTE: There is currently a bug in pydantic where attributes
|
||||||
@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]):
|
|||||||
|
|
||||||
class MistralTokenizer:
|
class MistralTokenizer:
|
||||||
|
|
||||||
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||||
self.mistral = tokenizer
|
self.mistral = tokenizer
|
||||||
self.instruct = tokenizer.instruct_tokenizer
|
self.instruct = tokenizer.instruct_tokenizer
|
||||||
|
|
||||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import (
|
||||||
|
SpecialTokenPolicy, Tekkenizer)
|
||||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||||
|
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||||
|
SentencePieceTokenizer)
|
||||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
# Make sure special tokens will not raise
|
# Make sure special tokens will not raise
|
||||||
@ -153,6 +154,8 @@ class MistralTokenizer:
|
|||||||
assert Path(
|
assert Path(
|
||||||
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||||
|
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import (
|
||||||
|
MistralTokenizer as PublicMistralTokenizer)
|
||||||
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||||
return cls(mistral_tokenizer)
|
return cls(mistral_tokenizer)
|
||||||
|
|
||||||
@ -181,6 +184,8 @@ class MistralTokenizer:
|
|||||||
# by the guided structured output backends.
|
# by the guided structured output backends.
|
||||||
@property
|
@property
|
||||||
def all_special_tokens_extended(self) -> List[str]:
|
def all_special_tokens_extended(self) -> List[str]:
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
|
|
||||||
# tekken defines its own extended special tokens list
|
# tekken defines its own extended special tokens list
|
||||||
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
|
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
|
||||||
special_tokens = self.tokenizer.SPECIAL_TOKENS
|
special_tokens = self.tokenizer.SPECIAL_TOKENS
|
||||||
@ -284,6 +289,8 @@ class MistralTokenizer:
|
|||||||
if last_message["role"] == "assistant":
|
if last_message["role"] == "assistant":
|
||||||
last_message["prefix"] = True
|
last_message["prefix"] = True
|
||||||
|
|
||||||
|
from mistral_common.protocol.instruct.request import (
|
||||||
|
ChatCompletionRequest)
|
||||||
request = ChatCompletionRequest(messages=messages,
|
request = ChatCompletionRequest(messages=messages,
|
||||||
tools=tools) # type: ignore[type-var]
|
tools=tools) # type: ignore[type-var]
|
||||||
encoded = self.mistral.encode_chat_completion(request)
|
encoded = self.mistral.encode_chat_completion(request)
|
||||||
@ -292,6 +299,7 @@ class MistralTokenizer:
|
|||||||
return encoded.tokens
|
return encoded.tokens
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
tokens = [
|
tokens = [
|
||||||
t for t in tokens
|
t for t in tokens
|
||||||
@ -363,6 +371,8 @@ class MistralTokenizer:
|
|||||||
ids: List[int],
|
ids: List[int],
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
|
|
||||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||||
assert (
|
assert (
|
||||||
skip_special_tokens
|
skip_special_tokens
|
||||||
|
Loading…
x
Reference in New Issue
Block a user