[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f12141170a
commit
2a0596bc48
@ -4,24 +4,17 @@ from functools import partial
|
|||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pqdm.threads import pqdm
|
from pqdm.threads import pqdm
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.parse import ImageSize
|
from vllm.multimodal.parse import ImageSize
|
||||||
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from ....utils import build_model_context
|
from ....utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
|
||||||
@pytest.fixture()
|
|
||||||
def processor_for_llava_next():
|
|
||||||
from vllm.model_executor.models.llava_next import (
|
|
||||||
LlavaNextMultiModalProcessor)
|
|
||||||
return LlavaNextMultiModalProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_prompt_replacements_one(
|
def _validate_image_prompt_replacements_one(
|
||||||
processor,
|
processor: BaseMultiModalProcessor,
|
||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
failed_size_excs: list[tuple[ImageSize, Exception]],
|
failed_size_excs: list[tuple[ImageSize, Exception]],
|
||||||
image_size: ImageSize,
|
image_size: ImageSize,
|
||||||
@ -78,20 +71,17 @@ def _test_image_prompt_replacements(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
def test_processor_prompt_replacements_regression(
|
def test_processor_prompt_replacements_regression(model_id, num_imgs):
|
||||||
processor_for_llava_next,
|
|
||||||
model_id: str,
|
|
||||||
num_imgs: int,
|
|
||||||
):
|
|
||||||
ctx = build_model_context(
|
ctx = build_model_context(
|
||||||
model_name=model_id,
|
model_name=model_id,
|
||||||
tokenizer_name=model_id,
|
tokenizer_name=model_id,
|
||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
ctx.model_config,
|
||||||
processor = processor_for_llava_next(ctx)
|
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||||
(488, 183), (2560, 1669)]
|
(488, 183), (2560, 1669)]
|
||||||
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
|
|||||||
"Comment this out to run it manually.")
|
"Comment this out to run it manually.")
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
@pytest.mark.parametrize("num_imgs", [1])
|
@pytest.mark.parametrize("num_imgs", [1])
|
||||||
def test_processor_prompt_replacements_all(
|
def test_processor_prompt_replacements_all(model_id, num_imgs):
|
||||||
processor_for_llava_next,
|
|
||||||
model_id: str,
|
|
||||||
num_imgs: int,
|
|
||||||
):
|
|
||||||
ctx = build_model_context(
|
ctx = build_model_context(
|
||||||
model_name=model_id,
|
model_name=model_id,
|
||||||
tokenizer_name=model_id,
|
tokenizer_name=model_id,
|
||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
ctx.model_config,
|
||||||
processor = processor_for_llava_next(ctx)
|
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
seen_aspect_ratios = set[float]()
|
seen_aspect_ratios = set[float]()
|
||||||
image_sizes = list[ImageSize]()
|
image_sizes = list[ImageSize]()
|
||||||
|
@ -4,24 +4,17 @@ from functools import partial
|
|||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pqdm.threads import pqdm
|
from pqdm.threads import pqdm
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.parse import ImageSize
|
from vllm.multimodal.parse import ImageSize
|
||||||
|
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from ....utils import build_model_context
|
from ....utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
|
||||||
@pytest.fixture()
|
|
||||||
def processor_for_llava_onevision():
|
|
||||||
from vllm.model_executor.models.llava_onevision import (
|
|
||||||
LlavaOnevisionMultiModalProcessor)
|
|
||||||
return LlavaOnevisionMultiModalProcessor
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_prompt_replacements_one(
|
def _validate_image_prompt_replacements_one(
|
||||||
processor,
|
processor: BaseMultiModalProcessor,
|
||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
failed_size_excs: list[tuple[ImageSize, Exception]],
|
failed_size_excs: list[tuple[ImageSize, Exception]],
|
||||||
image_size: ImageSize,
|
image_size: ImageSize,
|
||||||
@ -77,20 +70,17 @@ def _test_image_prompt_replacements(
|
|||||||
@pytest.mark.parametrize("model_id",
|
@pytest.mark.parametrize("model_id",
|
||||||
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
def test_processor_prompt_replacements_regression(
|
def test_processor_prompt_replacements_regression(model_id, num_imgs):
|
||||||
processor_for_llava_onevision,
|
|
||||||
model_id: str,
|
|
||||||
num_imgs: int,
|
|
||||||
):
|
|
||||||
ctx = build_model_context(
|
ctx = build_model_context(
|
||||||
model_name=model_id,
|
model_name=model_id,
|
||||||
tokenizer_name=model_id,
|
tokenizer_name=model_id,
|
||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
ctx.model_config,
|
||||||
processor = processor_for_llava_onevision(ctx)
|
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||||
(488, 183), (2560, 1669)]
|
(488, 183), (2560, 1669)]
|
||||||
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
|
|||||||
@pytest.mark.parametrize("model_id",
|
@pytest.mark.parametrize("model_id",
|
||||||
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
|
||||||
@pytest.mark.parametrize("num_imgs", [1])
|
@pytest.mark.parametrize("num_imgs", [1])
|
||||||
def test_processor_prompt_replacements_all(
|
def test_processor_prompt_replacements_all(model_id, num_imgs):
|
||||||
processor_for_llava_onevision,
|
|
||||||
model_id: str,
|
|
||||||
num_imgs: int,
|
|
||||||
):
|
|
||||||
ctx = build_model_context(
|
ctx = build_model_context(
|
||||||
model_name=model_id,
|
model_name=model_id,
|
||||||
tokenizer_name=model_id,
|
tokenizer_name=model_id,
|
||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
ctx.model_config,
|
||||||
processor = processor_for_llava_onevision(ctx)
|
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
seen_aspect_ratios = set[float]()
|
seen_aspect_ratios = set[float]()
|
||||||
image_sizes = list[ImageSize]()
|
image_sizes = list[ImageSize]()
|
||||||
|
@ -1,21 +1,13 @@
|
|||||||
"""Tests for phi3v's multimodal preprocessing kwargs."""
|
"""Tests for phi3v's multimodal preprocessing kwargs."""
|
||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from .....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ....utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
|
||||||
@pytest.fixture()
|
|
||||||
def processor_for_phi3v():
|
|
||||||
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
|
|
||||||
return Phi3VMultiModalProcessor
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -29,7 +21,6 @@ def processor_for_phi3v():
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
def test_processor_override(
|
def test_processor_override(
|
||||||
processor_for_phi3v,
|
|
||||||
image_assets: _ImageAssets,
|
image_assets: _ImageAssets,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
mm_processor_kwargs: dict[str, int],
|
mm_processor_kwargs: dict[str, int],
|
||||||
@ -37,21 +28,26 @@ def test_processor_override(
|
|||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
):
|
):
|
||||||
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
"""Ensure input_processor_for_phi3v handles num_crops properly."""
|
||||||
|
# Avoid initializing CUDA early
|
||||||
|
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||||
|
|
||||||
ctx = build_model_context(
|
ctx = build_model_context(
|
||||||
model_name=model_id,
|
model_name=model_id,
|
||||||
tokenizer_name=model_id,
|
tokenizer_name=model_id,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
|
ctx.model_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the image str / prompt based on the number of images we pass
|
# Build the image str / prompt based on the number of images we pass
|
||||||
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
|
||||||
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
|
||||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||||
|
|
||||||
processor = processor_for_phi3v(ctx)
|
|
||||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||||
|
|
||||||
# Ensure we have the right number of placeholders per num_crops size
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
|
@ -1,19 +1,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
from .....conftest import _ImageAssets
|
from .....conftest import _ImageAssets
|
||||||
from ....utils import build_model_context
|
from ....utils import build_model_context
|
||||||
|
|
||||||
|
|
||||||
# Fixtures lazy import to avoid initializing CUDA during test collection
|
|
||||||
@pytest.fixture()
|
|
||||||
def processor_for_qwen2_vl():
|
|
||||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
|
|
||||||
return Qwen2VLMultiModalProcessor
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -24,7 +17,6 @@ def processor_for_qwen2_vl():
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||||
def test_processor_override(
|
def test_processor_override(
|
||||||
processor_for_qwen2_vl,
|
|
||||||
image_assets: _ImageAssets,
|
image_assets: _ImageAssets,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
mm_processor_kwargs: dict[str, object],
|
mm_processor_kwargs: dict[str, object],
|
||||||
@ -39,18 +31,20 @@ def test_processor_override(
|
|||||||
mm_processor_kwargs=None,
|
mm_processor_kwargs=None,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
|
ctx.model_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the image str / prompt based on the number of images we pass
|
# Build the image str / prompt based on the number of images we pass
|
||||||
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
|
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
|
||||||
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
|
||||||
|
|
||||||
processor = processor_for_qwen2_vl(ctx)
|
|
||||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||||
|
|
||||||
# Ensure we have the right number of placeholders per num_crops size
|
# Ensure we have the right number of placeholders per num_crops size
|
||||||
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
|
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
|
||||||
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
|
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
|
||||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
|
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape
|
||||||
|
@ -10,12 +10,17 @@ from PIL import Image
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
# yapf conflicts with isort for this block
|
||||||
_PlaceholderInfo, find_mm_placeholders,
|
# yapf: disable
|
||||||
|
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
|
||||||
|
PromptReplacement,
|
||||||
|
find_mm_placeholders,
|
||||||
find_text_matches, find_token_matches,
|
find_text_matches, find_token_matches,
|
||||||
iter_token_matches,
|
iter_token_matches,
|
||||||
replace_text_matches,
|
replace_text_matches,
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import full_groupby
|
from vllm.utils import full_groupby
|
||||||
@ -431,7 +436,7 @@ def test_find_replace_tokens(
|
|||||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
{
|
{
|
||||||
"pattern_1": [
|
"pattern_1": [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=6,
|
start_idx=6,
|
||||||
@ -445,13 +450,13 @@ def test_find_replace_tokens(
|
|||||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||||
{
|
{
|
||||||
"pattern_1": [
|
"pattern_1": [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=1,
|
start_idx=1,
|
||||||
replacement=[32000, 32000],
|
replacement=[32000, 32000],
|
||||||
),
|
),
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=1,
|
item_idx=1,
|
||||||
start_idx=5,
|
start_idx=5,
|
||||||
@ -459,7 +464,7 @@ def test_find_replace_tokens(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_3": [
|
"pattern_3": [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_3",
|
modality="pattern_3",
|
||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=7,
|
start_idx=7,
|
||||||
@ -472,13 +477,13 @@ def test_find_replace_tokens(
|
|||||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||||
{
|
{
|
||||||
"pattern_1": [
|
"pattern_1": [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=1,
|
start_idx=1,
|
||||||
replacement=[32000, 32000],
|
replacement=[32000, 32000],
|
||||||
),
|
),
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_1",
|
modality="pattern_1",
|
||||||
item_idx=1,
|
item_idx=1,
|
||||||
start_idx=3,
|
start_idx=3,
|
||||||
@ -486,7 +491,7 @@ def test_find_replace_tokens(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
"pattern_3": [
|
"pattern_3": [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality="pattern_3",
|
modality="pattern_3",
|
||||||
item_idx=0,
|
item_idx=0,
|
||||||
start_idx=6,
|
start_idx=6,
|
||||||
@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
|||||||
revision=None,
|
revision=None,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
)
|
)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
||||||
|
|
||||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(
|
|
||||||
model_config,
|
model_config,
|
||||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||||
)
|
)
|
||||||
|
profiler = MultiModalProfiler(processor)
|
||||||
processor = processor_factory(ctx, cache=None)
|
|
||||||
profiler = processor.profiling_info
|
|
||||||
|
|
||||||
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
||||||
profiler.get_supported_mm_limits = mock_supported_mm_limits
|
processor.info.get_supported_mm_limits = mock_supported_mm_limits
|
||||||
|
|
||||||
if is_valid:
|
if is_valid:
|
||||||
exc_ctx = nullcontext()
|
exc_ctx = nullcontext()
|
||||||
@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
|||||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||||
|
|
||||||
with exc_ctx:
|
with exc_ctx:
|
||||||
profiler.get_mm_limits()
|
profiler.get_dummy_data(model_config.max_model_len)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
|||||||
revision=None,
|
revision=None,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
)
|
)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
||||||
|
|
||||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||||
ctx = InputProcessingContext(
|
|
||||||
model_config,
|
model_config,
|
||||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||||
)
|
)
|
||||||
|
|
||||||
processor = processor_factory(ctx, cache=None)
|
|
||||||
|
|
||||||
rng = np.random.RandomState(0)
|
rng = np.random.RandomState(0)
|
||||||
image = _rand_img(rng, min_wh=128, max_wh=256)
|
image = _rand_img(rng, min_wh=128, max_wh=256)
|
||||||
if num_images == 0:
|
if num_images == 0:
|
||||||
@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
|
|||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
)
|
)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
|
||||||
|
|
||||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
ctx = InputProcessingContext(
|
ctx = InputProcessingContext(
|
||||||
model_config,
|
model_config,
|
||||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||||
@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
|
|||||||
# Ensure that it can fit all of the data
|
# Ensure that it can fit all of the data
|
||||||
cache = ProcessingCache(capacity=1 << 30)
|
cache = ProcessingCache(capacity=1 << 30)
|
||||||
|
|
||||||
baseline_processor = processor_factory(ctx, cache=None)
|
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||||
cached_processor = processor_factory(ctx, cache=cache)
|
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||||
|
dummy_inputs = baseline_processor.dummy_inputs
|
||||||
|
|
||||||
rng = np.random.RandomState(0)
|
rng = np.random.RandomState(0)
|
||||||
|
|
||||||
@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||||
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
|
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||||
model_config.max_model_len,
|
model_config.max_model_len,
|
||||||
mm_counts,
|
mm_counts,
|
||||||
).prompt_text
|
).prompt_text
|
||||||
|
@ -2,13 +2,17 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
|
||||||
LlavaMultiModalProcessor)
|
LlavaForConditionalGeneration,
|
||||||
|
LlavaMultiModalProcessor,
|
||||||
|
LlavaProcessingInfo)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
|
||||||
|
info=LlavaProcessingInfo,
|
||||||
|
dummy_inputs=LlavaDummyInputsBuilder)
|
||||||
class MyLlava(LlavaForConditionalGeneration):
|
class MyLlava(LlavaForConditionalGeneration):
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
|
@ -7,7 +7,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
|
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.utils import print_info_once, print_warning_once
|
from vllm.utils import print_info_once, print_warning_once
|
||||||
|
@ -323,6 +323,7 @@ class InputRegistry:
|
|||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.model_loader import get_model_architecture
|
from vllm.model_executor.model_loader import get_model_architecture
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.utils import cached_get_tokenizer
|
||||||
|
|
||||||
if mm_registry.has_processor(model_config):
|
if mm_registry.has_processor(model_config):
|
||||||
@ -331,7 +332,8 @@ class InputRegistry:
|
|||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
processor = mm_registry.create_processor(model_config, tokenizer)
|
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||||
dummy_data = processor.get_dummy_data(seq_len)
|
profiler = MultiModalProfiler(processor)
|
||||||
|
dummy_data = profiler.get_dummy_data(seq_len)
|
||||||
else:
|
else:
|
||||||
model_cls, _ = get_model_architecture(model_config)
|
model_cls, _ = get_model_architecture(model_config)
|
||||||
if is_encoder_data:
|
if is_encoder_data:
|
||||||
|
@ -23,10 +23,10 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
|
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
|
||||||
AriaVisionConfig)
|
AriaVisionConfig)
|
||||||
@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AriaProcessingMixin(ProcessingMixin):
|
class AriaProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config()
|
return self.ctx.get_hf_config()
|
||||||
|
|
||||||
def _get_vision_config(self) -> AriaVisionConfig:
|
def get_vision_config(self) -> AriaVisionConfig:
|
||||||
return self._get_hf_config().vision_config
|
return self.get_hf_config().vision_config
|
||||||
|
|
||||||
def _get_num_image_tokens(self) -> int:
|
|
||||||
hf_config = self._get_hf_config()
|
|
||||||
return max(hf_config.projector_patch_to_query_dict.values())
|
|
||||||
|
|
||||||
|
|
||||||
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None}
|
return {"image": None}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
return {"image": self._get_num_image_tokens()}
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
return max(hf_config.projector_patch_to_query_dict.values())
|
||||||
|
|
||||||
|
|
||||||
|
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
vision_config = self._get_vision_config()
|
vision_config = self.info.get_vision_config()
|
||||||
|
|
||||||
max_image_size = vision_config.image_size
|
max_image_size = vision_config.image_size
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
@ -483,7 +483,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
|
|||||||
num_images=num_images)
|
num_images=num_images)
|
||||||
}
|
}
|
||||||
|
|
||||||
hf_processor = self._get_hf_processor()
|
hf_processor = self.info.get_hf_processor()
|
||||||
image_token: str = hf_processor.image_token # type: ignore
|
image_token: str = hf_processor.image_token # type: ignore
|
||||||
|
|
||||||
return ProcessorInputs(
|
return ProcessorInputs(
|
||||||
@ -492,10 +492,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
|
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return AriaProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -513,10 +510,10 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
image_token_id = hf_config.image_token_index
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
num_image_tokens = self._get_num_image_tokens()
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -527,7 +524,9 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
|
||||||
|
info=AriaProcessingInfo,
|
||||||
|
dummy_inputs=AriaDummyInputsBuilder)
|
||||||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
"""
|
"""
|
||||||
Aria model for conditional generation tasks.
|
Aria model for conditional generation tasks.
|
||||||
|
@ -17,10 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
NestedTensors, PlaceholderRange)
|
NestedTensors, PlaceholderRange)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .blip import BlipVisionModel
|
from .blip import BlipVisionModel
|
||||||
@ -397,30 +397,30 @@ class Blip2QFormerModel(nn.Module):
|
|||||||
return sequence_output
|
return sequence_output
|
||||||
|
|
||||||
|
|
||||||
class Blip2ProcessingMixin(ProcessingMixin):
|
class Blip2ProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(Blip2Config)
|
return self.ctx.get_hf_config(Blip2Config)
|
||||||
|
|
||||||
def _get_num_image_tokens(self) -> int:
|
|
||||||
hf_config = self._get_hf_config()
|
|
||||||
return hf_config.num_query_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": 1}
|
return {"image": 1}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
return {"image": self._get_num_image_tokens()}
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
return hf_config.num_query_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
max_image_size = vision_config.image_size
|
max_image_size = vision_config.image_size
|
||||||
@ -439,10 +439,7 @@ class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
|
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return Blip2ProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -460,7 +457,7 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
num_image_tokens = self._get_num_image_tokens()
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -491,7 +488,9 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
|
||||||
|
info=Blip2ProcessingInfo,
|
||||||
|
dummy_inputs=Blip2DummyInputsBuilder)
|
||||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
@ -30,10 +30,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
NestedTensors, PlaceholderRange)
|
NestedTensors, PlaceholderRange)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict):
|
|||||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||||
|
|
||||||
|
|
||||||
class ChameleonProcessingMixin(ProcessingMixin):
|
class ChameleonProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(ChameleonConfig)
|
return self.ctx.get_hf_config(ChameleonConfig)
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||||
|
|
||||||
def _get_num_image_tokens(self) -> int:
|
|
||||||
processor = self._get_hf_processor()
|
|
||||||
return processor.image_seq_length
|
|
||||||
|
|
||||||
|
|
||||||
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": 1}
|
return {"image": 1}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
return {"image": self._get_num_image_tokens()}
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self) -> int:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
return processor.image_seq_length
|
||||||
|
|
||||||
|
|
||||||
|
class ChameleonDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
config = self._get_hf_config()
|
config = self.info.get_hf_config()
|
||||||
|
|
||||||
width = height = config.vq_config.resolution
|
width = height = config.vq_config.resolution
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
@ -93,11 +94,8 @@ class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
|
class ChameleonMultiModalProcessor(
|
||||||
BaseMultiModalProcessor):
|
BaseMultiModalProcessor[ChameleonProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return ChameleonProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -112,7 +110,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -120,7 +118,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
|
|||||||
target="<image>",
|
target="<image>",
|
||||||
replacement="".join([
|
replacement="".join([
|
||||||
processor.image_start_token,
|
processor.image_start_token,
|
||||||
processor.image_token * self._get_num_image_tokens(),
|
processor.image_token * self.info.get_num_image_tokens(),
|
||||||
processor.image_end_token,
|
processor.image_end_token,
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
@ -916,7 +914,10 @@ class ChameleonModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
ChameleonMultiModalProcessor,
|
||||||
|
info=ChameleonProcessingInfo,
|
||||||
|
dummy_inputs=ChameleonDummyInputsBuilder)
|
||||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
NestedTensors, PlaceholderRange)
|
NestedTensors, PlaceholderRange)
|
||||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||||
|
MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FuyuProcessingMixin(ProcessingMixin):
|
class FuyuProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(FuyuConfig)
|
return self.ctx.get_hf_config(FuyuConfig)
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||||
|
|
||||||
def _get_image_processor(self) -> FuyuImageProcessor:
|
def get_image_processor(self) -> FuyuImageProcessor:
|
||||||
return self._get_hf_processor().image_processor
|
return self.get_hf_processor().image_processor
|
||||||
|
|
||||||
def _get_image_feature_grid_size(
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": 1}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
)
|
||||||
|
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||||
|
|
||||||
|
return {"image": max_image_tokens}
|
||||||
|
|
||||||
|
def get_image_feature_grid_size(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
image_processor = self._get_image_processor()
|
image_processor = self.get_image_processor()
|
||||||
target_width = image_processor.size["width"]
|
target_width = image_processor.size["width"]
|
||||||
target_height = image_processor.size["height"]
|
target_height = image_processor.size["height"]
|
||||||
|
|
||||||
@ -97,34 +111,21 @@ class FuyuProcessingMixin(ProcessingMixin):
|
|||||||
nrows = math.ceil(image_height / 30)
|
nrows = math.ceil(image_height / 30)
|
||||||
return ncols, nrows
|
return ncols, nrows
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
return {"image": 1}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
|
||||||
max_ncols, max_nrows = self._get_image_feature_grid_size(
|
|
||||||
image_width=target_width,
|
|
||||||
image_height=target_height,
|
|
||||||
)
|
|
||||||
max_image_tokens = (max_ncols + 1) * max_nrows
|
|
||||||
|
|
||||||
return {"image": max_image_tokens}
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
image_processor = self._get_image_processor()
|
|
||||||
return ImageSize(width=image_processor.size["width"],
|
return ImageSize(width=image_processor.size["width"],
|
||||||
height=image_processor.size["height"])
|
height=image_processor.size["height"])
|
||||||
|
|
||||||
|
|
||||||
|
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
@ -140,10 +141,7 @@ class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return FuyuProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
@ -156,7 +154,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
|||||||
# Avoid warning from HF logger for text-only input
|
# Avoid warning from HF logger for text-only input
|
||||||
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
|
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
|
||||||
# Tokenizer won't add boa_token_id by default, we add it manually.
|
# Tokenizer won't add boa_token_id by default, we add it manually.
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
|
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
|
||||||
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
|
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
|
||||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||||
@ -196,10 +194,10 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
bos_token_id = hf_config.bos_token_id
|
bos_token_id = hf_config.bos_token_id
|
||||||
|
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
eot_token_id = tokenizer.bos_token_id
|
eot_token_id = tokenizer.bos_token_id
|
||||||
assert isinstance(eot_token_id, int)
|
assert isinstance(eot_token_id, int)
|
||||||
|
|
||||||
@ -207,7 +205,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
|||||||
images = mm_items.get_items("image", ImageProcessorItems)
|
images = mm_items.get_items("image", ImageProcessorItems)
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
|
|
||||||
ncols, nrows = self._get_image_feature_grid_size(
|
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
)
|
)
|
||||||
@ -244,7 +242,9 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
|
||||||
|
info=FuyuProcessingInfo,
|
||||||
|
dummy_inputs=FuyuDummyInputsBuilder)
|
||||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||||
Protocol, Set, Tuple, TypedDict, Union)
|
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
ImageSize)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingCache,
|
BaseProcessingInfo, ProcessingCache,
|
||||||
ProcessingMixin, PromptReplacement)
|
PromptReplacement)
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
|
|||||||
image_token: Final[str]
|
image_token: Final[str]
|
||||||
|
|
||||||
|
|
||||||
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
def get_hf_config(self) -> LlavaLikeConfig:
|
||||||
return self.ctx.get_hf_config(LlavaConfig)
|
return self.ctx.get_hf_config(LlavaConfig)
|
||||||
|
|
||||||
def _get_vision_encoder_info(self):
|
def get_vision_encoder_info(self):
|
||||||
return get_vision_encoder_info(self._get_hf_config())
|
return get_vision_encoder_info(self.get_hf_config())
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_hf_processor(self) -> LlavaLikeProcessor:
|
def get_hf_processor(self) -> LlavaLikeProcessor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_num_image_tokens(
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
self,
|
return {"image": None}
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
) -> int:
|
|
||||||
hf_config = self._get_hf_config()
|
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
|
||||||
|
|
||||||
return self._apply_feature_select_strategy(
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
hf_config.vision_feature_select_strategy,
|
return {"image": self.get_max_image_tokens()}
|
||||||
vision_encoder_info.get_num_image_tokens(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_feature_select_strategy(
|
def _apply_feature_select_strategy(
|
||||||
self,
|
self,
|
||||||
@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
|||||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
|
|
||||||
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
return self._apply_feature_select_strategy(
|
||||||
|
hf_config.vision_feature_select_strategy,
|
||||||
|
vision_encoder_info.get_num_image_tokens(
|
||||||
|
image_width=image_width,
|
||||||
|
image_height=image_height,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
return {"image": None}
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
return {"image": self._get_max_image_tokens()}
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
|
||||||
width = height = vision_encoder_info.get_image_size()
|
width = height = vision_encoder_info.get_image_size()
|
||||||
return ImageSize(width=width, height=height)
|
return ImageSize(width=width, height=height)
|
||||||
|
|
||||||
def _get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
return self._get_num_image_tokens(
|
return self.get_num_image_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
|||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
processor = self._get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
image_token = processor.image_token
|
image_token = processor.image_token
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
|
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||||
|
|
||||||
|
|
||||||
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
|
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
|
||||||
BaseMultiModalProcessor):
|
|
||||||
|
|
||||||
# Copied from BaseMultiModalProcessor
|
|
||||||
@abstractmethod
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# Copied from BaseMultiModalProcessor
|
# Copied from BaseMultiModalProcessor
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
image_token_id = hf_config.image_token_index
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
def get_replacement(item_idx: int):
|
||||||
@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
|||||||
num_image_tokens = images.get_feature_size(item_idx)
|
num_image_tokens = images.get_feature_size(item_idx)
|
||||||
else:
|
else:
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
num_image_tokens = self._get_num_image_tokens(
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
)
|
)
|
||||||
@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
class LlavaMultiModalProcessor(
|
||||||
|
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return LlavaProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
|
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||||
|
|
||||||
|
|
||||||
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
|
class PixtralHFMultiModalProcessor(
|
||||||
pass
|
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
|
||||||
|
|
||||||
|
|
||||||
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
|
||||||
BaseMultiModalProcessor):
|
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return PixtralHFProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
image_token_id = hf_config.image_token_index
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
processor = self._get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
image_token = processor.image_token
|
image_token = processor.image_token
|
||||||
image_break_token = processor.image_break_token
|
image_break_token = processor.image_break_token
|
||||||
image_end_token = processor.image_end_token
|
image_end_token = processor.image_end_token
|
||||||
@ -363,26 +348,40 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_llava_or_pixtral_hf_info(
|
||||||
|
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
|
||||||
|
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||||
|
|
||||||
|
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||||
|
return PixtralHFProcessingInfo(ctx)
|
||||||
|
|
||||||
|
return LlavaProcessingInfo(ctx)
|
||||||
|
|
||||||
|
|
||||||
def _build_llava_or_pixtral_hf_processor(
|
def _build_llava_or_pixtral_hf_processor(
|
||||||
ctx: InputProcessingContext,
|
info: _I,
|
||||||
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||||
*,
|
*,
|
||||||
cache: Optional[ProcessingCache] = None,
|
cache: Optional[ProcessingCache] = None,
|
||||||
enable_sanity_checks: bool = True,
|
enable_sanity_checks: bool = True,
|
||||||
) -> BaseMultiModalProcessor:
|
) -> BaseMultiModalProcessor:
|
||||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
if isinstance(info, PixtralHFProcessingInfo):
|
||||||
|
|
||||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
|
||||||
return PixtralHFMultiModalProcessor(
|
return PixtralHFMultiModalProcessor(
|
||||||
ctx,
|
info,
|
||||||
|
dummy_inputs, # type: ignore
|
||||||
cache=cache,
|
cache=cache,
|
||||||
enable_sanity_checks=enable_sanity_checks,
|
enable_sanity_checks=enable_sanity_checks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaMultiModalProcessor(
|
if isinstance(info, LlavaProcessingInfo):
|
||||||
ctx,
|
return LlavaMultiModalProcessor(
|
||||||
cache=cache,
|
info,
|
||||||
enable_sanity_checks=enable_sanity_checks,
|
dummy_inputs, # type: ignore
|
||||||
)
|
cache=cache,
|
||||||
|
enable_sanity_checks=enable_sanity_checks,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(type(info))
|
||||||
|
|
||||||
|
|
||||||
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||||
@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
|
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
|
||||||
|
info=_build_llava_or_pixtral_hf_info,
|
||||||
|
dummy_inputs=LlavaDummyInputsBuilder)
|
||||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
|||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> MultiModalInputsV2:
|
) -> MultiModalInputsV2:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
image_token_id = hf_config.image_token_index
|
image_token_id = hf_config.image_token_index
|
||||||
|
|
||||||
# Assume that it doesn't depend on the image size
|
# Assume that it doesn't depend on the image size
|
||||||
num_image_tokens = self._get_num_image_tokens(
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
image_width=-1,
|
image_width=-1,
|
||||||
image_height=-1,
|
image_height=-1,
|
||||||
)
|
)
|
||||||
@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
|||||||
|
|
||||||
# To use this model, please use
|
# To use this model, please use
|
||||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
||||||
|
info=LlavaProcessingInfo,
|
||||||
|
dummy_inputs=LlavaDummyInputsBuilder)
|
||||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||||
pass
|
pass
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||||
Protocol, Set, Tuple, TypedDict, Union)
|
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||||
from vllm.multimodal.parse import ImageSize
|
from vllm.multimodal.parse import ImageSize
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
|
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
|
||||||
BaseLlavaProfilingInfo, LlavaLikeConfig,
|
LlavaDummyInputsBuilder, LlavaLikeConfig,
|
||||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||||
@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
|
|||||||
image_grid_pinpoints: Final[list[list[int]]]
|
image_grid_pinpoints: Final[list[list[int]]]
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self) -> LlavaNextLikeConfig:
|
def get_hf_config(self) -> LlavaNextLikeConfig:
|
||||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||||
|
|
||||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
|
||||||
def _get_num_image_tokens(
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
|
|
||||||
base_feature_size = self._apply_feature_select_strategy(
|
base_feature_size = self._apply_feature_select_strategy(
|
||||||
hf_config.vision_feature_select_strategy,
|
hf_config.vision_feature_select_strategy,
|
||||||
@ -140,16 +140,13 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
|||||||
|
|
||||||
return (unpadded_features, newline_features)
|
return (unpadded_features, newline_features)
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
|
hf_config = self.get_hf_config()
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
hf_config = self._get_hf_config()
|
|
||||||
|
|
||||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||||
for (height, width) in hf_config.image_grid_pinpoints:
|
for (height, width) in hf_config.image_grid_pinpoints:
|
||||||
feat_size = self._get_num_image_tokens(image_width=width,
|
feat_size = self.get_num_image_tokens(image_width=width,
|
||||||
image_height=height)
|
image_height=height)
|
||||||
if feat_size > largest_feature_size:
|
if feat_size > largest_feature_size:
|
||||||
largest_feature_size = feat_size
|
largest_feature_size = feat_size
|
||||||
largest_feature_pinpoint = ImageSize(width=width,
|
largest_feature_pinpoint = ImageSize(width=width,
|
||||||
@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
|
|||||||
return largest_feature_pinpoint
|
return largest_feature_pinpoint
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
|
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
|
||||||
BaseLlavaMultiModalProcessor):
|
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return LlavaNextProfilingInfo(self.ctx)
|
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
|
||||||
|
|
||||||
|
# Copied from BaseMultiModalProcessor
|
||||||
|
@abstractmethod
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextMultiModalProcessor(
|
||||||
|
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
|
||||||
|
info=LlavaNextProcessingInfo,
|
||||||
|
dummy_inputs=LlavaDummyInputsBuilder)
|
||||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
|
@ -17,12 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
|
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||||
VideoProcessorItems)
|
VideoEmbeddingItems, VideoProcessorItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoProcessingMixin(ProcessingMixin):
|
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
||||||
|
|
||||||
def _get_vision_encoder_info(self):
|
def get_vision_encoder_info(self):
|
||||||
return get_vision_encoder_info(self._get_hf_config())
|
return get_vision_encoder_info(self.get_hf_config())
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"video": 1}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
max_video_tokens = self.get_num_video_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"video": max_video_tokens}
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
|
width = height = vision_encoder_info.get_image_size()
|
||||||
|
return ImageSize(width=width, height=height)
|
||||||
|
|
||||||
def _get_num_frame_tokens(
|
def _get_num_frame_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||||
|
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
||||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||||
|
|
||||||
return pooled_grid_length * pooled_grid_length
|
return pooled_grid_length * pooled_grid_length
|
||||||
|
|
||||||
def _get_num_video_tokens(
|
def get_num_video_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@ -87,37 +105,14 @@ class LlavaNextVideoProcessingMixin(ProcessingMixin):
|
|||||||
|
|
||||||
return num_frame_tokens * num_frames
|
return num_frame_tokens * num_frames
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
|
||||||
BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
return {"video": 1}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
|
||||||
max_video_tokens = self._get_num_video_tokens(
|
|
||||||
image_width=target_width,
|
|
||||||
image_height=target_height,
|
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"video": max_video_tokens}
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
|
||||||
width = height = vision_encoder_info.get_image_size()
|
|
||||||
return ImageSize(width=width, height=height)
|
|
||||||
|
|
||||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
next_num_frames = num_frames + 1
|
next_num_frames = num_frames + 1
|
||||||
next_max_tokens = self._get_num_video_tokens(
|
next_max_tokens = self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=next_num_frames,
|
num_frames=next_num_frames,
|
||||||
@ -130,7 +125,7 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
|||||||
|
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||||
mm_config = self.ctx.get_mm_config()
|
mm_config = self.ctx.get_mm_config()
|
||||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||||
|
|
||||||
@ -138,6 +133,10 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
|||||||
|
|
||||||
return max(max_total_frames // max(max_videos, 1), 1)
|
return max(max_total_frames // max(max_videos, 1), 1)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextVideoDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -145,16 +144,20 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
|||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
num_videos = mm_counts.get("video", 0)
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
processor = self._get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
video_token = processor.video_token
|
video_token = processor.video_token
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
target_num_frames = \
|
||||||
|
self.info.get_num_frames_with_most_features(seq_len)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"video":
|
"video":
|
||||||
self._get_dummy_videos(
|
self._get_dummy_videos(
|
||||||
width=target_width,
|
width=target_width,
|
||||||
height=target_height,
|
height=target_height,
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
num_frames=target_num_frames,
|
||||||
num_videos=num_videos,
|
num_videos=num_videos,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -165,11 +168,8 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
|
class LlavaNextVideoMultiModalProcessor(
|
||||||
BaseMultiModalProcessor):
|
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return LlavaNextVideoProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -184,7 +184,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
video_token_id = hf_config.video_token_index
|
video_token_id = hf_config.video_token_index
|
||||||
|
|
||||||
def get_replacement(item_idx: int):
|
def get_replacement(item_idx: int):
|
||||||
@ -195,7 +195,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
|
|||||||
num_video_tokens = videos.get_feature_size(item_idx)
|
num_video_tokens = videos.get_feature_size(item_idx)
|
||||||
else:
|
else:
|
||||||
image_size = videos.get_frame_size(item_idx)
|
image_size = videos.get_frame_size(item_idx)
|
||||||
num_video_tokens = self._get_num_video_tokens(
|
num_video_tokens = self.info.get_num_video_tokens(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
num_frames=videos.get_num_frames(item_idx),
|
num_frames=videos.get_num_frames(item_idx),
|
||||||
@ -269,7 +269,11 @@ class LlavaNextMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
LlavaNextVideoMultiModalProcessor,
|
||||||
|
info=LlavaNextVideoProcessingInfo,
|
||||||
|
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
|
||||||
|
)
|
||||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
|
@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
NestedTensors)
|
||||||
VideoEmbeddingItems, VideoProcessorItems)
|
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
|
||||||
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
|
VideoProcessorItems)
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
from vllm.multimodal.processing import PromptReplacement
|
||||||
|
from vllm.multimodal.profiling import ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .clip import CLIPVisionModel
|
from .clip import CLIPVisionModel
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
|
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
|
||||||
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
|
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
|
||||||
LlavaNextProcessingMixin)
|
LlavaNextProcessingInfo)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
|
|||||||
video_token_index: Final[int]
|
video_token_index: Final[int]
|
||||||
|
|
||||||
|
|
||||||
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
||||||
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
||||||
|
|
||||||
def _get_hf_processor(self):
|
def get_hf_processor(self):
|
||||||
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None, "video": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
return {
|
||||||
|
"image": self.get_max_image_tokens(),
|
||||||
|
"video": self.get_max_video_tokens(seq_len),
|
||||||
|
}
|
||||||
|
|
||||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||||
# with additional logic afterwards taken from LlavaOnevisionProcessor
|
# with additional logic afterwards taken from LlavaOnevisionProcessor
|
||||||
def _get_num_unpadded_features(
|
def _get_num_unpadded_features(
|
||||||
@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
|||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
|
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
|
||||||
|
|
||||||
vision_encoder_info = self._get_vision_encoder_info()
|
vision_encoder_info = self.get_vision_encoder_info()
|
||||||
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
||||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||||
|
|
||||||
return pooled_grid_length * pooled_grid_length
|
return pooled_grid_length * pooled_grid_length
|
||||||
|
|
||||||
def _get_num_video_tokens(
|
def get_num_video_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
|||||||
|
|
||||||
return num_frame_tokens * num_frames + 1 # Newline token
|
return num_frame_tokens * num_frames + 1 # Newline token
|
||||||
|
|
||||||
|
|
||||||
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|
||||||
BaseLlavaProfilingInfo):
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
hf_config = self._get_hf_config()
|
|
||||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
|
||||||
for (height, width) in hf_config.image_grid_pinpoints:
|
|
||||||
feat_size = self._get_num_image_tokens(image_width=width,
|
|
||||||
image_height=height)
|
|
||||||
if feat_size > largest_feature_size:
|
|
||||||
largest_feature_size = feat_size
|
|
||||||
largest_feature_pinpoint = ImageSize(width=width,
|
|
||||||
height=height)
|
|
||||||
|
|
||||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
|
||||||
raise ValueError("Cannot have a largest feature size of 0!")
|
|
||||||
|
|
||||||
return largest_feature_pinpoint
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
return {"image": None, "video": None}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
return {
|
|
||||||
"image": self._get_max_image_tokens(),
|
|
||||||
"video": self._get_max_video_tokens(seq_len),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
next_num_frames = num_frames + 1
|
next_num_frames = num_frames + 1
|
||||||
next_max_tokens = self._get_num_video_tokens(
|
next_max_tokens = self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=next_num_frames,
|
num_frames=next_num_frames,
|
||||||
@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|||||||
|
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||||
mm_config = self.ctx.get_mm_config()
|
mm_config = self.ctx.get_mm_config()
|
||||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||||
|
|
||||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
max_total_frames = self._get_max_video_frames(seq_len -
|
max_total_frames = self._get_max_video_frames(seq_len -
|
||||||
max_image_tokens)
|
max_image_tokens)
|
||||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||||
@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|||||||
|
|
||||||
return max(max_frames_per_video, 1)
|
return max(max_frames_per_video, 1)
|
||||||
|
|
||||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
return self._get_num_video_tokens(
|
return self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaOnevisionDummyInputsBuilder(
|
||||||
|
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
num_videos = mm_counts.get("video", 0)
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
processor = self._get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
image_token = processor.image_token
|
image_token = processor.image_token
|
||||||
video_token = processor.video_token
|
video_token = processor.video_token
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
target_num_frames = \
|
||||||
|
self.info.get_num_frames_with_most_features(seq_len)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|||||||
self._get_dummy_videos(
|
self._get_dummy_videos(
|
||||||
width=target_width,
|
width=target_width,
|
||||||
height=target_height,
|
height=target_height,
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
num_frames=target_num_frames,
|
||||||
num_videos=num_videos,
|
num_videos=num_videos,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
class LlavaOnevisionMultiModalProcessor(
|
||||||
LlavaNextMultiModalProcessor):
|
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return LlavaOnevisionProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
|||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
processor = self._get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
video_token = processor.video_token
|
video_token = processor.video_token
|
||||||
|
|
||||||
# LLaVA-OneVision processor doesn't support multiple videos
|
# LLaVA-OneVision processor doesn't support multiple videos
|
||||||
@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
|||||||
out_mm_kwargs=out_mm_kwargs,
|
out_mm_kwargs=out_mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
video_token_id = hf_config.video_token_index
|
video_token_id = hf_config.video_token_index
|
||||||
|
|
||||||
def get_video_replacement(item_idx: int):
|
def get_video_replacement(item_idx: int):
|
||||||
@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
|||||||
num_video_tokens = videos.get_feature_size(item_idx)
|
num_video_tokens = videos.get_feature_size(item_idx)
|
||||||
else:
|
else:
|
||||||
image_size = videos.get_frame_size(item_idx)
|
image_size = videos.get_frame_size(item_idx)
|
||||||
num_video_tokens = self._get_num_video_tokens(
|
num_video_tokens = self.info.get_num_video_tokens(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
num_frames=videos.get_num_frames(item_idx),
|
num_frames=videos.get_num_frames(item_idx),
|
||||||
@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
LlavaOnevisionMultiModalProcessor,
|
||||||
|
info=LlavaOnevisionProcessingInfo,
|
||||||
|
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
|
||||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
|
@ -34,13 +34,12 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
NestedTensors, PlaceholderRange)
|
NestedTensors, PlaceholderRange)
|
||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
ImageSize)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo,
|
||||||
PromptReplacement,
|
BoundPromptReplacement,
|
||||||
_BoundPromptReplacement,
|
PlaceholderInfo, PromptReplacement)
|
||||||
_PlaceholderInfo)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
@ -302,9 +301,9 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
|||||||
return image_features_hd_newline
|
return image_features_hd_newline
|
||||||
|
|
||||||
|
|
||||||
class Phi3VProcessingMixin(ProcessingMixin):
|
class Phi3VProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_processor(
|
def get_hf_processor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
num_crops: Optional[int] = None,
|
num_crops: Optional[int] = None,
|
||||||
@ -314,39 +313,42 @@ class Phi3VProcessingMixin(ProcessingMixin):
|
|||||||
|
|
||||||
return self.ctx.get_hf_processor()
|
return self.ctx.get_hf_processor()
|
||||||
|
|
||||||
def _get_num_image_tokens(
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
|
max_image_tokens = self.get_num_image_tokens(
|
||||||
|
image_width=target_width,
|
||||||
|
image_height=target_height,
|
||||||
|
processor=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"image": max_image_tokens}
|
||||||
|
|
||||||
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
|
processor: Optional[ProcessorMixin],
|
||||||
) -> int:
|
) -> int:
|
||||||
processor = self._get_hf_processor()
|
if processor is None:
|
||||||
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
return processor.calc_num_image_tokens_from_image_size( # type: ignore
|
return processor.calc_num_image_tokens_from_image_size( # type: ignore
|
||||||
width=image_width,
|
width=image_width,
|
||||||
height=image_height,
|
height=image_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
return {"image": None}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
|
||||||
max_image_tokens = self._get_num_image_tokens(
|
|
||||||
image_width=target_width,
|
|
||||||
image_height=target_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"image": max_image_tokens}
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
# Result in the max possible feature size (h:w = 16:1)
|
# Result in the max possible feature size (h:w = 16:1)
|
||||||
return ImageSize(height=8000, width=50)
|
return ImageSize(height=8000, width=50)
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -354,7 +356,8 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
|
|||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -363,7 +366,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
|
|||||||
num_images=num_images)
|
num_images=num_images)
|
||||||
}
|
}
|
||||||
|
|
||||||
hf_processor = self._get_hf_processor()
|
hf_processor = self.info.get_hf_processor()
|
||||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||||
|
|
||||||
return ProcessorInputs(
|
return ProcessorInputs(
|
||||||
@ -372,10 +375,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return Phi3VProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
@ -416,10 +416,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||||
|
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
bos_token_id = tokenizer.bos_token_id
|
bos_token_id = tokenizer.bos_token_id
|
||||||
assert isinstance(bos_token_id, int)
|
assert isinstance(bos_token_id, int)
|
||||||
|
|
||||||
@ -431,9 +431,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
|||||||
num_image_tokens = images.get_feature_size(item_idx)
|
num_image_tokens = images.get_feature_size(item_idx)
|
||||||
else:
|
else:
|
||||||
image_size = images.get_image_size(item_idx)
|
image_size = images.get_image_size(item_idx)
|
||||||
num_image_tokens = self._get_num_image_tokens(
|
num_image_tokens = self.info.get_num_image_tokens(
|
||||||
image_width=image_size.width,
|
image_width=image_size.width,
|
||||||
image_height=image_size.height,
|
image_height=image_size.height,
|
||||||
|
processor=hf_processor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
|
||||||
@ -451,9 +452,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
|||||||
def _apply_prompt_replacements(
|
def _apply_prompt_replacements(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
|
||||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
mm_prompt_repls=mm_prompt_repls,
|
mm_prompt_repls=mm_prompt_repls,
|
||||||
@ -466,7 +467,7 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
|||||||
token_ids = [token_ids[0], *token_ids[2:]]
|
token_ids = [token_ids[0], *token_ids[2:]]
|
||||||
placeholders = {
|
placeholders = {
|
||||||
modality: [
|
modality: [
|
||||||
_PlaceholderInfo(
|
PlaceholderInfo(
|
||||||
modality=p.modality,
|
modality=p.modality,
|
||||||
item_idx=p.item_idx,
|
item_idx=p.item_idx,
|
||||||
start_idx=p.start_idx - 1,
|
start_idx=p.start_idx - 1,
|
||||||
@ -499,7 +500,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
|
||||||
|
info=Phi3VProcessingInfo,
|
||||||
|
dummy_inputs=Phi3VDummyInputsBuilder)
|
||||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
|
@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
|
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||||
|
MultiModalDataParser)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
|||||||
return feat_lengths, output_lengths
|
return feat_lengths, output_lengths
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioProcessingMixin(ProcessingMixin):
|
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||||
|
|
||||||
def _get_hf_processor(
|
def get_hf_processor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
# Ignored in initialization
|
# Ignored in initialization
|
||||||
@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
|
|||||||
) -> Qwen2AudioProcessor:
|
) -> Qwen2AudioProcessor:
|
||||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
|
||||||
|
|
||||||
def _get_feature_extractor(
|
def get_feature_extractor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
# Ignored in initialization
|
# Ignored in initialization
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
) -> WhisperFeatureExtractor:
|
) -> WhisperFeatureExtractor:
|
||||||
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
|
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
|
||||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||||
return feature_extractor
|
return feature_extractor
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"audio": None}
|
return {"audio": None}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.get_hf_config()
|
||||||
max_source_positions = hf_config.audio_config.max_source_positions
|
max_source_positions = hf_config.audio_config.max_source_positions
|
||||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||||
|
|
||||||
return {"audio": max_output_lengths}
|
return {"audio": max_output_lengths}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2AudioDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
|
||||||
sampling_rate = feature_extractor.sampling_rate
|
sampling_rate = feature_extractor.sampling_rate
|
||||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||||
@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
class Qwen2AudioMultiModalProcessor(
|
||||||
BaseMultiModalProcessor):
|
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return Qwen2AudioProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
|||||||
if audios:
|
if audios:
|
||||||
mm_data["audios"] = audios
|
mm_data["audios"] = audios
|
||||||
|
|
||||||
feature_extractor = self._get_feature_extractor(**mm_kwargs)
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||||
mm_kwargs = dict(
|
mm_kwargs = dict(
|
||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
sampling_rate=feature_extractor.sampling_rate,
|
sampling_rate=feature_extractor.sampling_rate,
|
||||||
@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_config = self._get_hf_config()
|
hf_config = self.info.get_hf_config()
|
||||||
placeholder = hf_config.audio_token_index
|
placeholder = hf_config.audio_token_index
|
||||||
|
|
||||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||||
@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
|
|||||||
# has already performed processing for multi-audio input when the input
|
# has already performed processing for multi-audio input when the input
|
||||||
# audios are short (the corresponding placeholders may take up fewer
|
# audios are short (the corresponding placeholders may take up fewer
|
||||||
# tokens than the number of audio items)
|
# tokens than the number of audio items)
|
||||||
return not hasattr(self._get_hf_processor(), "audio_token")
|
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
Qwen2AudioMultiModalProcessor,
|
||||||
|
info=Qwen2AudioProcessingInfo,
|
||||||
|
dummy_inputs=Qwen2AudioDummyInputsBuilder)
|
||||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
|
||||||
|
@ -57,11 +57,10 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
|||||||
MultiModalFieldConfig, MultiModalKwargs,
|
MultiModalFieldConfig, MultiModalKwargs,
|
||||||
NestedTensors, VideoItem)
|
NestedTensors, VideoItem)
|
||||||
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
|
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
|
||||||
MultiModalDataParser)
|
MultiModalDataItems, MultiModalDataParser)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.platforms import _Backend
|
from vllm.platforms import _Backend
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.config import uses_mrope
|
from vllm.transformers_utils.config import uses_mrope
|
||||||
@ -709,12 +708,12 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
|
|||||||
return super()._parse_video_data(data)
|
return super()._parse_video_data(data)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLProcessingMixin(ProcessingMixin):
|
class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(Qwen2VLConfig)
|
return self.ctx.get_hf_config(Qwen2VLConfig)
|
||||||
|
|
||||||
def _get_hf_processor(
|
def get_hf_processor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
min_pixels: Optional[int] = None,
|
min_pixels: Optional[int] = None,
|
||||||
@ -736,18 +735,27 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
|
|||||||
|
|
||||||
return hf_processor
|
return hf_processor
|
||||||
|
|
||||||
def _get_image_processor(
|
def get_image_processor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
min_pixels: Optional[int] = None,
|
min_pixels: Optional[int] = None,
|
||||||
max_pixels: Optional[int] = None,
|
max_pixels: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hf_processor = self._get_hf_processor(min_pixels=min_pixels,
|
hf_processor = self.get_hf_processor(min_pixels=min_pixels,
|
||||||
max_pixels=max_pixels)
|
max_pixels=max_pixels)
|
||||||
image_processor = hf_processor.image_processor # type: ignore
|
image_processor = hf_processor.image_processor # type: ignore
|
||||||
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
assert isinstance(image_processor, Qwen2VLImageProcessor)
|
||||||
return image_processor
|
return image_processor
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None, "video": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
return {
|
||||||
|
"image": self.get_max_image_tokens(),
|
||||||
|
"video": self.get_max_video_tokens(seq_len),
|
||||||
|
}
|
||||||
|
|
||||||
def _get_vision_info(
|
def _get_vision_info(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -755,15 +763,17 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
|
|||||||
image_height: int,
|
image_height: int,
|
||||||
num_frames: int = 1,
|
num_frames: int = 1,
|
||||||
do_resize: bool = True,
|
do_resize: bool = True,
|
||||||
|
image_processor: Optional[Qwen2VLImageProcessor],
|
||||||
) -> tuple[ImageSize, int]:
|
) -> tuple[ImageSize, int]:
|
||||||
hf_config = self._get_hf_config()
|
if image_processor is None:
|
||||||
|
image_processor = self.get_image_processor()
|
||||||
|
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
vision_config = hf_config.vision_config
|
vision_config = hf_config.vision_config
|
||||||
patch_size = vision_config.patch_size
|
patch_size = vision_config.patch_size
|
||||||
merge_size = vision_config.spatial_merge_size
|
merge_size = vision_config.spatial_merge_size
|
||||||
temporal_patch_size = vision_config.temporal_patch_size
|
temporal_patch_size = vision_config.temporal_patch_size
|
||||||
|
|
||||||
image_processor = self._get_image_processor()
|
|
||||||
|
|
||||||
if do_resize:
|
if do_resize:
|
||||||
resized_height, resized_width = smart_resize(
|
resized_height, resized_width = smart_resize(
|
||||||
height=image_height,
|
height=image_height,
|
||||||
@ -787,70 +797,65 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
|
|||||||
|
|
||||||
return preprocessed_size, num_vision_tokens
|
return preprocessed_size, num_vision_tokens
|
||||||
|
|
||||||
def _get_num_image_tokens(
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
|
image_processor: Optional[Qwen2VLImageProcessor],
|
||||||
) -> int:
|
) -> int:
|
||||||
_, num_image_tokens = self._get_vision_info(
|
_, num_image_tokens = self._get_vision_info(
|
||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
image_height=image_height,
|
image_height=image_height,
|
||||||
|
image_processor=image_processor,
|
||||||
)
|
)
|
||||||
return num_image_tokens
|
return num_image_tokens
|
||||||
|
|
||||||
def _get_num_video_tokens(
|
def get_num_video_tokens(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
num_frames: int,
|
num_frames: int,
|
||||||
|
image_processor: Optional[Qwen2VLImageProcessor],
|
||||||
) -> int:
|
) -> int:
|
||||||
_, num_video_tokens = self._get_vision_info(
|
_, num_video_tokens = self._get_vision_info(
|
||||||
image_width=image_width,
|
image_width=image_width,
|
||||||
image_height=image_height,
|
image_height=image_height,
|
||||||
num_frames=num_frames,
|
num_frames=num_frames,
|
||||||
|
image_processor=image_processor,
|
||||||
)
|
)
|
||||||
return num_video_tokens
|
return num_video_tokens
|
||||||
|
|
||||||
|
def get_image_size_with_most_features(self) -> ImageSize:
|
||||||
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
return {"image": None, "video": None}
|
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
return {
|
|
||||||
"image": self._get_max_image_tokens(),
|
|
||||||
"video": self._get_max_video_tokens(seq_len),
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
max_image_size, _ = self._get_vision_info(
|
max_image_size, _ = self._get_vision_info(
|
||||||
image_width=9999999,
|
image_width=9999999,
|
||||||
image_height=9999999,
|
image_height=9999999,
|
||||||
|
image_processor=None,
|
||||||
)
|
)
|
||||||
return max_image_size
|
return max_image_size
|
||||||
|
|
||||||
def _get_max_image_tokens(self) -> int:
|
def get_max_image_tokens(self) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
return self._get_num_image_tokens(
|
return self.get_num_image_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
num_frames = 0
|
num_frames = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
next_num_frames = num_frames + 1
|
next_num_frames = num_frames + 1
|
||||||
next_max_tokens = self._get_num_video_tokens(
|
next_max_tokens = self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=next_num_frames,
|
num_frames=next_num_frames,
|
||||||
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if next_max_tokens > max_tokens:
|
if next_max_tokens > max_tokens:
|
||||||
@ -860,12 +865,12 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|||||||
|
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||||
mm_config = self.ctx.get_mm_config()
|
mm_config = self.ctx.get_mm_config()
|
||||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||||
|
|
||||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||||
max_total_frames = self._get_max_video_frames(seq_len -
|
max_total_frames = self._get_max_video_frames(seq_len -
|
||||||
max_image_tokens)
|
max_image_tokens)
|
||||||
|
|
||||||
@ -877,15 +882,19 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|||||||
|
|
||||||
return num_frames
|
return num_frames
|
||||||
|
|
||||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
|
|
||||||
return self._get_num_video_tokens(
|
return self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||||
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
@ -894,10 +903,14 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
num_videos = mm_counts.get("video", 0)
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
hf_processor = self._get_hf_processor()
|
hf_processor = self.info.get_hf_processor()
|
||||||
image_token: str = hf_processor.image_token
|
image_token: str = hf_processor.image_token
|
||||||
video_token: str = hf_processor.video_token
|
video_token: str = hf_processor.video_token
|
||||||
target_width, target_height = self._get_image_size_with_most_features()
|
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
target_num_frames = \
|
||||||
|
self.info.get_num_frames_with_most_features(seq_len)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"image":
|
"image":
|
||||||
@ -908,7 +921,7 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|||||||
self._get_dummy_videos(
|
self._get_dummy_videos(
|
||||||
width=target_width,
|
width=target_width,
|
||||||
height=target_height,
|
height=target_height,
|
||||||
num_frames=self._get_dummy_num_frames(seq_len),
|
num_frames=target_num_frames,
|
||||||
num_videos=num_videos,
|
num_videos=num_videos,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@ -919,11 +932,8 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
|
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||||
BaseMultiModalProcessor):
|
):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return Qwen2VLProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
return Qwen2MultiModalDataParser()
|
return Qwen2MultiModalDataParser()
|
||||||
@ -934,8 +944,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
image_processor = self._get_image_processor(**hf_processor_mm_kwargs)
|
image_processor = self.info.get_image_processor(
|
||||||
|
**hf_processor_mm_kwargs)
|
||||||
|
|
||||||
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
|
||||||
# image_token and video_token registered
|
# image_token and video_token registered
|
||||||
@ -991,7 +1002,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
|
||||||
|
info=Qwen2VLProcessingInfo,
|
||||||
|
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA, SupportsPP):
|
SupportsLoRA, SupportsPP):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
|
@ -24,11 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
NestedTensors)
|
NestedTensors)
|
||||||
from vllm.multimodal.parse import MultiModalDataParser
|
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
MultiModalDataItems, ProcessingMixin,
|
BaseProcessingInfo, PromptReplacement)
|
||||||
PromptReplacement)
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
|
||||||
@ -59,9 +58,9 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
|||||||
UltravoxAudioEmbeddingInputs]
|
UltravoxAudioEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
class UltravoxProcessingMixin(ProcessingMixin):
|
class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def _get_hf_processor(
|
def get_hf_processor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
# Ignored in initialization
|
# Ignored in initialization
|
||||||
@ -76,37 +75,38 @@ class UltravoxProcessingMixin(ProcessingMixin):
|
|||||||
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
|
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
|
||||||
return hf_processor
|
return hf_processor
|
||||||
|
|
||||||
def _get_feature_extractor(
|
def get_feature_extractor(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
# Ignored in initialization
|
# Ignored in initialization
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
) -> WhisperFeatureExtractor:
|
) -> WhisperFeatureExtractor:
|
||||||
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
|
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
|
||||||
audio_processor = hf_processor.audio_processor # type: ignore
|
audio_processor = hf_processor.audio_processor # type: ignore
|
||||||
feature_extractor = audio_processor.feature_extractor # type: ignore
|
feature_extractor = audio_processor.feature_extractor # type: ignore
|
||||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||||
return feature_extractor
|
return feature_extractor
|
||||||
|
|
||||||
|
|
||||||
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
|
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"audio": None}
|
return {"audio": None}
|
||||||
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||||
_AUDIO_TOKENS_PER_SECOND)
|
_AUDIO_TOKENS_PER_SECOND)
|
||||||
|
|
||||||
return {"audio": max_audio_tokens}
|
return {"audio": max_audio_tokens}
|
||||||
|
|
||||||
|
|
||||||
|
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
|
||||||
sampling_rate = feature_extractor.sampling_rate
|
sampling_rate = feature_extractor.sampling_rate
|
||||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||||
@ -123,14 +123,11 @@ class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
|
class UltravoxMultiModalProcessor(
|
||||||
BaseMultiModalProcessor):
|
BaseMultiModalProcessor[UltravoxProcessingInfo]):
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
return UltravoxProfilingInfo(self.ctx)
|
|
||||||
|
|
||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
@ -141,7 +138,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
|
|||||||
) -> BatchFeature:
|
) -> BatchFeature:
|
||||||
# Text-only input not supported in composite processor
|
# Text-only input not supported in composite processor
|
||||||
if not mm_data:
|
if not mm_data:
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
prompt_ids = tokenizer.encode(
|
prompt_ids = tokenizer.encode(
|
||||||
prompt,
|
prompt,
|
||||||
@ -160,7 +157,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
|
|||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
feature_extractor = self._get_feature_extractor()
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
mm_kwargs = dict(
|
mm_kwargs = dict(
|
||||||
**mm_kwargs,
|
**mm_kwargs,
|
||||||
sampling_rate=feature_extractor.sampling_rate,
|
sampling_rate=feature_extractor.sampling_rate,
|
||||||
@ -208,7 +205,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
|
|||||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> list[PromptReplacement]:
|
) -> list[PromptReplacement]:
|
||||||
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
placeholder = hf_processor.audio_token_replacement # type: ignore
|
placeholder = hf_processor.audio_token_replacement # type: ignore
|
||||||
|
|
||||||
def get_replacement_ultravox(item_idx: int):
|
def get_replacement_ultravox(item_idx: int):
|
||||||
@ -342,7 +339,10 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
|
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor,
|
||||||
|
info=UltravoxProcessingInfo,
|
||||||
|
dummy_inputs=UltravoxDummyInputsBuilder
|
||||||
|
)
|
||||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
@ -4,12 +4,13 @@ from collections import defaultdict
|
|||||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
|
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||||
|
|
||||||
from vllm import envs
|
import vllm.envs as envs
|
||||||
from vllm.inputs import DummyData, InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||||
encode_tokens)
|
encode_tokens)
|
||||||
@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
|||||||
MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||||
from .profiling import BaseProfilingInfo
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .profiling import BaseDummyInputsBuilder
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -46,8 +49,8 @@ class PromptReplacement:
|
|||||||
if it does not depend on the input.
|
if it does not depend on the input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
|
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
|
||||||
return _BoundPromptReplacement(
|
return BoundPromptReplacement(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
modality=self.modality,
|
modality=self.modality,
|
||||||
_target=self.target,
|
_target=self.target,
|
||||||
@ -128,7 +131,7 @@ class _BoundPromptSequence:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _BoundPromptReplacement:
|
class BoundPromptReplacement:
|
||||||
tokenizer: AnyTokenizer = field(repr=False)
|
tokenizer: AnyTokenizer = field(repr=False)
|
||||||
modality: str
|
modality: str
|
||||||
|
|
||||||
@ -207,7 +210,7 @@ def iter_token_matches(
|
|||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class _PromptReplacementMatch(ABC):
|
class _PromptReplacementMatch(ABC):
|
||||||
prompt_repl: _BoundPromptReplacement
|
prompt_repl: BoundPromptReplacement
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def modality(self) -> str:
|
def modality(self) -> str:
|
||||||
@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _PlaceholderInfo:
|
class PlaceholderInfo:
|
||||||
modality: str
|
modality: str
|
||||||
item_idx: int
|
item_idx: int
|
||||||
start_idx: int
|
start_idx: int
|
||||||
@ -274,7 +277,7 @@ class _PlaceholderInfo:
|
|||||||
|
|
||||||
def find_token_matches(
|
def find_token_matches(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
prompt_repls: Sequence[BoundPromptReplacement],
|
||||||
) -> list[_PromptReplacementTokenMatch]:
|
) -> list[_PromptReplacementTokenMatch]:
|
||||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||||
return [
|
return [
|
||||||
@ -286,7 +289,7 @@ def find_token_matches(
|
|||||||
|
|
||||||
def find_text_matches(
|
def find_text_matches(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
prompt_repls: Sequence[BoundPromptReplacement],
|
||||||
) -> list[_PromptReplacementTextMatch]:
|
) -> list[_PromptReplacementTextMatch]:
|
||||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||||
return [
|
return [
|
||||||
@ -390,9 +393,9 @@ def replace_text_matches(
|
|||||||
def _iter_modality_placeholders(
|
def _iter_modality_placeholders(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
modality: str,
|
modality: str,
|
||||||
modality_repls: Sequence[_BoundPromptReplacement],
|
modality_repls: Sequence[BoundPromptReplacement],
|
||||||
modal_item_count: int,
|
modal_item_count: int,
|
||||||
) -> Iterable[_PlaceholderInfo]:
|
) -> Iterable[PlaceholderInfo]:
|
||||||
if modal_item_count == 0:
|
if modal_item_count == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -413,7 +416,7 @@ def _iter_modality_placeholders(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if prompt[start_idx:end_idx] == repl_tokens:
|
if prompt[start_idx:end_idx] == repl_tokens:
|
||||||
yield _PlaceholderInfo(
|
yield PlaceholderInfo(
|
||||||
modality=modality,
|
modality=modality,
|
||||||
item_idx=item_idx,
|
item_idx=item_idx,
|
||||||
start_idx=start_idx,
|
start_idx=start_idx,
|
||||||
@ -434,10 +437,10 @@ def _iter_modality_placeholders(
|
|||||||
|
|
||||||
|
|
||||||
def _iter_placeholders(
|
def _iter_placeholders(
|
||||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> Iterable[_PlaceholderInfo]:
|
) -> Iterable[PlaceholderInfo]:
|
||||||
"""
|
"""
|
||||||
For each modality, yield each set of placeholder tokens found in
|
For each modality, yield each set of placeholder tokens found in
|
||||||
:code:`prompt`.
|
:code:`prompt`.
|
||||||
@ -455,10 +458,10 @@ def _iter_placeholders(
|
|||||||
|
|
||||||
|
|
||||||
def find_mm_placeholders(
|
def find_mm_placeholders(
|
||||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> Mapping[str, list[_PlaceholderInfo]]:
|
) -> Mapping[str, list[PlaceholderInfo]]:
|
||||||
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
||||||
return dict(full_groupby_modality(it))
|
return dict(full_groupby_modality(it))
|
||||||
|
|
||||||
@ -524,29 +527,59 @@ class ProcessingCache:
|
|||||||
self._cache.put(cache_key, output_kwargs)
|
self._cache.put(cache_key, output_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ProcessingMixin:
|
class BaseProcessingInfo:
|
||||||
"""
|
"""Base class containing information to perform processing."""
|
||||||
Contains helper functions to perform processing.
|
|
||||||
|
|
||||||
Not to be confused with :class:`transformers.ProcessorMixin`.
|
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||||
"""
|
super().__init__()
|
||||||
ctx: InputProcessingContext
|
|
||||||
|
|
||||||
def _get_tokenizer(self) -> AnyTokenizer:
|
self.ctx = ctx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self.ctx.model_config.model
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> AnyTokenizer:
|
||||||
return self.ctx.tokenizer
|
return self.ctx.tokenizer
|
||||||
|
|
||||||
def _get_hf_config(self) -> PretrainedConfig:
|
def get_hf_config(self) -> PretrainedConfig:
|
||||||
return self.ctx.get_hf_config()
|
return self.ctx.get_hf_config()
|
||||||
|
|
||||||
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||||
"""
|
"""
|
||||||
Subclasses can override this method to handle
|
Subclasses can override this method to handle
|
||||||
specific kwargs from model config or user inputs.
|
specific kwargs from model config or user inputs.
|
||||||
"""
|
"""
|
||||||
return self.ctx.get_hf_processor(**kwargs)
|
return self.ctx.get_hf_processor(**kwargs)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
"""
|
||||||
|
Return the maximum supported number of items for each modality.
|
||||||
|
|
||||||
class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
A value of `None` means unlimited number of items.
|
||||||
|
|
||||||
|
Omitting a modality from the returned dictionary means that
|
||||||
|
it is not supported at all.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
"""
|
||||||
|
Get the maximum possible number of tokens per data item
|
||||||
|
for each modality.
|
||||||
|
|
||||||
|
The dictionary returned by this method should have the same
|
||||||
|
keys as that returned by :meth:`get_supported_mm_limits`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||||
"""
|
"""
|
||||||
Abstract base class to process multi-modal inputs to be used in vLLM.
|
Abstract base class to process multi-modal inputs to be used in vLLM.
|
||||||
|
|
||||||
@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
ctx: InputProcessingContext,
|
info: _I,
|
||||||
|
dummy_inputs: "BaseDummyInputsBuilder[_I]",
|
||||||
*,
|
*,
|
||||||
cache: Optional[ProcessingCache] = None,
|
cache: Optional[ProcessingCache] = None,
|
||||||
enable_sanity_checks: bool = True) -> None:
|
enable_sanity_checks: bool = True) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ctx = ctx
|
self.info = info
|
||||||
|
self.dummy_inputs = dummy_inputs
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.enable_sanity_checks = enable_sanity_checks
|
self.enable_sanity_checks = enable_sanity_checks
|
||||||
|
|
||||||
self.data_parser = self._get_data_parser()
|
self.data_parser = self._get_data_parser()
|
||||||
self.profiling_info = self._get_profiling_info()
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
return MultiModalDataParser()
|
return MultiModalDataParser()
|
||||||
|
|
||||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
|
||||||
"""
|
|
||||||
Get the profiling information to find the worst-case memory usage of
|
|
||||||
the model.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _to_mm_items(
|
def _to_mm_items(
|
||||||
self,
|
self,
|
||||||
mm_data: MultiModalDataDict,
|
mm_data: MultiModalDataDict,
|
||||||
@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
"""
|
"""
|
||||||
mm_items = self.data_parser.parse_mm_data(mm_data)
|
mm_items = self.data_parser.parse_mm_data(mm_data)
|
||||||
|
|
||||||
mm_limits = self.ctx.get_mm_config().limit_per_prompt
|
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
|
||||||
for modality, items in mm_items.items():
|
for modality, items in mm_items.items():
|
||||||
limit = mm_limits.get(modality, 1)
|
limit = mm_limits.get(modality, 1)
|
||||||
if len(items) > limit:
|
if len(items) > limit:
|
||||||
@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
|
|
||||||
def _find_mm_placeholders(
|
def _find_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> Mapping[str, list[_PlaceholderInfo]]:
|
) -> Mapping[str, list[PlaceholderInfo]]:
|
||||||
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
||||||
mm_item_counts)
|
mm_item_counts)
|
||||||
|
|
||||||
def _get_hf_mm_data(
|
def _get_hf_mm_data(
|
||||||
self,
|
self,
|
||||||
mm_items: MultiModalDataItems,
|
mm_items: MultiModalDataItems,
|
||||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
) -> tuple[Mapping[str, object], Mapping[str, object]]:
|
||||||
processor_data = dict[str, Any]()
|
processor_data = dict[str, object]()
|
||||||
passthrough_data = dict[str, Any]()
|
passthrough_data = dict[str, object]()
|
||||||
|
|
||||||
for items in mm_items.values():
|
for items in mm_items.values():
|
||||||
processor_data.update(items.get_processor_data())
|
processor_data.update(items.get_processor_data())
|
||||||
@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
Call the HF processor on the prompt text and
|
Call the HF processor on the prompt text and
|
||||||
associated multi-modal data.
|
associated multi-modal data.
|
||||||
"""
|
"""
|
||||||
return self.ctx.call_hf_processor(
|
return self.info.ctx.call_hf_processor(
|
||||||
self._get_hf_processor(**mm_kwargs),
|
self.info.get_hf_processor(**mm_kwargs),
|
||||||
dict(text=prompt, **mm_data),
|
dict(text=prompt, **mm_data),
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
|
|
||||||
# Some HF processors (e.g. Qwen2-VL) expect corresponding
|
# Some HF processors (e.g. Qwen2-VL) expect corresponding
|
||||||
# multi-modal tokens to be in the prompt text
|
# multi-modal tokens to be in the prompt text
|
||||||
dummy_inputs = self.profiling_info.get_dummy_processor_inputs(
|
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
|
||||||
self.ctx.model_config.max_model_len,
|
self.info.ctx.model_config.max_model_len,
|
||||||
mm_missing_counts,
|
mm_missing_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
caching the results and reusing cached results.
|
caching the results and reusing cached results.
|
||||||
"""
|
"""
|
||||||
cache = self.cache
|
cache = self.cache
|
||||||
model_id = self.ctx.model_config.model
|
model_id = self.info.model_id
|
||||||
|
|
||||||
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
||||||
if cache is None or passthrough_data:
|
if cache is None or passthrough_data:
|
||||||
@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
def _bind_and_group_repls(
|
def _bind_and_group_repls(
|
||||||
self,
|
self,
|
||||||
prompt_repls: list[PromptReplacement],
|
prompt_repls: list[PromptReplacement],
|
||||||
) -> dict[str, list[_BoundPromptReplacement]]:
|
) -> dict[str, list[BoundPromptReplacement]]:
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
||||||
return dict(full_groupby_modality(it))
|
return dict(full_groupby_modality(it))
|
||||||
@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
def _apply_prompt_replacements(
|
def _apply_prompt_replacements(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
|
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
|
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
|
||||||
mm_token_matches = {
|
mm_token_matches = {
|
||||||
modality: find_token_matches(token_ids, prompt_repls)
|
modality: find_token_matches(token_ids, prompt_repls)
|
||||||
@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
|
|
||||||
def _validate_mm_placeholders(
|
def _validate_mm_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
|
mm_placeholders: Mapping[str, list[PlaceholderInfo]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
*,
|
*,
|
||||||
allow_missing: bool = False,
|
allow_missing: bool = False,
|
||||||
@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
# instead of rehashing.
|
# instead of rehashing.
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
model_id = self.ctx.model_config.model
|
model_id = self.info.model_id
|
||||||
mm_hashes = {
|
mm_hashes = {
|
||||||
modality: [
|
modality: [
|
||||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||||
@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
allow_missing=True,
|
allow_missing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
|
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
|
||||||
for modality, missing_repl_count in mm_missing_repl_counts.items():
|
for modality, missing_repl_count in mm_missing_repl_counts.items():
|
||||||
if missing_repl_count == 0:
|
if missing_repl_count == 0:
|
||||||
mm_missing_repls[modality] = []
|
mm_missing_repls[modality] = []
|
||||||
@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
# If HF processor already inserts placeholder tokens,
|
# If HF processor already inserts placeholder tokens,
|
||||||
# there is no need for us to insert them
|
# there is no need for us to insert them
|
||||||
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
|
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self.info.get_tokenizer()
|
||||||
prompt_text = decode_tokens(tokenizer, prompt_ids)
|
prompt_text = decode_tokens(tokenizer, prompt_ids)
|
||||||
mm_placeholders = hf_mm_placeholders
|
mm_placeholders = hf_mm_placeholders
|
||||||
else:
|
else:
|
||||||
@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
|||||||
mm_hashes=mm_hashes,
|
mm_hashes=mm_hashes,
|
||||||
mm_placeholders=mm_placeholder_ranges,
|
mm_placeholders=mm_placeholder_ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_dummy_mm_inputs(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
) -> MultiModalInputsV2:
|
|
||||||
profiling = self.profiling_info
|
|
||||||
processor_inputs = profiling.get_dummy_processor_inputs(
|
|
||||||
seq_len, mm_counts)
|
|
||||||
|
|
||||||
return self.apply(
|
|
||||||
prompt_text=processor_inputs.prompt_text,
|
|
||||||
mm_data=processor_inputs.mm_data,
|
|
||||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_dummy_data(self, seq_len: int) -> DummyData:
|
|
||||||
# Avoid circular import
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
|
|
||||||
profiling = self.profiling_info
|
|
||||||
mm_counts = profiling.get_mm_limits()
|
|
||||||
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
|
|
||||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
|
||||||
raise AssertionError(
|
|
||||||
"The keys returned by `get_supported_mm_limits`"
|
|
||||||
f"({set(mm_counts.keys())}) should be the same as those "
|
|
||||||
"returned by `get_mm_max_tokens_per_item` "
|
|
||||||
f"({set(mm_max_tokens_per_item.keys())})")
|
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
|
||||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
|
||||||
|
|
||||||
total_placeholders_by_modality = {
|
|
||||||
modality: sum(item["length"] for item in placeholders)
|
|
||||||
for modality, placeholders in placeholders_by_modality.items()
|
|
||||||
}
|
|
||||||
expected_placeholders_by_modality = {
|
|
||||||
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
|
|
||||||
for modality in placeholders_by_modality
|
|
||||||
}
|
|
||||||
if total_placeholders_by_modality != expected_placeholders_by_modality:
|
|
||||||
raise AssertionError(
|
|
||||||
f"The processed dummy data has a total of "
|
|
||||||
f"{total_placeholders_by_modality} placeholder tokens, which "
|
|
||||||
f"is not the expected {expected_placeholders_by_modality} "
|
|
||||||
"tokens.")
|
|
||||||
|
|
||||||
total_len = len(prompt_token_ids)
|
|
||||||
|
|
||||||
# V0 does not support chunked prefill.
|
|
||||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
|
||||||
logger.warning(
|
|
||||||
"The context length (%d) of the model is too short "
|
|
||||||
"to hold the multi-modal embeddings in the worst case "
|
|
||||||
"(%d tokens in total, out of which %s are reserved for "
|
|
||||||
"multi-modal embeddings). This may cause certain multi-modal "
|
|
||||||
"inputs to fail during inference, even when the input text is "
|
|
||||||
"short. To avoid this, you should increase `max_model_len`, "
|
|
||||||
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
|
||||||
total_len, total_placeholders_by_modality)
|
|
||||||
|
|
||||||
return DummyData(
|
|
||||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
|
||||||
multi_modal_data=None,
|
|
||||||
multi_modal_placeholders=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
|
||||||
|
|
||||||
return DummyData(
|
|
||||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
|
||||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
|
||||||
multi_modal_placeholders=placeholders_by_modality,
|
|
||||||
)
|
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.inputs import InputProcessingContext
|
import vllm.envs as envs
|
||||||
|
from vllm.inputs import DummyData
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .inputs import MultiModalDataDict
|
from .inputs import MultiModalDataDict, MultiModalInputsV2
|
||||||
|
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -23,39 +25,19 @@ class ProcessorInputs:
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class BaseProfilingInfo(ABC):
|
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDummyInputsBuilder(ABC, Generic[_I]):
|
||||||
"""
|
"""
|
||||||
Abstract base class that provides the information necessary to profile
|
Abstract base class that constructs the dummy data to profile
|
||||||
multi-modal models.
|
multi-modal models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
def __init__(self, info: _I) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ctx = ctx
|
self.info = info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
||||||
"""
|
|
||||||
Return the maximum supported number of items for each modality.
|
|
||||||
|
|
||||||
A value of `None` means unlimited number of items.
|
|
||||||
|
|
||||||
Omitting a modality from the returned dictionary means that
|
|
||||||
it is not supported at all.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
|
||||||
"""
|
|
||||||
Get the maximum possible number of tokens per data item
|
|
||||||
for each modality.
|
|
||||||
|
|
||||||
The dictionary returned by this method should have the same
|
|
||||||
keys as that returned by :meth:`get_supported_mm_limits`.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dummy_processor_inputs(
|
def get_dummy_processor_inputs(
|
||||||
@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
|
|||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> ProcessorInputs:
|
) -> ProcessorInputs:
|
||||||
"""
|
"""
|
||||||
Build the multi-modal portion of the input which, after processing,
|
Build the input which, after processing, results in
|
||||||
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`.
|
`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
|
|||||||
video = np.zeros((num_frames, width, height, 3))
|
video = np.zeros((num_frames, width, height, 3))
|
||||||
return [video] * num_videos
|
return [video] * num_videos
|
||||||
|
|
||||||
def get_mm_limits(self) -> Mapping[str, int]:
|
|
||||||
mm_config = self.ctx.get_mm_config()
|
class MultiModalProfiler(Generic[_I]):
|
||||||
|
"""
|
||||||
|
Contains code for running memory profiling for multi-modal models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: BaseMultiModalProcessor[_I],
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processing_info(self) -> BaseProcessingInfo:
|
||||||
|
return self.processor.info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
|
||||||
|
return self.processor.dummy_inputs
|
||||||
|
|
||||||
|
def _get_mm_limits(self) -> Mapping[str, int]:
|
||||||
|
mm_config = self.processing_info.ctx.get_mm_config()
|
||||||
mm_limit_per_prompt = mm_config.limit_per_prompt
|
mm_limit_per_prompt = mm_config.limit_per_prompt
|
||||||
|
|
||||||
supported_mm_limits = self.get_supported_mm_limits()
|
supported_mm_limits = self.processing_info.get_supported_mm_limits()
|
||||||
|
|
||||||
mm_limits = {
|
mm_limits = {
|
||||||
modality: mm_limit_per_prompt.get(modality, 1)
|
modality: mm_limit_per_prompt.get(modality, 1)
|
||||||
@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
|
|||||||
f"at most {supported_limit} {modality} items.")
|
f"at most {supported_limit} {modality} items.")
|
||||||
|
|
||||||
return mm_limits
|
return mm_limits
|
||||||
|
|
||||||
|
def _get_dummy_mm_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalInputsV2:
|
||||||
|
factory = self.dummy_inputs
|
||||||
|
processor_inputs = factory.get_dummy_processor_inputs(
|
||||||
|
seq_len, mm_counts)
|
||||||
|
|
||||||
|
return self.processor.apply(
|
||||||
|
prompt_text=processor_inputs.prompt_text,
|
||||||
|
mm_data=processor_inputs.mm_data,
|
||||||
|
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dummy_data(self, seq_len: int) -> DummyData:
|
||||||
|
# Avoid circular import
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
|
||||||
|
mm_counts = self._get_mm_limits()
|
||||||
|
|
||||||
|
info = self.processing_info
|
||||||
|
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
|
||||||
|
|
||||||
|
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||||
|
raise AssertionError(
|
||||||
|
"The keys returned by `get_supported_mm_limits`"
|
||||||
|
f"({set(mm_counts.keys())}) should be the same as those "
|
||||||
|
"returned by `get_mm_max_tokens_per_item` "
|
||||||
|
f"({set(mm_max_tokens_per_item.keys())})")
|
||||||
|
|
||||||
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
|
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||||
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
|
||||||
|
total_placeholders_by_modality = {
|
||||||
|
modality: sum(item["length"] for item in placeholders)
|
||||||
|
for modality, placeholders in placeholders_by_modality.items()
|
||||||
|
}
|
||||||
|
expected_placeholders_by_modality = {
|
||||||
|
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
|
||||||
|
for modality in placeholders_by_modality
|
||||||
|
}
|
||||||
|
if total_placeholders_by_modality != expected_placeholders_by_modality:
|
||||||
|
raise AssertionError(
|
||||||
|
f"The processed dummy data has a total of "
|
||||||
|
f"{total_placeholders_by_modality} placeholder tokens, which "
|
||||||
|
f"is not the expected {expected_placeholders_by_modality} "
|
||||||
|
"tokens.")
|
||||||
|
|
||||||
|
total_len = len(prompt_token_ids)
|
||||||
|
|
||||||
|
# V0 does not support chunked prefill.
|
||||||
|
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||||
|
logger.warning(
|
||||||
|
"The context length (%d) of the model is too short "
|
||||||
|
"to hold the multi-modal embeddings in the worst case "
|
||||||
|
"(%d tokens in total, out of which %s are reserved for "
|
||||||
|
"multi-modal embeddings). This may cause certain multi-modal "
|
||||||
|
"inputs to fail during inference, even when the input text is "
|
||||||
|
"short. To avoid this, you should increase `max_model_len`, "
|
||||||
|
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
||||||
|
total_len, total_placeholders_by_modality)
|
||||||
|
|
||||||
|
return DummyData(
|
||||||
|
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||||
|
multi_modal_data=None,
|
||||||
|
multi_modal_placeholders=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
||||||
|
|
||||||
|
return DummyData(
|
||||||
|
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||||
|
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||||
|
multi_modal_placeholders=placeholders_by_modality,
|
||||||
|
)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import functools
|
import functools
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
|
from dataclasses import dataclass
|
||||||
Sequence, Type, TypeVar)
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
|
||||||
|
Protocol, Sequence, Type, TypeVar)
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -14,7 +15,9 @@ from .audio import AudioPlugin
|
|||||||
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||||
from .image import ImagePlugin
|
from .image import ImagePlugin
|
||||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||||
from .processing import BaseMultiModalProcessor, ProcessingCache
|
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||||
|
ProcessingCache)
|
||||||
|
from .profiling import BaseDummyInputsBuilder
|
||||||
from .utils import cached_get_tokenizer
|
from .utils import cached_get_tokenizer
|
||||||
from .video import VideoPlugin
|
from .video import VideoPlugin
|
||||||
|
|
||||||
@ -27,20 +30,59 @@ logger = init_logger(__name__)
|
|||||||
MM_CACHE_SIZE = 256
|
MM_CACHE_SIZE = 256
|
||||||
|
|
||||||
N = TypeVar("N", bound=Type[nn.Module])
|
N = TypeVar("N", bound=Type[nn.Module])
|
||||||
|
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||||
|
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalProcessorFactory(Protocol):
|
class ProcessingInfoFactory(Protocol[_I_co]):
|
||||||
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
ctx: InputProcessingContext,
|
ctx: InputProcessingContext,
|
||||||
|
) -> _I_co:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class DummyInputsBuilderFactory(Protocol[_I]):
|
||||||
|
"""
|
||||||
|
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalProcessorFactory(Protocol[_I]):
|
||||||
|
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
info: _I,
|
||||||
|
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||||
*,
|
*,
|
||||||
cache: Optional[ProcessingCache] = None,
|
cache: Optional[ProcessingCache] = None,
|
||||||
) -> BaseMultiModalProcessor:
|
) -> BaseMultiModalProcessor[_I]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _ProcessorFactories(Generic[_I]):
|
||||||
|
info: ProcessingInfoFactory[_I]
|
||||||
|
processor: MultiModalProcessorFactory[_I]
|
||||||
|
dummy_inputs: DummyInputsBuilderFactory[_I]
|
||||||
|
|
||||||
|
def build_processor(
|
||||||
|
self,
|
||||||
|
ctx: InputProcessingContext,
|
||||||
|
*,
|
||||||
|
cache: Optional[ProcessingCache] = None,
|
||||||
|
):
|
||||||
|
info = self.info(ctx)
|
||||||
|
dummy_inputs_builder = self.dummy_inputs(info)
|
||||||
|
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||||
|
|
||||||
|
|
||||||
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
|
||||||
"""
|
"""
|
||||||
Wraps `_limits_by_model` for a more informative error message
|
Wraps `_limits_by_model` for a more informative error message
|
||||||
@ -71,7 +113,7 @@ class MultiModalRegistry:
|
|||||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||||
|
|
||||||
self._processor_factories = ClassRegistry[nn.Module,
|
self._processor_factories = ClassRegistry[nn.Module,
|
||||||
MultiModalProcessorFactory]()
|
_ProcessorFactories]()
|
||||||
|
|
||||||
# This is used for non-multimodal models
|
# This is used for non-multimodal models
|
||||||
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
|
||||||
@ -224,7 +266,7 @@ class MultiModalRegistry:
|
|||||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
processor = self.create_processor(model_config, tokenizer)
|
processor = self.create_processor(model_config, tokenizer)
|
||||||
seq_len = model_config.max_model_len
|
seq_len = model_config.max_model_len
|
||||||
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len)
|
return processor.info.get_mm_max_tokens_per_item(seq_len)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
key: plugin.get_max_multimodal_tokens(model_config)
|
key: plugin.get_max_multimodal_tokens(model_config)
|
||||||
@ -315,7 +357,10 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
def register_processor(
|
def register_processor(
|
||||||
self,
|
self,
|
||||||
factory: MultiModalProcessorFactory,
|
processor: MultiModalProcessorFactory[_I],
|
||||||
|
*,
|
||||||
|
info: ProcessingInfoFactory[_I],
|
||||||
|
dummy_inputs: DummyInputsBuilderFactory[_I],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Register a multi-modal processor to a model class. The processor
|
Register a multi-modal processor to a model class. The processor
|
||||||
@ -336,7 +381,11 @@ class MultiModalRegistry:
|
|||||||
"registered to %s. It is overwritten by the new one.",
|
"registered to %s. It is overwritten by the new one.",
|
||||||
model_cls, self)
|
model_cls, self)
|
||||||
|
|
||||||
self._processor_factories[model_cls] = factory
|
self._processor_factories[model_cls] = _ProcessorFactories(
|
||||||
|
info=info,
|
||||||
|
dummy_inputs=dummy_inputs,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
return model_cls
|
return model_cls
|
||||||
|
|
||||||
@ -359,15 +408,15 @@ class MultiModalRegistry:
|
|||||||
self,
|
self,
|
||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
) -> BaseMultiModalProcessor:
|
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||||
"""
|
"""
|
||||||
Create a multi-modal processor for a specific model and tokenizer.
|
Create a multi-modal processor for a specific model and tokenizer.
|
||||||
"""
|
"""
|
||||||
model_cls = self._get_model_cls(model_config)
|
model_cls = self._get_model_cls(model_config)
|
||||||
processor_factory = self._processor_factories[model_cls]
|
factories = self._processor_factories[model_cls]
|
||||||
|
|
||||||
ctx = InputProcessingContext(model_config, tokenizer)
|
ctx = InputProcessingContext(model_config, tokenizer)
|
||||||
cache = (None if model_config.disable_mm_preprocessor_cache else
|
cache = (None if model_config.disable_mm_preprocessor_cache else
|
||||||
self._processing_cache)
|
self._processing_cache)
|
||||||
|
|
||||||
return processor_factory(ctx, cache=cache)
|
return factories.build_processor(ctx, cache=cache)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user