make sure mistral_common not imported for non-mistral models (#12669)

When people use deepseek models, they find that they need to solve cv2
version conflict, see https://zhuanlan.zhihu.com/p/21064432691 .

I added the check, and make all imports of `cv2` lazy.

---------

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-03 13:40:25 +08:00 committed by GitHub
parent 95460fc513
commit 20579c0fae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 20 deletions

View File

@ -50,9 +50,9 @@ steps:
- tests/multimodal
- tests/test_utils
- tests/worker
- tests/standalone_tests/lazy_torch_compile.py
- tests/standalone_tests/lazy_imports.py
commands:
- python3 standalone_tests/lazy_torch_compile.py
- python3 standalone_tests/lazy_imports.py
- pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py

View File

@ -8,7 +8,17 @@ from contextlib import nullcontext
from vllm_test_utils import BlameResult, blame
module_name = "torch._inductor.async_compile"
# List of modules that should not be imported too early.
# Lazy import `torch._inductor.async_compile` to avoid creating
# too many processes before we set the number of compiler threads.
# Lazy import `cv2` to avoid bothering users who only use text models.
# `cv2` can easily mess up the environment.
module_names = ["torch._inductor.async_compile", "cv2"]
def any_module_imported():
return any(module_name in sys.modules for module_name in module_names)
# In CI, we only check finally if the module is imported.
# If it is indeed imported, we can rerun the test with `use_blame=True`,
@ -16,8 +26,7 @@ module_name = "torch._inductor.async_compile"
# and help find the root cause.
# We don't run it in CI by default because it is slow.
use_blame = False
context = blame(
lambda: module_name in sys.modules) if use_blame else nullcontext()
context = blame(any_module_imported) if use_blame else nullcontext()
with context as result:
import vllm # noqa
@ -25,6 +34,6 @@ if use_blame:
assert isinstance(result, BlameResult)
print(f"the first import location is:\n{result.trace_stack}")
assert module_name not in sys.modules, (
f"Module {module_name} is imported. To see the first"
assert not any_module_imported(), (
f"Some the modules in {module_names} are imported. To see the first"
f" import location, run the test with `use_blame=True`.")

View File

@ -6,7 +6,6 @@ from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
import cv2
import numpy as np
import numpy.typing as npt
from PIL import Image
@ -95,6 +94,8 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
new_height, new_width = size
resized_frames = np.empty((num_frames, new_height, new_width, channels),
dtype=frames.dtype)
# lazy import cv2 to avoid bothering users who only use text models
import cv2
for i, frame in enumerate(frames):
resized_frame = cv2.resize(frame, (new_width, new_height))
resized_frames[i] = resized_frame

View File

@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
# yapf: enable
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
from vllm.logger import init_logger
from vllm.utils import is_list_of
if TYPE_CHECKING:
# make sure `mistral_common` is lazy imported,
# so that users who only use non-mistral models
# will not be bothered by the dependency.
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
logger = init_logger(__name__)
@ -33,7 +30,7 @@ class Encoding:
input_ids: Union[List[int], List[List[int]]]
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]):
class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.tekken import (
SpecialTokenPolicy, Tekkenizer)
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise
@ -153,6 +154,8 @@ class MistralTokenizer:
assert Path(
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@ -181,6 +184,8 @@ class MistralTokenizer:
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
@ -284,6 +289,8 @@ class MistralTokenizer:
if last_message["role"] == "assistant":
last_message["prefix"] = True
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest)
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
@ -292,6 +299,7 @@ class MistralTokenizer:
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens
if self.is_tekken:
tokens = [
t for t in tokens
@ -363,6 +371,8 @@ class MistralTokenizer:
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens