[Model] Fix Phi-3.5-vision-instruct 'num_crops' issue (#7710)
This commit is contained in:
parent
7937009a7e
commit
df1a21131d
@ -225,9 +225,9 @@ Multimodal Language Models
|
|||||||
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
|
||||||
-
|
-
|
||||||
* - :code:`Phi3VForCausalLM`
|
* - :code:`Phi3VForCausalLM`
|
||||||
- Phi-3-Vision
|
- Phi-3-Vision, Phi-3.5-Vision
|
||||||
- Image
|
- Image
|
||||||
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
|
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
|
||||||
-
|
-
|
||||||
* - :code:`MiniCPMV`
|
* - :code:`MiniCPMV`
|
||||||
- MiniCPM-V
|
- MiniCPM-V
|
||||||
|
@ -21,7 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
|||||||
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
"<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n",
|
||||||
})
|
})
|
||||||
|
|
||||||
models = ["microsoft/Phi-3-vision-128k-instruct"]
|
models = ["microsoft/Phi-3.5-vision-instruct"]
|
||||||
|
|
||||||
|
|
||||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
@ -13,7 +13,9 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import (get_config,
|
||||||
|
get_hf_image_processor_config,
|
||||||
|
get_hf_text_config)
|
||||||
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
|
||||||
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||||
is_hip, is_neuron, is_openvino, is_xpu,
|
is_hip, is_neuron, is_openvino, is_xpu,
|
||||||
@ -167,6 +169,8 @@ class ModelConfig:
|
|||||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||||
code_revision, rope_scaling, rope_theta)
|
code_revision, rope_scaling, rope_theta)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
|
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||||
|
self.model, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||||
|
|
||||||
# Choose a default enforce_eager value if the user did not specify
|
# Choose a default enforce_eager value if the user did not specify
|
||||||
|
@ -2,8 +2,8 @@ import functools
|
|||||||
from array import array
|
from array import array
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
|
||||||
Tuple, Type)
|
Protocol, Tuple, Type)
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -55,6 +55,13 @@ class InputContext:
|
|||||||
|
|
||||||
return hf_config
|
return hf_config
|
||||||
|
|
||||||
|
def get_hf_image_processor_config(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the HuggingFace image processor configuration of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.model_config.hf_image_processor_config
|
||||||
|
|
||||||
|
|
||||||
N = TypeVar("N", bound=Type[nn.Module])
|
N = TypeVar("N", bound=Type[nn.Module])
|
||||||
|
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||||
TypedDict, Union)
|
Tuple, TypedDict, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -324,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
|||||||
|
|
||||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
||||||
def get_phi3v_image_feature_size(
|
def get_phi3v_image_feature_size(
|
||||||
hf_config: PretrainedConfig,
|
hf_config: Dict[str, Any],
|
||||||
*,
|
*,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
input_width: int,
|
input_width: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
num_crops = getattr(hf_config, "num_crops", 16)
|
num_crops = hf_config.get("num_crops", 16)
|
||||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||||
height=input_height,
|
height=input_height,
|
||||||
hd_num=num_crops)
|
hd_num=num_crops)
|
||||||
@ -341,7 +341,7 @@ def get_phi3v_image_feature_size(
|
|||||||
def get_max_phi3v_image_tokens(ctx: InputContext):
|
def get_max_phi3v_image_tokens(ctx: InputContext):
|
||||||
|
|
||||||
return get_phi3v_image_feature_size(
|
return get_phi3v_image_feature_size(
|
||||||
ctx.get_hf_config(),
|
ctx.get_hf_image_processor_config(),
|
||||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||||
)
|
)
|
||||||
@ -395,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
return llm_inputs
|
return llm_inputs
|
||||||
|
|
||||||
model_config = ctx.model_config
|
model_config = ctx.model_config
|
||||||
hf_config = ctx.get_hf_config()
|
hf_config = ctx.get_hf_image_processor_config()
|
||||||
|
|
||||||
image_data = multi_modal_data["image"]
|
image_data = multi_modal_data["image"]
|
||||||
if isinstance(image_data, Image.Image):
|
if isinstance(image_data, Image.Image):
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
|
from transformers.models.auto.image_processing_auto import (
|
||||||
|
get_image_processor_config)
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
||||||
|
|
||||||
@ -98,6 +100,17 @@ def get_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_image_processor_config(
|
||||||
|
model: Union[str, Path],
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# Separate model folder from file path for GGUF models
|
||||||
|
if Path(model).is_file() and Path(model).suffix == ".gguf":
|
||||||
|
model = Path(model).parent
|
||||||
|
return get_image_processor_config(model, revision=revision, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_hf_text_config(config: PretrainedConfig):
|
def get_hf_text_config(config: PretrainedConfig):
|
||||||
"""Get the "sub" config relevant to llm for multi modal models.
|
"""Get the "sub" config relevant to llm for multi modal models.
|
||||||
No op for pure text models.
|
No op for pure text models.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user