make sure mistral_common not imported for non-mistral models (#12669)
When people use deepseek models, they find that they need to solve cv2 version conflict, see https://zhuanlan.zhihu.com/p/21064432691 . I added the check, and make all imports of `cv2` lazy. --------- Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
95460fc513
commit
20579c0fae
@ -50,9 +50,9 @@ steps:
|
||||
- tests/multimodal
|
||||
- tests/test_utils
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_torch_compile.py
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_torch_compile.py
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
|
@ -8,7 +8,17 @@ from contextlib import nullcontext
|
||||
|
||||
from vllm_test_utils import BlameResult, blame
|
||||
|
||||
module_name = "torch._inductor.async_compile"
|
||||
# List of modules that should not be imported too early.
|
||||
# Lazy import `torch._inductor.async_compile` to avoid creating
|
||||
# too many processes before we set the number of compiler threads.
|
||||
# Lazy import `cv2` to avoid bothering users who only use text models.
|
||||
# `cv2` can easily mess up the environment.
|
||||
module_names = ["torch._inductor.async_compile", "cv2"]
|
||||
|
||||
|
||||
def any_module_imported():
|
||||
return any(module_name in sys.modules for module_name in module_names)
|
||||
|
||||
|
||||
# In CI, we only check finally if the module is imported.
|
||||
# If it is indeed imported, we can rerun the test with `use_blame=True`,
|
||||
@ -16,8 +26,7 @@ module_name = "torch._inductor.async_compile"
|
||||
# and help find the root cause.
|
||||
# We don't run it in CI by default because it is slow.
|
||||
use_blame = False
|
||||
context = blame(
|
||||
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
||||
context = blame(any_module_imported) if use_blame else nullcontext()
|
||||
with context as result:
|
||||
import vllm # noqa
|
||||
|
||||
@ -25,6 +34,6 @@ if use_blame:
|
||||
assert isinstance(result, BlameResult)
|
||||
print(f"the first import location is:\n{result.trace_stack}")
|
||||
|
||||
assert module_name not in sys.modules, (
|
||||
f"Module {module_name} is imported. To see the first"
|
||||
assert not any_module_imported(), (
|
||||
f"Some the modules in {module_names} are imported. To see the first"
|
||||
f" import location, run the test with `use_blame=True`.")
|
@ -6,7 +6,6 @@ from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
@ -95,6 +94,8 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
new_height, new_width = size
|
||||
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||
dtype=frames.dtype)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
resized_frames[i] = resized_frame
|
||||
|
@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
# yapf: disable
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
# yapf: enable
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# make sure `mistral_common` is lazy imported,
|
||||
# so that users who only use non-mistral models
|
||||
# will not be bothered by the dependency.
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -33,7 +30,7 @@ class Encoding:
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||
# Credits go to: @gcalmettes
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]):
|
||||
|
||||
class MistralTokenizer:
|
||||
|
||||
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import (
|
||||
SpecialTokenPolicy, Tekkenizer)
|
||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||
if self.is_tekken:
|
||||
# Make sure special tokens will not raise
|
||||
@ -153,6 +154,8 @@ class MistralTokenizer:
|
||||
assert Path(
|
||||
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||
return cls(mistral_tokenizer)
|
||||
|
||||
@ -181,6 +184,8 @@ class MistralTokenizer:
|
||||
# by the guided structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# tekken defines its own extended special tokens list
|
||||
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
|
||||
special_tokens = self.tokenizer.SPECIAL_TOKENS
|
||||
@ -284,6 +289,8 @@ class MistralTokenizer:
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest)
|
||||
request = ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
@ -292,6 +299,7 @@ class MistralTokenizer:
|
||||
return encoded.tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
if self.is_tekken:
|
||||
tokens = [
|
||||
t for t in tokens
|
||||
@ -363,6 +371,8 @@ class MistralTokenizer:
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
skip_special_tokens
|
||||
|
Loading…
x
Reference in New Issue
Block a user