[model] Support for Llava-Next-Video model (#7559)
Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
efcf946a15
commit
6a512a00df
@ -145,6 +145,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
|||||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||||
|
@ -5,6 +5,7 @@ FROM ubuntu:22.04 AS cpu-test-1
|
|||||||
RUN --mount=type=cache,target=/var/cache/apt \
|
RUN --mount=type=cache,target=/var/cache/apt \
|
||||||
apt-get update -y \
|
apt-get update -y \
|
||||||
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||||
|
|
||||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
||||||
|
@ -6,7 +6,9 @@ FROM $BASE_IMAGE
|
|||||||
RUN echo "Base image is $BASE_IMAGE"
|
RUN echo "Base image is $BASE_IMAGE"
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update \
|
||||||
|
&& apt-get install python3 python3-pip -y \
|
||||||
|
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
### Mount Point ###
|
### Mount Point ###
|
||||||
# When launching the container, mount the code directory to /app
|
# When launching the container, mount the code directory to /app
|
||||||
|
@ -4,7 +4,8 @@
|
|||||||
FROM ubuntu:22.04 AS dev
|
FROM ubuntu:22.04 AS dev
|
||||||
|
|
||||||
RUN apt-get update -y && \
|
RUN apt-get update -y && \
|
||||||
apt-get install -y python3-pip git
|
apt-get install -y python3-pip git && \
|
||||||
|
apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
# copy requirements
|
# copy requirements
|
||||||
|
@ -4,7 +4,7 @@ USER root
|
|||||||
|
|
||||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||||
|
|
||||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential
|
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
# Some packages in requirements-cpu are installed here
|
# Some packages in requirements-cpu are installed here
|
||||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||||
|
@ -4,6 +4,9 @@ ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:night
|
|||||||
FROM $BASE_IMAGE
|
FROM $BASE_IMAGE
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# Install some basic utilities
|
||||||
|
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
# Install the TPU and Pallas dependencies.
|
# Install the TPU and Pallas dependencies.
|
||||||
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
@ -9,8 +9,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||||
|
|
||||||
RUN apt-get update -y \
|
RUN apt-get update -y \
|
||||||
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip
|
&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
COPY ./ /workspace/vllm
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
WORKDIR /workspace/vllm
|
||||||
|
@ -99,6 +99,7 @@ autodoc_mock_imports = [
|
|||||||
"aiohttp",
|
"aiohttp",
|
||||||
"compressed_tensors",
|
"compressed_tensors",
|
||||||
"cpuinfo",
|
"cpuinfo",
|
||||||
|
"cv2",
|
||||||
"torch",
|
"torch",
|
||||||
"transformers",
|
"transformers",
|
||||||
"psutil",
|
"psutil",
|
||||||
|
@ -227,6 +227,11 @@ Multimodal Language Models
|
|||||||
- Image\ :sup:`E+`
|
- Image\ :sup:`E+`
|
||||||
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`LlavaNextVideoForConditionalGeneration`
|
||||||
|
- LLaVA-NeXT-Video
|
||||||
|
- Video
|
||||||
|
- :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note)
|
||||||
|
-
|
||||||
* - :code:`MiniCPMV`
|
* - :code:`MiniCPMV`
|
||||||
- MiniCPM-V
|
- MiniCPM-V
|
||||||
- Image\ :sup:`+`
|
- Image\ :sup:`+`
|
||||||
@ -260,6 +265,15 @@ Multimodal Language Models
|
|||||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||||
|
|
||||||
|
For :code:`LLaVA-NeXT-Video`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now.
|
||||||
|
This can be installed by running the following command:
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830
|
||||||
|
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
|
@ -9,12 +9,9 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
# Input image and question
|
|
||||||
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
|
|
||||||
question = "What is the content of this image?"
|
|
||||||
|
|
||||||
|
|
||||||
# LLaVA-1.5
|
# LLaVA-1.5
|
||||||
def run_llava(question):
|
def run_llava(question):
|
||||||
@ -30,7 +27,16 @@ def run_llava(question):
|
|||||||
def run_llava_next(question):
|
def run_llava_next(question):
|
||||||
|
|
||||||
prompt = f"[INST] <image>\n{question} [/INST]"
|
prompt = f"[INST] <image>\n{question} [/INST]"
|
||||||
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
|
llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
|
||||||
|
stop_token_ids = None
|
||||||
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
# LlaVA-NeXT-Video
|
||||||
|
# Currently only support for video input
|
||||||
|
def run_llava_next_video(question):
|
||||||
|
prompt = f"USER: <video>\n{question} ASSISTANT:"
|
||||||
|
llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192)
|
||||||
stop_token_ids = None
|
stop_token_ids = None
|
||||||
return llm, prompt, stop_token_ids
|
return llm, prompt, stop_token_ids
|
||||||
|
|
||||||
@ -176,6 +182,7 @@ def run_qwen_vl(question):
|
|||||||
model_example_map = {
|
model_example_map = {
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
"llava-next": run_llava_next,
|
"llava-next": run_llava_next,
|
||||||
|
"llava-next-video": run_llava_next_video,
|
||||||
"fuyu": run_fuyu,
|
"fuyu": run_fuyu,
|
||||||
"phi3_v": run_phi3v,
|
"phi3_v": run_phi3v,
|
||||||
"paligemma": run_paligemma,
|
"paligemma": run_paligemma,
|
||||||
@ -187,11 +194,49 @@ model_example_map = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_modal_input(args):
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"data": image or video,
|
||||||
|
"question": question,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if args.modality == "image":
|
||||||
|
# Input image and question
|
||||||
|
image = ImageAsset("cherry_blossom") \
|
||||||
|
.pil_image.convert("RGB")
|
||||||
|
img_question = "What is the content of this image?"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": image,
|
||||||
|
"question": img_question,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.modality == "video":
|
||||||
|
# Input video and question
|
||||||
|
video = VideoAsset(name="sample_demo_1.mp4",
|
||||||
|
num_frames=args.num_frames).np_ndarrays
|
||||||
|
vid_question = "Why is this video funny?"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": video,
|
||||||
|
"question": vid_question,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = f"Modality {args.modality} is not supported."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
model = args.model_type
|
model = args.model_type
|
||||||
if model not in model_example_map:
|
if model not in model_example_map:
|
||||||
raise ValueError(f"Model type {model} is not supported.")
|
raise ValueError(f"Model type {model} is not supported.")
|
||||||
|
|
||||||
|
modality = args.modality
|
||||||
|
mm_input = get_multi_modal_input(args)
|
||||||
|
data = mm_input["data"]
|
||||||
|
question = mm_input["question"]
|
||||||
|
|
||||||
llm, prompt, stop_token_ids = model_example_map[model](question)
|
llm, prompt, stop_token_ids = model_example_map[model](question)
|
||||||
|
|
||||||
# We set temperature to 0.2 so that outputs can be different
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
@ -206,7 +251,7 @@ def main(args):
|
|||||||
inputs = {
|
inputs = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"image": image
|
modality: data
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +260,7 @@ def main(args):
|
|||||||
inputs = [{
|
inputs = [{
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"multi_modal_data": {
|
"multi_modal_data": {
|
||||||
"image": image
|
modality: data
|
||||||
},
|
},
|
||||||
} for _ in range(args.num_prompts)]
|
} for _ in range(args.num_prompts)]
|
||||||
|
|
||||||
@ -238,8 +283,15 @@ if __name__ == "__main__":
|
|||||||
help='Huggingface "model_type".')
|
help='Huggingface "model_type".')
|
||||||
parser.add_argument('--num-prompts',
|
parser.add_argument('--num-prompts',
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=4,
|
||||||
help='Number of prompts to run.')
|
help='Number of prompts to run.')
|
||||||
|
parser.add_argument('--modality',
|
||||||
|
type=str,
|
||||||
|
default="image",
|
||||||
|
help='Modality of the input.')
|
||||||
|
parser.add_argument('--num-frames',
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help='Number of frames to extract from the video.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -11,6 +11,7 @@ awscli
|
|||||||
einops # required for MPT, qwen-vl and Mamba
|
einops # required for MPT, qwen-vl and Mamba
|
||||||
httpx
|
httpx
|
||||||
librosa # required for audio test
|
librosa # required for audio test
|
||||||
|
opencv-python # required for video test
|
||||||
peft
|
peft
|
||||||
requests
|
requests
|
||||||
ray[adag]>=2.35
|
ray[adag]>=2.35
|
||||||
|
1
setup.py
1
setup.py
@ -505,6 +505,7 @@ setup(
|
|||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
extras_require={
|
extras_require={
|
||||||
"tensorizer": ["tensorizer>=2.9.0"],
|
"tensorizer": ["tensorizer>=2.9.0"],
|
||||||
|
"video": ["opencv-python"], # Required for video processing
|
||||||
"audio": ["librosa", "soundfile"] # Required for audio processing
|
"audio": ["librosa", "soundfile"] # Required for audio processing
|
||||||
},
|
},
|
||||||
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
|
||||||
|
@ -21,6 +21,7 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TokenizerPoolConfig
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
@ -44,6 +45,7 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
|
|||||||
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
|
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
|
||||||
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
|
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
|
||||||
List[List[Tuple[np.ndarray, int]]]]
|
List[List[Tuple[np.ndarray, int]]]]
|
||||||
|
PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
|
||||||
|
|
||||||
|
|
||||||
def _read_prompts(filename: str) -> List[str]:
|
def _read_prompts(filename: str) -> List[str]:
|
||||||
@ -85,8 +87,35 @@ class _ImageAssets(_ImageAssetsBase):
|
|||||||
return [prompts["stop_sign"], prompts["cherry_blossom"]]
|
return [prompts["stop_sign"], prompts["cherry_blossom"]]
|
||||||
|
|
||||||
|
|
||||||
|
class _VideoAssetPrompts(TypedDict):
|
||||||
|
sample_demo_1: str
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
# UserList cannot be subscripted
|
||||||
|
class _VideoAssetsBase(UserList):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
|
||||||
|
class _VideoAssetsBase(UserList[VideoAsset]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _VideoAssets(_VideoAssetsBase):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__([
|
||||||
|
VideoAsset("sample_demo_1.mp4"),
|
||||||
|
])
|
||||||
|
|
||||||
|
def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
|
||||||
|
return [prompts["sample_demo_1"]]
|
||||||
|
|
||||||
|
|
||||||
IMAGE_ASSETS = _ImageAssets()
|
IMAGE_ASSETS = _ImageAssets()
|
||||||
"""Singleton instance of :class:`_ImageAssets`."""
|
"""Singleton instance of :class:`_ImageAssets`."""
|
||||||
|
VIDEO_ASSETS = _VideoAssets()
|
||||||
|
"""Singleton instance of :class:`_VideoAssets`."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -202,6 +231,11 @@ def image_assets() -> _ImageAssets:
|
|||||||
return IMAGE_ASSETS
|
return IMAGE_ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def video_assets() -> _VideoAssets:
|
||||||
|
return VIDEO_ASSETS
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
|
||||||
|
|
||||||
|
|
||||||
@ -279,6 +313,7 @@ class HfRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
|
videos: Optional[List[np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
if images:
|
if images:
|
||||||
@ -292,6 +327,8 @@ class HfRunner:
|
|||||||
}
|
}
|
||||||
if images is not None and images[i] is not None:
|
if images is not None and images[i] is not None:
|
||||||
processor_kwargs["images"] = images[i]
|
processor_kwargs["images"] = images[i]
|
||||||
|
if videos is not None and videos[i] is not None:
|
||||||
|
processor_kwargs["videos"] = videos[i]
|
||||||
|
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
inputs = self.postprocess_inputs(inputs)
|
inputs = self.postprocess_inputs(inputs)
|
||||||
@ -352,6 +389,7 @@ class HfRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
|
videos: Optional[List[np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[List[torch.Tensor]]:
|
) -> List[List[torch.Tensor]]:
|
||||||
all_logprobs: List[List[torch.Tensor]] = []
|
all_logprobs: List[List[torch.Tensor]] = []
|
||||||
@ -362,6 +400,8 @@ class HfRunner:
|
|||||||
}
|
}
|
||||||
if images is not None and images[i] is not None:
|
if images is not None and images[i] is not None:
|
||||||
processor_kwargs["images"] = images[i]
|
processor_kwargs["images"] = images[i]
|
||||||
|
if videos is not None and videos[i] is not None:
|
||||||
|
processor_kwargs["videos"] = videos[i]
|
||||||
|
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
inputs = self.postprocess_inputs(inputs)
|
inputs = self.postprocess_inputs(inputs)
|
||||||
@ -435,6 +475,7 @@ class HfRunner:
|
|||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
videos: Optional[List[np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||||
all_logprobs: List[List[Dict[int, float]]] = []
|
all_logprobs: List[List[Dict[int, float]]] = []
|
||||||
@ -454,6 +495,8 @@ class HfRunner:
|
|||||||
processor_kwargs["audio"] = audio
|
processor_kwargs["audio"] = audio
|
||||||
processor_kwargs["sampling_rate"] = sr
|
processor_kwargs["sampling_rate"] = sr
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
processor_kwargs["videos"] = videos[i]
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
inputs = self.postprocess_inputs(inputs)
|
inputs = self.postprocess_inputs(inputs)
|
||||||
|
|
||||||
@ -634,12 +677,16 @@ class VllmRunner:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
videos: Optional[PromptVideoInput] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
assert sampling_params.logprobs is not None
|
assert sampling_params.logprobs is not None
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
assert len(prompts) == len(videos)
|
||||||
|
|
||||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||||
if images is not None:
|
if images is not None:
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
@ -649,6 +696,11 @@ class VllmRunner:
|
|||||||
for i, audio in enumerate(audios):
|
for i, audio in enumerate(audios):
|
||||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||||
|
|
||||||
|
if videos is not None:
|
||||||
|
for i, video in enumerate(videos):
|
||||||
|
inputs[i]["multi_modal_data"] = {"video": video}
|
||||||
|
print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")
|
||||||
|
|
||||||
req_outputs = self.model.generate(inputs,
|
req_outputs = self.model.generate(inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
return self._final_steps_generate_w_logprobs(req_outputs)
|
||||||
@ -685,6 +737,7 @@ class VllmRunner:
|
|||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
|
videos: Optional[PromptVideoInput] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
@ -694,7 +747,8 @@ class VllmRunner:
|
|||||||
outputs = self.generate_w_logprobs(prompts,
|
outputs = self.generate_w_logprobs(prompts,
|
||||||
greedy_logprobs_params,
|
greedy_logprobs_params,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios)
|
audios=audios,
|
||||||
|
videos=videos)
|
||||||
|
|
||||||
return [(output_ids, output_str, output_logprobs)
|
return [(output_ids, output_str, output_logprobs)
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
236
tests/models/test_llava_next_video.py
Normal file
236
tests/models/test_llava_next_video.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
from typing import List, Optional, Tuple, Type, overload
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import (rescale_video_size, resize_video,
|
||||||
|
sample_frames_from_video)
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ..conftest import VIDEO_ASSETS, HfRunner, VllmRunner, _VideoAssets
|
||||||
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.vlm
|
||||||
|
|
||||||
|
_PREFACE = (
|
||||||
|
"A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's "
|
||||||
|
"questions.")
|
||||||
|
|
||||||
|
HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
|
||||||
|
"sample_demo_1":
|
||||||
|
f"{_PREFACE}USER: <video>\nWhy is this video funny? ASSISTANT:"
|
||||||
|
})
|
||||||
|
|
||||||
|
models = ["llava-hf/LLaVA-NeXT-Video-7B-hf"]
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||||
|
Optional[SampleLogprobs]],
|
||||||
|
model: str):
|
||||||
|
"""Sanitize vllm output to be comparable with hf output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model)
|
||||||
|
video_token_id = config.video_token_index
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
hf_output_ids = [
|
||||||
|
token_id for idx, token_id in enumerate(output_ids)
|
||||||
|
if token_id != video_token_id or output_ids[idx - 1] != video_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert output_str[0] == " "
|
||||||
|
hf_output_str = output_str[1:]
|
||||||
|
if hf_output_ids[-1] == eos_token_id:
|
||||||
|
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
|
||||||
|
|
||||||
|
return hf_output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
video_assets: _VideoAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: List[float],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
num_frames: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
video_assets: _VideoAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
sizes: List[Tuple[int, int]],
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
num_frames: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: Type[HfRunner],
|
||||||
|
vllm_runner: Type[VllmRunner],
|
||||||
|
video_assets: _VideoAssets,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
size_factors: Optional[List[float]] = None,
|
||||||
|
sizes: Optional[List[Tuple[int, int]]] = None,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
num_frames: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
videos = [
|
||||||
|
sample_frames_from_video(asset.np_ndarrays, num_frames)
|
||||||
|
for asset in video_assets
|
||||||
|
]
|
||||||
|
|
||||||
|
for video in videos:
|
||||||
|
print(video.shape)
|
||||||
|
|
||||||
|
if size_factors is not None:
|
||||||
|
inputs_per_video = [(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_video_size(video, factor) for factor in size_factors],
|
||||||
|
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
|
||||||
|
elif sizes is not None:
|
||||||
|
inputs_per_video = [(
|
||||||
|
[prompt for _ in sizes],
|
||||||
|
[resize_video(video, size) for size in sizes],
|
||||||
|
) for video, prompt in zip(videos, HF_VIDEO_PROMPTS)]
|
||||||
|
else:
|
||||||
|
raise ValueError("You must provide either `size_factors` or `sizes`")
|
||||||
|
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=4096,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs_per_video = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
videos=videos)
|
||||||
|
for prompts, videos in inputs_per_video
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
auto_cls=AutoModelForVision2Seq) as hf_model:
|
||||||
|
hf_outputs_per_video = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
videos=videos)
|
||||||
|
for prompts, videos in inputs_per_video
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_video,
|
||||||
|
vllm_outputs_per_video):
|
||||||
|
# TODO: Check whether using original CLIPVisionModel can improve
|
||||||
|
# consistency against HF
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(vllm_output, model)
|
||||||
|
for vllm_output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.45",
|
||||||
|
reason="Waiting for next transformers release")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"size_factors",
|
||||||
|
[
|
||||||
|
# No video
|
||||||
|
[],
|
||||||
|
# Single-scale
|
||||||
|
[1.0],
|
||||||
|
# Single-scale, batched
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
# Multi-scale
|
||||||
|
[0.25, 0.5, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("num_frames", [16])
|
||||||
|
def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
|
||||||
|
dtype, max_tokens, num_logprobs, num_frames) -> None:
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the image fixtures for the test is under tests/videos.
|
||||||
|
For huggingface runner, we provide the np.ndarray as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding MultiModalConfig as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
video_assets,
|
||||||
|
model,
|
||||||
|
size_factors=size_factors,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(transformers.__version__ < "4.45",
|
||||||
|
reason="Waiting for next transformers release")
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sizes",
|
||||||
|
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
@pytest.mark.parametrize("num_frames", [16])
|
||||||
|
def test_models_fixed_sizes(hf_runner, vllm_runner, video_assets, model, sizes,
|
||||||
|
dtype, max_tokens, num_logprobs,
|
||||||
|
num_frames) -> None:
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
video_assets,
|
||||||
|
model,
|
||||||
|
sizes=sizes,
|
||||||
|
dtype=dtype,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
85
vllm/assets/video.py
Normal file
85
vllm/assets/video.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.multimodal.utils import (sample_frames_from_video,
|
||||||
|
try_import_video_packages)
|
||||||
|
|
||||||
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def download_video_asset(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Download and open an image from huggingface
|
||||||
|
repo: raushan-testing-hf/videos-test
|
||||||
|
"""
|
||||||
|
video_directory = get_cache_dir() / "video-eample-data"
|
||||||
|
video_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
video_path = video_directory / filename
|
||||||
|
video_path_str = str(video_path)
|
||||||
|
if not video_path.exists():
|
||||||
|
video_path_str = hf_hub_download(
|
||||||
|
repo_id="raushan-testing-hf/videos-test",
|
||||||
|
filename=filename,
|
||||||
|
repo_type="dataset",
|
||||||
|
cache_dir=video_directory,
|
||||||
|
)
|
||||||
|
return video_path_str
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||||
|
cv2 = try_import_video_packages()
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise ValueError(f"Could not open video file {path}")
|
||||||
|
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frames = []
|
||||||
|
for i in range(total_frames):
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if ret:
|
||||||
|
frames.append(frame)
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
frames = np.stack(frames)
|
||||||
|
frames = sample_frames_from_video(frames, num_frames)
|
||||||
|
if len(frames) < num_frames:
|
||||||
|
raise ValueError(f"Could not read enough frames from video file {path}"
|
||||||
|
f" (expected {num_frames} frames, got {len(frames)})")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_pil_images_list(path: str,
|
||||||
|
num_frames: int = -1) -> List[Image.Image]:
|
||||||
|
cv2 = try_import_video_packages()
|
||||||
|
frames = video_to_ndarrays(path, num_frames)
|
||||||
|
return [
|
||||||
|
Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
|
for frame in frames
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class VideoAsset:
|
||||||
|
name: Literal["sample_demo_1.mp4"]
|
||||||
|
num_frames: int = -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pil_images(self) -> List[Image.Image]:
|
||||||
|
video_path = download_video_asset(self.name)
|
||||||
|
ret = video_to_pil_images_list(video_path, self.num_frames)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def np_ndarrays(self) -> List[npt.NDArray]:
|
||||||
|
video_path = download_video_asset(self.name)
|
||||||
|
ret = video_to_ndarrays(video_path, self.num_frames)
|
||||||
|
return ret
|
@ -80,8 +80,10 @@ _MULTIMODAL_MODELS = {
|
|||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
"LlavaForConditionalGeneration":
|
"LlavaForConditionalGeneration":
|
||||||
("llava", "LlavaForConditionalGeneration"),
|
("llava", "LlavaForConditionalGeneration"),
|
||||||
"LlavaNextForConditionalGeneration":
|
"LlavaNextForConditionalGeneration": ("llava_next",
|
||||||
("llava_next", "LlavaNextForConditionalGeneration"),
|
"LlavaNextForConditionalGeneration"),
|
||||||
|
"LlavaNextVideoForConditionalGeneration":
|
||||||
|
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
|
||||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||||
"PaliGemmaForConditionalGeneration": ("paligemma",
|
"PaliGemmaForConditionalGeneration": ("paligemma",
|
||||||
"PaliGemmaForConditionalGeneration"),
|
"PaliGemmaForConditionalGeneration"),
|
||||||
|
471
vllm/model_executor/models/llava_next_video.py
Normal file
471
vllm/model_executor/models/llava_next_video.py
Normal file
@ -0,0 +1,471 @@
|
|||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
|
TypedDict, Union)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
|
||||||
|
SiglipVisionConfig)
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, MultiModalConfig
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||||
|
repeat_and_pad_placeholder_tokens)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
|
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||||
|
from .interfaces import SupportsMultiModal
|
||||||
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
|
dummy_seq_data_for_siglip)
|
||||||
|
from .utils import (filter_weights, init_vllm_registered_model,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# For profile run
|
||||||
|
_MAX_FRAMES_PER_VIDEO = 32
|
||||||
|
_MAX_NUM_VIDEOS = 1
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextVideoPixelInputs(TypedDict):
|
||||||
|
type: Literal["pixel_values_videos"]
|
||||||
|
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
Shape: `(batch_size, num_frames, num_channels, height, width)`
|
||||||
|
|
||||||
|
Note that `num_frames` may be different for each batch, in which case
|
||||||
|
the data is passed as a list instead of a batched tensor.
|
||||||
|
|
||||||
|
Note that it only supports one video input for one batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_llava_next_video_frame_feature_size(
|
||||||
|
hf_config: LlavaNextVideoConfig) -> int:
|
||||||
|
# Support both CLIPVisionConfig and SiglipVisionConfig
|
||||||
|
image_size = hf_config.vision_config.image_size
|
||||||
|
patch_size = hf_config.vision_config.patch_size
|
||||||
|
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||||
|
|
||||||
|
return int((image_size / patch_size / spatial_pool_stride)**2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_llm_tokens(ctx: InputContext) -> int:
|
||||||
|
"""
|
||||||
|
Calculated from the maximum video frames under the context length
|
||||||
|
constraints of the language model.
|
||||||
|
"""
|
||||||
|
hf_text_config = ctx.model_config.hf_text_config
|
||||||
|
model_config = ctx.model_config
|
||||||
|
max_tokens = model_config.max_model_len
|
||||||
|
rope_scaling = model_config.rope_scaling
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
|
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
|
||||||
|
else:
|
||||||
|
rope_scaling_factor = 1
|
||||||
|
|
||||||
|
max_tokens *= rope_scaling_factor
|
||||||
|
|
||||||
|
return max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
|
||||||
|
# Currently set to 32 frames
|
||||||
|
# TODO: max_tokens = _get_max_llm_tokens(ctx)
|
||||||
|
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||||
|
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||||
|
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int]):
|
||||||
|
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
# TODO: support multiple videos
|
||||||
|
num_videos = mm_counts["video"]
|
||||||
|
if num_videos != _MAX_NUM_VIDEOS:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Only {_MAX_NUM_VIDEOS} videos are supported")
|
||||||
|
|
||||||
|
# TODO: support configuring the number of frames
|
||||||
|
frames_per_video = _MAX_FRAMES_PER_VIDEO
|
||||||
|
# num_images = num_videos * frames_per_video
|
||||||
|
|
||||||
|
# fills the sequence with as longer video data as possible
|
||||||
|
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||||
|
video_feature_size = frames_per_video * tokens_per_frame
|
||||||
|
|
||||||
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
|
seq_data = dummy_seq_data_for_clip(
|
||||||
|
vision_config,
|
||||||
|
seq_len,
|
||||||
|
num_videos,
|
||||||
|
image_token_id=hf_config.video_token_index,
|
||||||
|
image_feature_size_override=video_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
|
||||||
|
np_frame = np.array(pil_frame["image"])
|
||||||
|
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||||
|
mm_data = {"video": mm_data_per_video}
|
||||||
|
return seq_data, mm_data
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
seq_data = dummy_seq_data_for_siglip(
|
||||||
|
vision_config,
|
||||||
|
seq_len,
|
||||||
|
num_videos,
|
||||||
|
image_token_id=hf_config.video_token_index,
|
||||||
|
image_feature_size_override=video_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
|
||||||
|
np_frame = np.array(pil_frame["image"])
|
||||||
|
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||||
|
mm_data = {"video": mm_data_per_video}
|
||||||
|
return seq_data, mm_data
|
||||||
|
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def input_processor_for_llava_next_video(ctx: InputContext,
|
||||||
|
llm_inputs: LLMInputs):
|
||||||
|
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||||
|
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||||
|
return llm_inputs
|
||||||
|
video_data = multi_modal_data["video"]
|
||||||
|
|
||||||
|
model_config = ctx.model_config
|
||||||
|
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
if isinstance(video_data, np.ndarray):
|
||||||
|
# Supports both CLIP and Siglip
|
||||||
|
num_frames = video_data.shape[0]
|
||||||
|
frame_feature_size = \
|
||||||
|
get_llava_next_video_frame_feature_size(hf_config)
|
||||||
|
video_feature_size = num_frames * frame_feature_size
|
||||||
|
|
||||||
|
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||||
|
|
||||||
|
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
|
||||||
|
tokenizer,
|
||||||
|
llm_inputs.get("prompt"),
|
||||||
|
llm_inputs["prompt_token_ids"],
|
||||||
|
placeholder_token_id=hf_config.video_token_index,
|
||||||
|
repeat_count=video_feature_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||||
|
prompt=new_prompt,
|
||||||
|
multi_modal_data=multi_modal_data)
|
||||||
|
|
||||||
|
elif is_list_of(video_data, np.ndarray):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Processing multiple videos is not supported")
|
||||||
|
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _init_vision_tower(hf_config: LlavaNextVideoConfig):
|
||||||
|
vision_config = hf_config.vision_config
|
||||||
|
|
||||||
|
# Initialize the vision tower only up to the required feature layer
|
||||||
|
vision_feature_layer = hf_config.vision_feature_layer
|
||||||
|
if vision_feature_layer < 0:
|
||||||
|
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
|
||||||
|
+ vision_feature_layer + 1
|
||||||
|
else:
|
||||||
|
num_hidden_layers = vision_feature_layer + 1
|
||||||
|
|
||||||
|
if isinstance(vision_config, CLIPVisionConfig):
|
||||||
|
return CLIPVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
elif isinstance(vision_config, SiglipVisionConfig):
|
||||||
|
return SiglipVisionModel(
|
||||||
|
vision_config,
|
||||||
|
num_hidden_layers_override=num_hidden_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
# adopted from transformers modeling_llava_next_video.py
|
||||||
|
class LlavaNextVideoPooler(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mode = config.spatial_pool_mode
|
||||||
|
stride = config.spatial_pool_stride
|
||||||
|
image_size = config.vision_config.image_size
|
||||||
|
patch_size = config.vision_config.patch_size
|
||||||
|
self.image_size = image_size // patch_size**2
|
||||||
|
|
||||||
|
if mode == "average":
|
||||||
|
self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||||
|
elif mode == "max":
|
||||||
|
self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
||||||
|
else:
|
||||||
|
# TODO: Support Conv2d pooling layer, need to load weights
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
|
||||||
|
|
||||||
|
def forward(self, image_features):
|
||||||
|
ori_width = int(
|
||||||
|
math.sqrt(image_features.shape[1] * self.image_size //
|
||||||
|
self.image_size))
|
||||||
|
ori_height = int(ori_width * self.image_size // self.image_size)
|
||||||
|
|
||||||
|
batch_size, _, dim = image_features.shape
|
||||||
|
image_features_spatial = image_features \
|
||||||
|
.view(batch_size, ori_height, ori_height, dim) \
|
||||||
|
.permute(0, 3, 1, 2)
|
||||||
|
image_features_spatial = self.pool(image_features_spatial)
|
||||||
|
|
||||||
|
return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class LlavaNextMultiModalProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
|
||||||
|
projector_hidden_act: str):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(vision_hidden_size,
|
||||||
|
text_hidden_size,
|
||||||
|
bias=True)
|
||||||
|
self.act = get_act_fn(projector_hidden_act)
|
||||||
|
self.linear_2 = nn.Linear(text_hidden_size,
|
||||||
|
text_hidden_size,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.linear_1(image_features)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_input_mapper("video")
|
||||||
|
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||||
|
"video", get_max_llava_next_video_tokens)
|
||||||
|
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
|
||||||
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
|
||||||
|
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: LlavaNextVideoConfig,
|
||||||
|
multimodal_config: MultiModalConfig,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
# Initialize the vision tower only up to the required feature layer
|
||||||
|
self.vision_tower = _init_vision_tower(config)
|
||||||
|
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||||
|
vision_hidden_size=config.vision_config.hidden_size,
|
||||||
|
text_hidden_size=config.text_config.hidden_size,
|
||||||
|
projector_hidden_act=config.projector_hidden_act)
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
config.text_config, cache_config, quant_config)
|
||||||
|
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||||
|
|
||||||
|
def _validate_video_pixel_values(
|
||||||
|
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
|
||||||
|
h = w = self.config.vision_config.image_size
|
||||||
|
expected_dims = (3, h, w)
|
||||||
|
|
||||||
|
def _validate_shape(d: torch.Tensor):
|
||||||
|
actual_dims = tuple(d.shape[2:])
|
||||||
|
|
||||||
|
if actual_dims != expected_dims:
|
||||||
|
expected_expr = ("num_frames", *map(str, expected_dims))
|
||||||
|
raise ValueError(
|
||||||
|
"The expected shape of pixel values in each video frame "
|
||||||
|
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||||
|
|
||||||
|
for d in data:
|
||||||
|
_validate_shape(d)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_video_input(
|
||||||
|
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
|
||||||
|
"""
|
||||||
|
A legal video input should have the following dimensions:
|
||||||
|
{
|
||||||
|
"pixel_values_videos" :
|
||||||
|
List[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pixel_values = kwargs.pop("pixel_values_videos", None)
|
||||||
|
|
||||||
|
if pixel_values is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not (is_list_of(pixel_values,
|
||||||
|
(torch.Tensor)) # different shape videos
|
||||||
|
or isinstance(pixel_values,
|
||||||
|
torch.Tensor)): # same shape videos
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
return LlavaNextVideoPixelInputs(
|
||||||
|
type="pixel_values_videos",
|
||||||
|
data=pixel_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
|
strategy: str) -> torch.Tensor:
|
||||||
|
if strategy == "default":
|
||||||
|
return image_features[:, 1:]
|
||||||
|
elif strategy == "full":
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||||
|
|
||||||
|
def _video_pixels_to_features(
|
||||||
|
self,
|
||||||
|
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# NOTE: we skip the step to select the vision feature layer since
|
||||||
|
# this is already done inside the vision tower
|
||||||
|
image_features = vision_tower(pixel_values)
|
||||||
|
image_features = self._select_image_features(
|
||||||
|
image_features,
|
||||||
|
strategy=self.config.vision_feature_select_strategy,
|
||||||
|
)
|
||||||
|
image_features = self.vision_resampler(image_features)
|
||||||
|
image_features = self.multi_modal_projector(image_features)
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
|
||||||
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
|
video_pixels = inputs["data"]
|
||||||
|
|
||||||
|
if isinstance(video_pixels, torch.Tensor):
|
||||||
|
# TODO: support multiple videos per input
|
||||||
|
b, num_videos, num_frames, c, h, w = video_pixels.shape
|
||||||
|
assert (num_videos == 1)
|
||||||
|
stacked_pixels = video_pixels.view(b * num_videos * num_frames, c,
|
||||||
|
h, w)
|
||||||
|
stacked_embeddings = self._video_pixels_to_features(
|
||||||
|
self.vision_tower, stacked_pixels)
|
||||||
|
return stacked_embeddings.view(b, num_frames,
|
||||||
|
*stacked_embeddings.shape[1:])
|
||||||
|
|
||||||
|
elif is_list_of(video_pixels, torch.Tensor):
|
||||||
|
frames_per_videos = [v.shape[0] for v in video_pixels]
|
||||||
|
stacked_pixels = torch.cat(video_pixels, dim=0)
|
||||||
|
stacked_embeddings = self._video_pixels_to_features(
|
||||||
|
self.vision_tower, stacked_pixels)
|
||||||
|
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type of video input {type(video_pixels)}")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
"""Run forward pass for LlaVA-NeXT-Video.
|
||||||
|
Args:
|
||||||
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
|
batch.
|
||||||
|
pixel_values_videos: Pixels in each frames for each input videos.
|
||||||
|
"""
|
||||||
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
|
||||||
|
# merge video embeddings into input embeddings
|
||||||
|
if video_input is not None:
|
||||||
|
video_embeddings = self._process_video_pixels(video_input)
|
||||||
|
inputs_embeds = self.language_model \
|
||||||
|
.model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, video_embeddings,
|
||||||
|
self.config.video_token_index)
|
||||||
|
|
||||||
|
input_ids = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
positions,
|
||||||
|
kv_caches,
|
||||||
|
attn_metadata,
|
||||||
|
None,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
# prepare weight iterators
|
||||||
|
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
||||||
|
weights, 4)
|
||||||
|
|
||||||
|
# load vision encoder
|
||||||
|
vit_weights = filter_weights(vit_weights, "vision_tower")
|
||||||
|
self.vision_tower.load_weights(vit_weights)
|
||||||
|
|
||||||
|
# load mlp projector
|
||||||
|
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
||||||
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
|
for name, loaded_weight in mlp_weights:
|
||||||
|
param = mlp_params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
# load llm backbone
|
||||||
|
llm_weights = filter_weights(llm_weights, "language_model")
|
||||||
|
self.language_model.load_weights(llm_weights)
|
@ -9,6 +9,7 @@ from .audio import AudioPlugin
|
|||||||
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
||||||
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
|
MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
|
||||||
from .image import ImagePlugin
|
from .image import ImagePlugin
|
||||||
|
from .video import VideoPlugin
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ class MultiModalRegistry:
|
|||||||
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin())
|
DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -4,6 +4,7 @@ from io import BytesIO
|
|||||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
@ -187,6 +188,47 @@ def rescale_image_size(image: Image.Image,
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def try_import_video_packages() -> Any:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install vllm[video] for video support.") from None
|
||||||
|
return cv2
|
||||||
|
|
||||||
|
|
||||||
|
def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray:
|
||||||
|
cv2 = try_import_video_packages()
|
||||||
|
|
||||||
|
num_frames, _, _, channels = frames.shape
|
||||||
|
new_height, new_width = size
|
||||||
|
resized_frames = np.empty((num_frames, new_height, new_width, channels),
|
||||||
|
dtype=frames.dtype)
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||||
|
resized_frames[i] = resized_frame
|
||||||
|
return resized_frames
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray:
|
||||||
|
_, height, width, _ = frames.shape
|
||||||
|
new_height = int(height * size_factor)
|
||||||
|
new_width = int(width * size_factor)
|
||||||
|
|
||||||
|
return resize_video(frames, (new_height, new_width))
|
||||||
|
|
||||||
|
|
||||||
|
def sample_frames_from_video(frames: npt.NDArray,
|
||||||
|
num_frames: int) -> npt.NDArray:
|
||||||
|
total_frames = frames.shape[0]
|
||||||
|
if num_frames == -1:
|
||||||
|
return frames
|
||||||
|
else:
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
|
sampled_frames = frames[frame_indices, ...]
|
||||||
|
return sampled_frames
|
||||||
|
|
||||||
|
|
||||||
# Utilities for input processors
|
# Utilities for input processors
|
||||||
_T = TypeVar("_T", str, int)
|
_T = TypeVar("_T", str, int)
|
||||||
|
|
||||||
|
71
vllm/multimodal/video.py
Normal file
71
vllm/multimodal/video.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.inputs.registry import InputContext
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.image_processor import get_video_processor
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
|
from .base import MultiModalData, MultiModalInputs
|
||||||
|
from .image import ImagePlugin
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
cached_get_video_processor = lru_cache(get_video_processor)
|
||||||
|
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||||
|
|
||||||
|
VideoInput = Union[
|
||||||
|
"np.ndarray", # single video input
|
||||||
|
List["np.ndarray"],
|
||||||
|
# TODO: support more types
|
||||||
|
# List[Image.Image], List[List[Image.Image]],
|
||||||
|
# "torch.Tensor",
|
||||||
|
# List["torch.Tensor"],
|
||||||
|
# List[List["np.ndarrray"]],
|
||||||
|
# List[List["torch.Tensor"]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPlugin(ImagePlugin):
|
||||||
|
"""Plugin for video data."""
|
||||||
|
|
||||||
|
def get_data_key(self) -> str:
|
||||||
|
return "video"
|
||||||
|
|
||||||
|
def _get_hf_video_processor(self, model_config: ModelConfig):
|
||||||
|
return cached_get_video_processor(
|
||||||
|
model_config.model,
|
||||||
|
trust_remote_code=model_config.trust_remote_code)
|
||||||
|
|
||||||
|
def _default_input_mapper(
|
||||||
|
self,
|
||||||
|
ctx: InputContext,
|
||||||
|
data: MultiModalData[object],
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
model_config = ctx.model_config
|
||||||
|
|
||||||
|
# single video input as np.ndarray
|
||||||
|
if isinstance(data, np.ndarray):
|
||||||
|
video_processor = self._get_hf_video_processor(model_config)
|
||||||
|
if video_processor is None:
|
||||||
|
raise RuntimeError("No HuggingFace processor is available "
|
||||||
|
"to process the image object")
|
||||||
|
try:
|
||||||
|
batch_data = video_processor(data, return_tensors="pt").data
|
||||||
|
except Exception:
|
||||||
|
logger.error("Failed to process image (%s)", data)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return MultiModalInputs(batch_data)
|
||||||
|
elif is_list_of(data, np.ndarray):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Multi video for a prompt is not supported yet")
|
||||||
|
|
||||||
|
raise TypeError(f"Invalid video type: {type(data)}")
|
||||||
|
|
||||||
|
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
|
||||||
|
return 4096
|
@ -1,6 +1,33 @@
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_processor(
|
||||||
|
processor_name: str,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Gets a processor for the given model name via HuggingFace.
|
||||||
|
"""
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(processor_name)
|
||||||
|
video_processor = processor.video_processor
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
if not trust_remote_code:
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the processor. If the processor is "
|
||||||
|
"a custom processor not yet available in the HuggingFace "
|
||||||
|
"transformers library, consider setting "
|
||||||
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return video_processor
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
def get_image_processor(
|
||||||
processor_name: str,
|
processor_name: str,
|
||||||
*args,
|
*args,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user