[Core] Update dtype detection and defaults (#14858)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8b3e94a357
commit
f690372b68
@ -60,7 +60,7 @@ class TestSetting:
|
||||
# embedding model
|
||||
TestSetting(
|
||||
model="BAAI/bge-multilingual-gemma2",
|
||||
model_args=["--task", "embed"],
|
||||
model_args=["--task", "embed", "--dtype", "bfloat16"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
|
@ -14,8 +14,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
|
||||
BatchFeature)
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BatchEncoding, BatchFeature)
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from tests.models.utils import (TokensTextLogprobs,
|
||||
@ -23,7 +23,7 @@ from tests.models.utils import (TokensTextLogprobs,
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import TaskOption, TokenizerPoolConfig
|
||||
from vllm.config import TaskOption, TokenizerPoolConfig, _get_and_verify_dtype
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
@ -34,8 +34,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
identity, is_list_of)
|
||||
from vllm.utils import cuda_device_count_stateless, is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -271,14 +270,18 @@ _R = TypeVar("_R")
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
def get_default_device(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return ("cpu" if current_platform.is_cpu()
|
||||
or current_platform.is_openvino() else "cuda")
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
return x
|
||||
|
||||
if device is None:
|
||||
device = "cpu" if current_platform.is_cpu(
|
||||
) or current_platform.is_openvino() else "cuda"
|
||||
device = self.device
|
||||
|
||||
if isinstance(x, dict):
|
||||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||||
@ -291,45 +294,59 @@ class HfRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dtype: str = "half",
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
postprocess_inputs: Callable[..., BatchEncoding] = identity,
|
||||
) -> None:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.device = self.get_default_device()
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
||||
|
||||
if is_sentence_transformer:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = self.wrap_device(
|
||||
SentenceTransformer(
|
||||
model_name,
|
||||
device="cpu",
|
||||
trust_remote_code=True,
|
||||
).to(dtype=torch_dtype))
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
device=self.device,
|
||||
model_kwargs=model_kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
elif is_cross_encoder:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import CrossEncoder
|
||||
self.model = CrossEncoder(model_name,
|
||||
device="cpu",
|
||||
trust_remote_code=True)
|
||||
self.model.model = self.wrap_device(self.model.model)\
|
||||
.to(dtype=torch_dtype)
|
||||
|
||||
self.model = CrossEncoder(
|
||||
model_name,
|
||||
device=self.device,
|
||||
automodel_args=model_kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
self.model = self.wrap_device(
|
||||
auto_cls.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
**model_kwargs,
|
||||
))
|
||||
model = auto_cls.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device
|
||||
for p in model.parameters()}) < 2):
|
||||
model = model.to(self.device)
|
||||
|
||||
self.model = model
|
||||
|
||||
if not skip_tokenizer_init:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
@ -349,16 +366,13 @@ class HfRunner:
|
||||
if skip_tokenizer_init:
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
|
||||
self.dtype = dtype
|
||||
self.postprocess_inputs = postprocess_inputs
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> list[BatchEncoding]:
|
||||
) -> list[Union[BatchFeature, BatchEncoding]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
@ -368,7 +382,7 @@ class HfRunner:
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
|
||||
all_inputs: list[BatchEncoding] = []
|
||||
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
"text": prompt,
|
||||
@ -384,7 +398,8 @@ class HfRunner:
|
||||
processor_kwargs["sampling_rate"] = sr
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = self.postprocess_inputs(inputs, dtype=self.dtype)
|
||||
if isinstance(inputs, BatchFeature):
|
||||
inputs = inputs.to(dtype=self.dtype)
|
||||
|
||||
all_inputs.append(inputs)
|
||||
|
||||
@ -417,7 +432,7 @@ class HfRunner:
|
||||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||||
for inputs in all_inputs:
|
||||
output_ids = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
@ -488,7 +503,7 @@ class HfRunner:
|
||||
all_logprobs: list[list[torch.Tensor]] = []
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
@ -569,7 +584,7 @@ class HfRunner:
|
||||
|
||||
for inputs in all_inputs:
|
||||
output = self.model.generate(
|
||||
**self.wrap_device(inputs, device=self.model.device.type),
|
||||
**self.wrap_device(inputs),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
@ -620,19 +635,15 @@ class HfRunner:
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
encoder_inputs = self.wrap_device(
|
||||
self.processor(**processor_kwargs),
|
||||
device=self.model.device.type,
|
||||
)
|
||||
encoder_inputs = self.processor(**processor_kwargs)
|
||||
encoder_inputs = self.wrap_device(encoder_inputs)
|
||||
|
||||
if decoder_prompt is None:
|
||||
decoder_input_ids = None
|
||||
else:
|
||||
decoder_input_ids = self.wrap_device(
|
||||
self.tokenizer(decoder_prompt,
|
||||
return_tensors="pt").input_ids,
|
||||
device=self.model.device.type,
|
||||
)
|
||||
decoder_inputs = self.tokenizer(decoder_prompt,
|
||||
return_tensors="pt")
|
||||
decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
|
||||
|
||||
output = self.model.generate(
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -684,6 +695,7 @@ class VllmRunner:
|
||||
"""
|
||||
The default value of some arguments have been modified from
|
||||
:class:`~vllm.LLM` as follows:
|
||||
|
||||
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
|
||||
- `seed`: Set to `0` instead of `None` for test reproducibility.
|
||||
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
|
||||
@ -701,10 +713,8 @@ class VllmRunner:
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
seed: Optional[int] = 0,
|
||||
# Use smaller max model length, otherwise bigger model cannot run due
|
||||
# to kv cache size limit.
|
||||
max_model_len: int = 1024,
|
||||
dtype: str = "half",
|
||||
dtype: str = "auto",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
@ -1110,4 +1120,4 @@ def pytest_collection_modifyitems(config, items):
|
||||
skip_optional = pytest.mark.skip(reason="need --optional option to run")
|
||||
for item in items:
|
||||
if "optional" in item.keywords:
|
||||
item.add_marker(skip_optional)
|
||||
item.add_marker(skip_optional)
|
||||
|
@ -64,7 +64,6 @@ def test_multi_chat():
|
||||
def test_chat_multi_image(image_urls: list[str]):
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
enforce_eager=True,
|
||||
|
@ -18,8 +18,6 @@ TEST_AUDIO_URLS = [
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
|
@ -24,8 +24,6 @@ def server():
|
||||
args = [
|
||||
"--task",
|
||||
"generate",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"32768",
|
||||
"--max-num-seqs",
|
||||
|
@ -25,8 +25,6 @@ def server():
|
||||
args = [
|
||||
"--task",
|
||||
"generate",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
|
@ -28,8 +28,6 @@ def server():
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
|
@ -34,7 +34,7 @@ def phi3v_model_config():
|
||||
tokenizer=PHI3V_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
dtype="auto",
|
||||
seed=0,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
@ -58,7 +58,7 @@ def mllama_model_config():
|
||||
tokenizer=MLLAMA_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
dtype="auto",
|
||||
seed=0,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
@ -669,7 +669,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||
tokenizer=MLLAMA_MODEL_ID,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
dtype="auto",
|
||||
seed=0,
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
|
@ -5,11 +5,10 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoModel, AutoTokenizer, BatchEncoding
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.audio import resample_audio
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
@ -107,8 +106,6 @@ def run_test(
|
||||
**kwargs,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm."""
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
@ -124,15 +121,7 @@ def run_test(
|
||||
for vllm_prompt, _, audio in prompts_and_audios
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding, **kwargs):
|
||||
hf_inputs["audio_values"] = hf_inputs["audio_values"] \
|
||||
.to(torch_dtype) # type: ignore
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModel) as hf_model:
|
||||
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
|
||||
hf_outputs_per_audio = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
[hf_prompt],
|
||||
|
@ -122,9 +122,6 @@ VLM_TEST_SETTINGS = {
|
||||
"cherry_blossom": "What is in the picture?",
|
||||
}),
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||
dtype="bfloat16",
|
||||
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
|
||||
@ -179,7 +176,6 @@ VLM_TEST_SETTINGS = {
|
||||
# "cherry_blossom": "<vlm_image>Please infer the season with reason.", # noqa: E501
|
||||
# }),
|
||||
# multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||
# postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), # noqa: E501
|
||||
# stop_str=["<|im_end|>"],
|
||||
# image_size_factors=[(0.10, 0.15)],
|
||||
# max_tokens=64,
|
||||
@ -200,9 +196,6 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
# For chameleon, we only compare the sequences
|
||||
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
|
||||
hf_output_post_proc = lambda hf_output, model: hf_output[:2],
|
||||
@ -222,7 +215,6 @@ VLM_TEST_SETTINGS = {
|
||||
}),
|
||||
multi_image_prompt="image_1:<image>\nimage_2:<image>\nWhich image can we see the car and the tower?", # noqa: E501
|
||||
patch_hf_runner=model_utils.deepseekvl2_patch_hf_runner,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor("images"),
|
||||
hf_output_post_proc=model_utils.deepseekvl2_trunc_hf_output,
|
||||
stop_str=["<|end▁of▁sentence|>", "<|begin▁of▁sentence|>"], # noqa: E501
|
||||
image_size_factors=[(), (1.0, ), (1.0, 1.0, 1.0), (0.1, 0.5, 1.0)],
|
||||
@ -258,7 +250,6 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
dtype="bfloat16",
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
|
||||
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
|
||||
),
|
||||
@ -272,7 +263,6 @@ VLM_TEST_SETTINGS = {
|
||||
}),
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
get_stop_token_ids=lambda tok: [151329, 151336, 151338],
|
||||
patch_hf_runner=model_utils.glm4v_patch_hf_runner,
|
||||
# The image embeddings match with HF but the outputs of the language
|
||||
@ -295,7 +285,6 @@ VLM_TEST_SETTINGS = {
|
||||
}),
|
||||
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
num_logprobs=10,
|
||||
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
|
||||
@ -324,10 +313,6 @@ VLM_TEST_SETTINGS = {
|
||||
}),
|
||||
multi_image_prompt="Image-1: <image>\nImage-2: <image>\nDescribe the two images in short.", # noqa: E501
|
||||
max_model_len=4096,
|
||||
# NOTE: Mono-InternVL-2B doesn't work with fp16,
|
||||
# it will result NaN during inference.
|
||||
# See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
),
|
||||
@ -351,9 +336,6 @@ VLM_TEST_SETTINGS = {
|
||||
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||
num_video_frames=16,
|
||||
max_model_len=16384,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values_videos"
|
||||
),
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
@ -378,9 +360,6 @@ VLM_TEST_SETTINGS = {
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||
max_model_len=4096,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
get_stop_token_ids=lambda tok: [128009],
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
|
||||
@ -400,8 +379,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id],
|
||||
postprocess_inputs=model_utils.wrap_inputs_post_processor,
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_25_patch_hf_runner,
|
||||
),
|
||||
"minicpmo_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-o-2_6"],
|
||||
@ -411,11 +390,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
postprocess_inputs=model_utils.ignore_inputs_post_processor(
|
||||
"image_sizes"
|
||||
),
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmo_patch_hf_runner
|
||||
patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner,
|
||||
),
|
||||
"minicpmv_26": VLMTestInfo(
|
||||
models=["openbmb/MiniCPM-V-2_6"],
|
||||
@ -425,10 +401,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
|
||||
postprocess_inputs=model_utils.ignore_inputs_post_processor(
|
||||
"image_sizes"
|
||||
),
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
@ -437,7 +411,6 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
patch_hf_runner=model_utils.molmo_patch_hf_runner,
|
||||
postprocess_inputs=model_utils.molmo_post_processor,
|
||||
),
|
||||
# Tests for phi3v currently live in another file because of a bug in
|
||||
# transformers. Once this issue is fixed, we can enable them here instead.
|
||||
@ -482,9 +455,6 @@ VLM_TEST_SETTINGS = {
|
||||
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
||||
max_model_len=4096,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
|
||||
hf_output_post_proc = lambda hf_output, model: hf_output[:2],
|
||||
comparator=check_outputs_equal,
|
||||
@ -529,9 +499,6 @@ VLM_TEST_SETTINGS = {
|
||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=2,
|
||||
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||
"pixel_values"
|
||||
),
|
||||
auto_cls=AutoModelForVision2Seq,
|
||||
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
|
||||
custom_test_opts=[CustomTestOptions(
|
||||
|
@ -4,7 +4,6 @@ from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchEncoding
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
@ -31,7 +30,6 @@ def run_test(
|
||||
vllm_output_post_proc: Optional[Callable[[RunnerOutput, str], Any]],
|
||||
auto_cls: type[_BaseAutoModelClass],
|
||||
use_tokenizer_eos: bool,
|
||||
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
||||
comparator: Callable[..., None],
|
||||
get_stop_token_ids: Optional[Callable[[AnyTokenizer], list[int]]],
|
||||
stop_str: Optional[list[str]],
|
||||
@ -101,7 +99,6 @@ def run_test(
|
||||
hf_model = hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=auto_cls,
|
||||
postprocess_inputs=postprocess_inputs,
|
||||
model_kwargs=hf_model_kwargs)
|
||||
|
||||
# Some models need to patch things like the model processor, e.g., internvl
|
||||
|
@ -6,16 +6,15 @@ typically specific to a small subset of models.
|
||||
import re
|
||||
import types
|
||||
from pathlib import PosixPath
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchEncoding,
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
|
||||
GenerationConfig)
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, _ImageAssets
|
||||
from .types import RunnerOutput
|
||||
@ -211,40 +210,6 @@ def get_llava_embeddings(image_assets: _ImageAssets):
|
||||
return [asset.image_embeds for asset in image_assets]
|
||||
|
||||
|
||||
####### postprocessors to run on HF BatchEncoding
|
||||
def cast_dtype_post_processor(
|
||||
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
||||
"""Gets a handle to a post processor which converts a given key into a
|
||||
target data type."""
|
||||
|
||||
def process(hf_inputs: BatchEncoding, dtype: str):
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
hf_inputs[hf_inp_key] = hf_inputs[hf_inp_key].to(torch_dtype)
|
||||
return hf_inputs
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def ignore_inputs_post_processor(
|
||||
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
||||
"""Gets a handle to a post processor which ignores a given key."""
|
||||
|
||||
def process(hf_inputs: BatchEncoding, dtype: str):
|
||||
del hf_inputs[hf_inp_key]
|
||||
return hf_inputs
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
return {"model_inputs": hf_inputs}
|
||||
|
||||
|
||||
def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str):
|
||||
hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype)
|
||||
return {k: v.unsqueeze(0) for k, v in hf_inputs.items()}
|
||||
|
||||
|
||||
####### Prompt path encoders for models that need models on disk
|
||||
def qwen_prompt_path_encoder(
|
||||
tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset],
|
||||
@ -295,8 +260,7 @@ def deepseekvl2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
for k in inputs.keys() # noqa
|
||||
if k not in ("seq_lens", "sft_format")
|
||||
}
|
||||
inputs = BatchEncoding(data=inputs, tensor_type="pt")
|
||||
return inputs
|
||||
return BatchFeature(data=inputs, tensor_type="pt")
|
||||
|
||||
hf_model.processor = processor
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
@ -529,10 +493,52 @@ def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
def minicpmv_25_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
def _generate(
|
||||
self,
|
||||
*args,
|
||||
input_ids=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
image_bound=None,
|
||||
tgt_sizes=None,
|
||||
**kwargs,
|
||||
):
|
||||
model_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"image_sizes": image_sizes,
|
||||
"image_bound": image_bound,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
}
|
||||
for k in list(model_inputs.keys()):
|
||||
if model_inputs[k] is None:
|
||||
model_inputs.pop(k)
|
||||
|
||||
return orig_generate(model_inputs, *args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmo_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, image_sizes=None, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, image_sizes=None, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
@ -551,10 +557,11 @@ def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
|
||||
def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
|
||||
batch = {
|
||||
k: kwargs.pop(k)
|
||||
k: kwargs.pop(k).unsqueeze(0)
|
||||
for k in ("input_ids", "images", "image_input_idx", "image_masks")
|
||||
if k in kwargs
|
||||
}
|
||||
batch = BatchFeature(batch).to(dtype=self.dtype)
|
||||
|
||||
return self.generate_from_batch(
|
||||
batch,
|
||||
|
@ -8,13 +8,12 @@ from typing import Any, Callable, NamedTuple, Optional, Union
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoModelForCausalLM, BatchEncoding
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import identity
|
||||
|
||||
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
|
||||
from ....utils import check_logprobs_close
|
||||
@ -110,11 +109,6 @@ class VLMTestInfo(NamedTuple):
|
||||
# Indicates we should explicitly pass the EOS from the tokenizer
|
||||
use_tokenizer_eos: bool = False
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM
|
||||
# Callable to pass to the HF runner to run on inputs; for now, we also pass
|
||||
# the data type to input post processing, because almost all of the uses of
|
||||
# postprocess_inputs are to fix the data types of BatchEncoding values.
|
||||
postprocess_inputs: Callable[[BatchEncoding, str],
|
||||
BatchEncoding] = identity
|
||||
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]] = None
|
||||
|
||||
# Post processors that if defined, will run oun the outputs of the
|
||||
@ -130,7 +124,7 @@ class VLMTestInfo(NamedTuple):
|
||||
# is all combinations of .models + all fields below
|
||||
max_tokens: Union[int, tuple[int]] = 128
|
||||
num_logprobs: Union[int, tuple[int]] = 5
|
||||
dtype: Union[str, Iterable[str]] = "half"
|
||||
dtype: Union[str, Union[list[str], tuple[str, ...]]] = "auto"
|
||||
distributed_executor_backend: Optional[Union[str, Iterable[str]]] = None
|
||||
# Only expanded in video tests
|
||||
num_video_frames: Union[int, tuple[int]] = 16
|
||||
@ -171,7 +165,6 @@ class VLMTestInfo(NamedTuple):
|
||||
"vllm_output_post_proc": self.vllm_output_post_proc,
|
||||
"auto_cls": self.auto_cls,
|
||||
"use_tokenizer_eos": self.use_tokenizer_eos,
|
||||
"postprocess_inputs": self.postprocess_inputs,
|
||||
"comparator": self.comparator,
|
||||
"get_stop_token_ids": self.get_stop_token_ids,
|
||||
"hf_model_kwargs": self.hf_model_kwargs,
|
||||
|
@ -1,12 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding, Qwen2VLForConditionalGeneration
|
||||
from transformers import Qwen2VLForConditionalGeneration
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
from ....utils import large_gpu_test
|
||||
@ -75,10 +75,6 @@ def apply_chat_template_and_add_eos(
|
||||
return prompt
|
||||
|
||||
|
||||
def postprocess_inputs(hf_model: HfRunner, inputs: BatchEncoding, **kwargs):
|
||||
return hf_model.model.prepare_inputs_for_generation(**inputs, **kwargs)
|
||||
|
||||
|
||||
def _run_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
@ -118,14 +114,8 @@ def _run_test(
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=Qwen2VLForConditionalGeneration) as hf_model:
|
||||
hf_model.postprocess_inputs = partial(
|
||||
postprocess_inputs,
|
||||
hf_model,
|
||||
cache_position=torch.arange(
|
||||
0,
|
||||
1, # 1 for batch size
|
||||
requires_grad=False),
|
||||
use_cache=False)
|
||||
|
||||
prompts = []
|
||||
for text, image, embed_text in zip(input_texts, input_images,
|
||||
embed_texts):
|
||||
# dse requires non-standard input processing
|
||||
@ -133,20 +123,34 @@ def _run_test(
|
||||
messages = get_messages(image, text, embed_text)
|
||||
prompt = apply_chat_template_and_add_eos(
|
||||
messages, hf_model.processor.apply_chat_template)
|
||||
inputs = hf_model.get_inputs(
|
||||
prompts=[[prompt]],
|
||||
images=[[image]],
|
||||
)
|
||||
with torch.no_grad():
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
all_inputs = hf_model.get_inputs(
|
||||
prompts=prompts,
|
||||
images=input_images,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
all_outputs = []
|
||||
for inputs in all_inputs:
|
||||
inputs = hf_model.model.prepare_inputs_for_generation(
|
||||
**inputs,
|
||||
cache_position=torch.arange(1), # 1 for batch size
|
||||
use_cache=False,
|
||||
)
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs[0],
|
||||
device=hf_model.model.device.type),
|
||||
**hf_model.wrap_device(inputs),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
pooled_output = torch.nn.functional.normalize(
|
||||
outputs.hidden_states[-1][0, -1], p=2, dim=-1)
|
||||
hf_outputs.append(pooled_output.tolist())
|
||||
pooled_output = F.normalize(outputs.hidden_states[-1][0, -1],
|
||||
p=2,
|
||||
dim=-1)
|
||||
|
||||
all_outputs.append(pooled_output.tolist())
|
||||
|
||||
hf_outputs = all_outputs
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
|
@ -86,8 +86,7 @@ def _run_test(
|
||||
for inputs in all_inputs:
|
||||
# Based on: https://huggingface.co/royokong/e5-v
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs,
|
||||
device=hf_model.model.device.type),
|
||||
**hf_model.wrap_device(inputs),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
@ -53,8 +53,7 @@ def _run_test(
|
||||
for inputs in all_inputs:
|
||||
# Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py
|
||||
outputs = hf_model.model(
|
||||
**hf_model.wrap_device(inputs,
|
||||
device=hf_model.model.device.type),
|
||||
**hf_model.wrap_device(inputs),
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
@ -4,8 +4,7 @@ from typing import Optional, overload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoModelForImageTextToText,
|
||||
AutoTokenizer, BatchEncoding)
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
@ -227,13 +226,9 @@ def _run_test(
|
||||
for prompts, images in inputs
|
||||
]
|
||||
|
||||
def process(hf_inputs: BatchEncoding, **kwargs):
|
||||
return hf_inputs
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
model_kwargs={"device_map": "auto"},
|
||||
postprocess_inputs=process,
|
||||
auto_cls=AutoModelForImageTextToText) as hf_model:
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -254,9 +254,9 @@ def check_logprobs_close(
|
||||
def build_model_context(
|
||||
model_id: str,
|
||||
task: TaskOption = "auto",
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
mm_processor_kwargs: Optional[dict] = None,
|
||||
limit_mm_per_prompt: Optional[dict] = None,
|
||||
dtype: Union[str, torch.dtype] = "auto",
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
disable_mm_preprocessor_cache: bool = True,
|
||||
):
|
||||
"""Creates an InputContext for a given model.
|
||||
@ -274,9 +274,6 @@ def build_model_context(
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
if dtype is None:
|
||||
dtype = "half"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task=task,
|
||||
|
@ -853,7 +853,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
@ -892,7 +892,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
@ -965,7 +965,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
|
@ -166,7 +166,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
test_prompts = multilora_inference.create_test_prompts(lora_path)
|
||||
|
||||
# Serialize model before deserializing and binding LoRA adapters
|
||||
with vllm_runner(model_ref, ) as vllm_model:
|
||||
with vllm_runner(model_ref) as vllm_model:
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
|
||||
vllm_model.apply_model(
|
||||
@ -208,7 +208,7 @@ def test_load_without_tensorizer_load_format(vllm_runner):
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
||||
## Serialize model
|
||||
with vllm_runner(model_ref, ) as vllm_model:
|
||||
with vllm_runner(model_ref) as vllm_model:
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
|
||||
vllm_model.apply_model(
|
||||
|
@ -50,7 +50,7 @@ def _get_test_sampling_params(
|
||||
"""Generate random sampling params for a batch."""
|
||||
|
||||
def get_mostly_n_gt1() -> int:
|
||||
"""Mostly n \in [2,20], ~1/3 n=1"""
|
||||
r"""Mostly n \in [2,20], ~1/3 n=1"""
|
||||
x = random.randint(0, 28)
|
||||
if x < 10:
|
||||
return 1
|
||||
|
@ -347,7 +347,7 @@ class ModelConfig:
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
|
||||
@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
|
||||
# Fallbacks for multi-modal models if the root config
|
||||
# does not define torch_dtype
|
||||
if config_dtype is None and hasattr(config, "text_config"):
|
||||
config_dtype = getattr(config.text_config, "torch_dtype", None)
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type in ("gemma2", "gemma3", "gemma3_text"):
|
||||
logger.info(
|
||||
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
|
||||
"instead of float16 by default. Please specify `dtype` "
|
||||
"if you want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
else:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
# Following common practice, we use float16 for float32 models
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user