[Model] Update multi-modal processor to support Mantis(LLaVA) model (#10711)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1c768fe537
commit
39e227c7ae
@ -362,6 +362,7 @@ steps:
|
|||||||
- tests/models/embedding/vision_language
|
- tests/models/embedding/vision_language
|
||||||
- tests/models/encoder_decoder/vision_language
|
- tests/models/encoder_decoder/vision_language
|
||||||
commands:
|
commands:
|
||||||
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s models/embedding/vision_language -m core_model
|
- pytest -v -s models/embedding/vision_language -m core_model
|
||||||
@ -377,6 +378,7 @@ steps:
|
|||||||
- tests/models/embedding/vision_language
|
- tests/models/embedding/vision_language
|
||||||
- tests/models/encoder_decoder/vision_language
|
- tests/models/encoder_decoder/vision_language
|
||||||
commands:
|
commands:
|
||||||
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
# HACK - run phi3v tests separately to sidestep this transformers bug
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
# https://github.com/huggingface/transformers/issues/34307
|
||||||
|
@ -555,7 +555,7 @@ Text Generation
|
|||||||
* - :code:`LlavaForConditionalGeneration`
|
* - :code:`LlavaForConditionalGeneration`
|
||||||
- LLaVA-1.5
|
- LLaVA-1.5
|
||||||
- T + I\ :sup:`E+`
|
- T + I\ :sup:`E+`
|
||||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc.
|
||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`LlavaNextForConditionalGeneration`
|
* - :code:`LlavaNextForConditionalGeneration`
|
||||||
@ -664,6 +664,10 @@ Text Generation
|
|||||||
.. note::
|
.. note::
|
||||||
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`)
|
||||||
|
and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||||
|
@ -419,6 +419,22 @@ def run_aria(question: str, modality: str):
|
|||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# Mantis
|
||||||
|
def run_mantis(question: str, modality: str):
|
||||||
|
assert modality == "image"
|
||||||
|
|
||||||
|
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
|
||||||
|
prompt = llama3_template.format(f"{question}\n<image>")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
|
max_model_len=4096,
|
||||||
|
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
||||||
|
)
|
||||||
|
stop_token_ids = [128009]
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
"llava-next": run_llava_next,
|
"llava-next": run_llava_next,
|
||||||
@ -441,6 +457,7 @@ model_example_map = {
|
|||||||
"glm4v": run_glm4v,
|
"glm4v": run_glm4v,
|
||||||
"idefics3": run_idefics3,
|
"idefics3": run_idefics3,
|
||||||
"aria": run_aria,
|
"aria": run_aria,
|
||||||
|
"mantis": run_mantis,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,9 +24,6 @@ mistral_common[opencv] >= 1.5.0 # required for pixtral test
|
|||||||
datamodel_code_generator # required for minicpm3 test
|
datamodel_code_generator # required for minicpm3 test
|
||||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||||
|
|
||||||
# TODO: Add this after fully implementing llava(mantis)
|
|
||||||
# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test
|
|
||||||
|
|
||||||
# quantization
|
# quantization
|
||||||
bitsandbytes>=0.44.0
|
bitsandbytes>=0.44.0
|
||||||
buildkite-test-collector==0.1.9
|
buildkite-test-collector==0.1.9
|
||||||
|
@ -34,7 +34,7 @@ COMMON_BROADCAST_SETTINGS = {
|
|||||||
"dtype": "half",
|
"dtype": "half",
|
||||||
"max_tokens": 5,
|
"max_tokens": 5,
|
||||||
"tensor_parallel_size": 2,
|
"tensor_parallel_size": 2,
|
||||||
"model_kwargs": {"device_map": "auto"},
|
"hf_model_kwargs": {"device_map": "auto"},
|
||||||
"image_size_factors": [(.25, 0.5, 1.0)],
|
"image_size_factors": [(.25, 0.5, 1.0)],
|
||||||
"distributed_executor_backend": (
|
"distributed_executor_backend": (
|
||||||
"ray",
|
"ray",
|
||||||
@ -108,7 +108,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
"cherry_blossom": "What is in the picture?",
|
"cherry_blossom": "What is in the picture?",
|
||||||
}),
|
}),
|
||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
"pixel_values"
|
"pixel_values"
|
||||||
),
|
),
|
||||||
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
|
||||||
@ -151,7 +151,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
|
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
|
||||||
}),
|
}),
|
||||||
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
|
postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"),
|
||||||
stop_str=["<|im_end|>"],
|
stop_str=["<|im_end|>"],
|
||||||
image_size_factors=[(0.10, 0.15)],
|
image_size_factors=[(0.10, 0.15)],
|
||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
@ -177,7 +177,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
"pixel_values"
|
"pixel_values"
|
||||||
),
|
),
|
||||||
# For chameleon, we only compare the sequences
|
# For chameleon, we only compare the sequences
|
||||||
@ -281,7 +281,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||||
num_video_frames=16,
|
num_video_frames=16,
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
"pixel_values_videos"
|
"pixel_values_videos"
|
||||||
),
|
),
|
||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
@ -306,6 +306,20 @@ VLM_TEST_SETTINGS = {
|
|||||||
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
|
||||||
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
|
||||||
),
|
),
|
||||||
|
"mantis": VLMTestInfo(
|
||||||
|
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
|
||||||
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
|
"pixel_values"
|
||||||
|
),
|
||||||
|
vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501
|
||||||
|
get_stop_token_ids=lambda tok: [128009],
|
||||||
|
auto_cls=AutoModelForVision2Seq,
|
||||||
|
vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output,
|
||||||
|
patch_hf_runner=model_utils.mantis_patch_hf_runner,
|
||||||
|
),
|
||||||
"minicpmv_25": VLMTestInfo(
|
"minicpmv_25": VLMTestInfo(
|
||||||
models=["openbmb/MiniCPM-Llama3-V-2_5"],
|
models=["openbmb/MiniCPM-Llama3-V-2_5"],
|
||||||
test_type=VLMTestType.IMAGE,
|
test_type=VLMTestType.IMAGE,
|
||||||
@ -342,7 +356,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
# max_num_seqs=2,
|
# max_num_seqs=2,
|
||||||
# task="generate",
|
# task="generate",
|
||||||
# # use eager mode for hf runner since phi3v didn't work with flash_attn
|
# # use eager mode for hf runner since phi3v didn't work with flash_attn
|
||||||
# model_kwargs={"_attn_implementation": "eager"},
|
# hf_model_kwargs={"_attn_implementation": "eager"},
|
||||||
# use_tokenizer_eos=True,
|
# use_tokenizer_eos=True,
|
||||||
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
|
# vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
|
||||||
# num_logprobs=10,
|
# num_logprobs=10,
|
||||||
@ -373,7 +387,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
"pixel_values"
|
"pixel_values"
|
||||||
),
|
),
|
||||||
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
|
vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2],
|
||||||
@ -438,7 +452,7 @@ VLM_TEST_SETTINGS = {
|
|||||||
test_type=VLMTestType.CUSTOM_INPUTS,
|
test_type=VLMTestType.CUSTOM_INPUTS,
|
||||||
max_model_len=16384,
|
max_model_len=16384,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
postprocess_inputs=model_utils.get_key_type_post_processor(
|
postprocess_inputs=model_utils.cast_dtype_post_processor(
|
||||||
"pixel_values"
|
"pixel_values"
|
||||||
),
|
),
|
||||||
auto_cls=AutoModelForVision2Seq,
|
auto_cls=AutoModelForVision2Seq,
|
||||||
|
@ -3,9 +3,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from transformers import AutoTokenizer, BatchEncoding
|
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
|
from vllm.config import TaskOption
|
||||||
|
|
||||||
from .....conftest import HfRunner, VllmRunner
|
from .....conftest import HfRunner, VllmRunner
|
||||||
from .types import RunnerOutput
|
from .types import RunnerOutput
|
||||||
|
|
||||||
@ -28,13 +30,15 @@ def run_test(
|
|||||||
use_tokenizer_eos: bool,
|
use_tokenizer_eos: bool,
|
||||||
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
|
||||||
comparator: Callable[..., None],
|
comparator: Callable[..., None],
|
||||||
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
|
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
|
||||||
|
List[int]]],
|
||||||
stop_str: Optional[List[str]],
|
stop_str: Optional[List[str]],
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
limit_mm_per_prompt: Dict[str, int],
|
limit_mm_per_prompt: Dict[str, int],
|
||||||
model_kwargs: Optional[Dict[str, Any]],
|
vllm_runner_kwargs: Optional[Dict[str, Any]],
|
||||||
|
hf_model_kwargs: Optional[Dict[str, Any]],
|
||||||
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
|
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
|
||||||
task: str = "auto",
|
task: TaskOption = "auto",
|
||||||
runner_mm_key: str = "images",
|
runner_mm_key: str = "images",
|
||||||
distributed_executor_backend: Optional[str] = None,
|
distributed_executor_backend: Optional[str] = None,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
@ -58,6 +62,9 @@ def run_test(
|
|||||||
if stop_str:
|
if stop_str:
|
||||||
vllm_kwargs["stop"] = stop_str
|
vllm_kwargs["stop"] = stop_str
|
||||||
|
|
||||||
|
if vllm_runner_kwargs is None:
|
||||||
|
vllm_runner_kwargs = {}
|
||||||
|
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
@ -67,7 +74,8 @@ def run_test(
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
task=task) as vllm_model:
|
task=task,
|
||||||
|
**vllm_runner_kwargs) as vllm_model:
|
||||||
for prompts, media in vllm_inputs:
|
for prompts, media in vllm_inputs:
|
||||||
vllm_kwargs[runner_mm_key] = media
|
vllm_kwargs[runner_mm_key] = media
|
||||||
vllm_output = vllm_model.generate_greedy_logprobs(
|
vllm_output = vllm_model.generate_greedy_logprobs(
|
||||||
@ -78,7 +86,7 @@ def run_test(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
auto_cls=auto_cls,
|
auto_cls=auto_cls,
|
||||||
postprocess_inputs=postprocess_inputs,
|
postprocess_inputs=postprocess_inputs,
|
||||||
model_kwargs=model_kwargs)
|
model_kwargs=hf_model_kwargs)
|
||||||
|
|
||||||
# Some models need to patch things like the model processor, e.g., internvl
|
# Some models need to patch things like the model processor, e.g., internvl
|
||||||
if patch_hf_runner is not None:
|
if patch_hf_runner is not None:
|
||||||
|
@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput,
|
|||||||
return hf_output_ids, hf_output_str, out_logprobs
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def mantis_vllm_to_hf_output(vllm_output: RunnerOutput,
|
||||||
|
model: str) -> RunnerOutput:
|
||||||
|
"""Sanitize vllm output [mantis] to compare with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
hf_output_str = output_str + "<|eot_id|>"
|
||||||
|
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
|
def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
|
||||||
model: str) -> RunnerOutput:
|
model: str) -> RunnerOutput:
|
||||||
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
|
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
|
||||||
@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets):
|
|||||||
|
|
||||||
|
|
||||||
####### postprocessors to run on HF BatchEncoding
|
####### postprocessors to run on HF BatchEncoding
|
||||||
def get_key_type_post_processor(
|
def cast_dtype_post_processor(
|
||||||
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
|
||||||
"""Gets a handle to a post processor which converts a given key into a
|
"""Gets a handle to a post processor which converts a given key into a
|
||||||
target data type."""
|
target data type."""
|
||||||
@ -418,3 +428,26 @@ def _internvl_generate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||||
|
from mantis.models.mllava import MLlavaProcessor
|
||||||
|
|
||||||
|
hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name)
|
||||||
|
|
||||||
|
orig_generate = hf_model.model.generate
|
||||||
|
tokenizer = hf_model.processor.tokenizer
|
||||||
|
|
||||||
|
def _generate(self, *args, **kwargs):
|
||||||
|
return orig_generate(
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
eos_token_id=[
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||||
|
|
||||||
|
return hf_model
|
||||||
|
@ -7,9 +7,11 @@ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Optional,
|
|||||||
import torch
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pytest import MarkDecorator
|
from pytest import MarkDecorator
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
|
from transformers import (AutoModelForCausalLM, BatchEncoding,
|
||||||
|
PreTrainedTokenizerBase)
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
|
from vllm.config import TaskOption
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.utils import identity
|
from vllm.utils import identity
|
||||||
|
|
||||||
@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple):
|
|||||||
class VLMTestInfo(NamedTuple):
|
class VLMTestInfo(NamedTuple):
|
||||||
"""Holds the configuration for 1+ tests for one model architecture."""
|
"""Holds the configuration for 1+ tests for one model architecture."""
|
||||||
|
|
||||||
models: Union[List[str]]
|
models: List[str]
|
||||||
test_type: Union[VLMTestType, Iterable[VLMTestType]]
|
test_type: Union[VLMTestType, Iterable[VLMTestType]]
|
||||||
|
|
||||||
# Should be None only if this is a CUSTOM_INPUTS test
|
# Should be None only if this is a CUSTOM_INPUTS test
|
||||||
@ -92,18 +94,20 @@ class VLMTestInfo(NamedTuple):
|
|||||||
enforce_eager: bool = True
|
enforce_eager: bool = True
|
||||||
max_model_len: int = 1024
|
max_model_len: int = 1024
|
||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
task: str = "auto"
|
task: TaskOption = "auto"
|
||||||
tensor_parallel_size: int = 1
|
tensor_parallel_size: int = 1
|
||||||
|
vllm_runner_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# Optional callable which gets a list of token IDs from the model tokenizer
|
# Optional callable which gets a list of token IDs from the model tokenizer
|
||||||
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
|
get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase],
|
||||||
|
List[int]]] = None
|
||||||
# Optional list of strings to stop generation, useful when stop tokens are
|
# Optional list of strings to stop generation, useful when stop tokens are
|
||||||
# not special tokens in the tokenizer
|
# not special tokens in the tokenizer
|
||||||
stop_str: Optional[List[str]] = None
|
stop_str: Optional[List[str]] = None
|
||||||
|
|
||||||
# Exposed options for HF runner
|
# Exposed options for HF runner
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
hf_model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
# Indicates we should explicitly pass the EOS from the tokeniezr
|
# Indicates we should explicitly pass the EOS from the tokenizer
|
||||||
use_tokenizer_eos: bool = False
|
use_tokenizer_eos: bool = False
|
||||||
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
|
auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM
|
||||||
# Callable to pass to the HF runner to run on inputs; for now, we also pass
|
# Callable to pass to the HF runner to run on inputs; for now, we also pass
|
||||||
@ -164,6 +168,7 @@ class VLMTestInfo(NamedTuple):
|
|||||||
"max_num_seqs": self.max_num_seqs,
|
"max_num_seqs": self.max_num_seqs,
|
||||||
"task": self.task,
|
"task": self.task,
|
||||||
"tensor_parallel_size": self.tensor_parallel_size,
|
"tensor_parallel_size": self.tensor_parallel_size,
|
||||||
|
"vllm_runner_kwargs": self.vllm_runner_kwargs,
|
||||||
"hf_output_post_proc": self.hf_output_post_proc,
|
"hf_output_post_proc": self.hf_output_post_proc,
|
||||||
"vllm_output_post_proc": self.vllm_output_post_proc,
|
"vllm_output_post_proc": self.vllm_output_post_proc,
|
||||||
"auto_cls": self.auto_cls,
|
"auto_cls": self.auto_cls,
|
||||||
@ -171,8 +176,8 @@ class VLMTestInfo(NamedTuple):
|
|||||||
"postprocess_inputs": self.postprocess_inputs,
|
"postprocess_inputs": self.postprocess_inputs,
|
||||||
"comparator": self.comparator,
|
"comparator": self.comparator,
|
||||||
"get_stop_token_ids": self.get_stop_token_ids,
|
"get_stop_token_ids": self.get_stop_token_ids,
|
||||||
|
"hf_model_kwargs": self.hf_model_kwargs,
|
||||||
"stop_str": self.stop_str,
|
"stop_str": self.stop_str,
|
||||||
"model_kwargs": self.model_kwargs,
|
|
||||||
"patch_hf_runner": self.patch_hf_runner,
|
"patch_hf_runner": self.patch_hf_runner,
|
||||||
"tokenizer_mode": self.tokenizer_mode
|
"tokenizer_mode": self.tokenizer_mode
|
||||||
}
|
}
|
||||||
|
@ -176,6 +176,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
|
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
|
||||||
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
|
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
|
||||||
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
|
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
|
||||||
|
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
|
||||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||||
|
@ -3,16 +3,14 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
||||||
create_metadata_for_llava,
|
LlavaProcessor,
|
||||||
dummy_mm_kwargs_for_llava,
|
|
||||||
get_max_llava_image_tokens)
|
get_max_llava_image_tokens)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
|
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||||
dummy_mm_kwargs_for_llava)
|
|
||||||
class MyLlava(LlavaForConditionalGeneration):
|
class MyLlava(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -22,10 +22,11 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||||
from vllm.multimodal.processing import (InputProcessingContext,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
InputProcessingContext,
|
||||||
ModalityProcessingMetadata,
|
ModalityProcessingMetadata,
|
||||||
MultiModalProcessingMetadata,
|
MultiModalProcessingMetadata,
|
||||||
MultiModalProcessor, PromptReplacement)
|
PromptReplacement)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||||
@ -163,7 +164,13 @@ def create_metadata_for_llava(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LlavaProcessor(MultiModalProcessor):
|
class LlavaProcessor(BaseMultiModalProcessor):
|
||||||
|
|
||||||
|
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ctx=ctx,
|
||||||
|
metadata=create_metadata_for_llava(ctx),
|
||||||
|
)
|
||||||
|
|
||||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||||
if getattr(hf_processor, "__is_patched__", False):
|
if getattr(hf_processor, "__is_patched__", False):
|
||||||
@ -193,7 +200,30 @@ class LlavaProcessor(MultiModalProcessor):
|
|||||||
self,
|
self,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> MultiModalKwargs:
|
) -> MultiModalKwargs:
|
||||||
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
|
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
num_images = mm_counts["image"]
|
||||||
|
|
||||||
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
|
data = dummy_image_for_clip(vision_config, num_images)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
data = dummy_image_for_siglip(vision_config, num_images)
|
||||||
|
elif isinstance(vision_config, PixtralVisionConfig):
|
||||||
|
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||||
|
else:
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
hf_processor = self._get_hf_processor()
|
||||||
|
image_processor = hf_processor.image_processor # type: ignore
|
||||||
|
hf_inputs = image_processor.preprocess(data['image'],
|
||||||
|
return_tensors="pt")
|
||||||
|
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||||
|
|
||||||
|
return MultiModalKwargs(
|
||||||
|
**hf_inputs,
|
||||||
|
is_pixtral=torch.tensor(is_pixtral),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlavaLikeConfig(Protocol):
|
class LlavaLikeConfig(Protocol):
|
||||||
@ -277,10 +307,7 @@ def init_vision_tower_for_llava(
|
|||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
|
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||||
ctx=ctx,
|
|
||||||
metadata=create_metadata_for_llava(ctx),
|
|
||||||
))
|
|
||||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
@ -559,3 +586,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
return loader.load_weights(weights)
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class MantisProcessor(LlavaProcessor):
|
||||||
|
|
||||||
|
def _get_hf_processor(self) -> ProcessorMixin:
|
||||||
|
try:
|
||||||
|
from mantis.models.mllava import MLlavaProcessor
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"You need to `pip install "
|
||||||
|
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
|
||||||
|
"to use this model") from exc
|
||||||
|
|
||||||
|
processor = MLlavaProcessor.from_pretrained(
|
||||||
|
self.ctx.model_config.tokenizer)
|
||||||
|
assert isinstance(processor, ProcessorMixin)
|
||||||
|
return processor
|
||||||
|
|
||||||
|
|
||||||
|
# To use this model, please use
|
||||||
|
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
|
||||||
|
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||||
|
pass
|
||||||
|
@ -152,6 +152,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||||
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
||||||
|
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
|
||||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||||
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
|
||||||
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
|
||||||
|
@ -529,9 +529,9 @@ def iter_placeholders(
|
|||||||
yield placeholder
|
yield placeholder
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessor(ABC):
|
class BaseMultiModalProcessor(ABC):
|
||||||
"""
|
"""
|
||||||
Helper class to process multi-modal inputs to be used in vLLM.
|
Abstract base class to process multi-modal inputs to be used in vLLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -15,7 +15,7 @@ from .audio import AudioPlugin
|
|||||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||||
from .image import ImagePlugin
|
from .image import ImagePlugin
|
||||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||||
from .processing import MultiModalProcessingMetadata, MultiModalProcessor
|
from .processing import BaseMultiModalProcessor
|
||||||
from .video import VideoPlugin
|
from .video import VideoPlugin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ logger = init_logger(__name__)
|
|||||||
N = TypeVar("N", bound=Type[nn.Module])
|
N = TypeVar("N", bound=Type[nn.Module])
|
||||||
|
|
||||||
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
|
MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
|
||||||
MultiModalProcessor]
|
BaseMultiModalProcessor]
|
||||||
"""
|
"""
|
||||||
Constructs a :class:`MultiModalProcessor` instance from the context.
|
Constructs a :class:`MultiModalProcessor` instance from the context.
|
||||||
|
|
||||||
@ -311,41 +311,6 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def register_processor_by_metadata(
|
|
||||||
self,
|
|
||||||
metadata_factory: Callable[[InputProcessingContext],
|
|
||||||
MultiModalProcessingMetadata],
|
|
||||||
get_dummy_mm_kwargs: Callable[
|
|
||||||
[InputProcessingContext, Mapping[str, int]], MultiModalKwargs],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Convenience method to register a multi-modal processor to a model class
|
|
||||||
according to a function that constructs its metadata.
|
|
||||||
|
|
||||||
When the model receives multi-modal data, the provided function is
|
|
||||||
invoked to transform the data into a dictionary of model inputs.
|
|
||||||
|
|
||||||
See also:
|
|
||||||
- :ref:`input_processing_pipeline`
|
|
||||||
- :ref:`enabling_multimodal_inputs`
|
|
||||||
"""
|
|
||||||
|
|
||||||
class ConcreteMultiModalProcessor(MultiModalProcessor):
|
|
||||||
|
|
||||||
def _get_dummy_mm_kwargs(
|
|
||||||
self,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
) -> MultiModalKwargs:
|
|
||||||
return get_dummy_mm_kwargs(self.ctx, mm_counts)
|
|
||||||
|
|
||||||
def factory(ctx: InputProcessingContext):
|
|
||||||
return ConcreteMultiModalProcessor(
|
|
||||||
ctx=ctx,
|
|
||||||
metadata=metadata_factory(ctx),
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.register_processor(factory)
|
|
||||||
|
|
||||||
def has_processor(self, model_config: "ModelConfig") -> bool:
|
def has_processor(self, model_config: "ModelConfig") -> bool:
|
||||||
"""
|
"""
|
||||||
Test whether a multi-modal processor is defined for a specific model.
|
Test whether a multi-modal processor is defined for a specific model.
|
||||||
@ -360,7 +325,7 @@ class MultiModalRegistry:
|
|||||||
self,
|
self,
|
||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> MultiModalProcessor:
|
) -> BaseMultiModalProcessor:
|
||||||
"""
|
"""
|
||||||
Create a multi-modal processor for a specific model and tokenizer.
|
Create a multi-modal processor for a specific model and tokenizer.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user