diff --git a/requirements/common.txt b/requirements/common.txt index bb021d9e..8d910868 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -28,7 +28,7 @@ pyzmq msgspec gguf == 0.10.0 importlib_metadata -mistral_common[opencv] >= 1.5.0 +mistral_common[opencv] >= 1.5.4 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 diff --git a/requirements/docs.txt b/requirements/docs.txt index 7a9b921a..416ca503 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -15,7 +15,7 @@ pydantic >= 2.8 torch py-cpuinfo transformers -mistral_common >= 1.5.0 +mistral_common >= 1.5.4 aiohttp starlette openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/requirements/test.in b/requirements/test.in index c171e8d4..faa4564e 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -27,7 +27,7 @@ torchaudio==2.6.0 torchvision==0.21.0 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.5.0 # required for pixtral test +mistral_common[opencv] >= 1.5.4 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test transformers==4.48.2 @@ -40,4 +40,4 @@ tritonclient==2.51.0 numpy < 2.0.0 runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 \ No newline at end of file +runai-model-streamer-s3==0.11.0 diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index fff63005..8e545432 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -56,6 +56,8 @@ try: except ImportError: USE_XFORMERS_OPS = False +PATCH_MERGE = "patch_merge" + class PixtralImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -155,7 +157,6 @@ class PixtralProcessorAdapter: for image in images: image_inputs = self.image_processor(ImageChunk(image=image)) - image_processed = torch.tensor(image_inputs.image) image_tokens = torch.tensor(image_inputs.tokens) @@ -353,6 +354,27 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ) self.vision_encoder = VisionTransformer(self.vision_args) + + if self.vision_args.add_pre_mm_projector_layer_norm: + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, + eps=1e-5) + + if self.vision_args.mm_projector_id == PATCH_MERGE: + self.patch_merger = PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) + if self.vision_args.add_pre_mm_projector_layer_norm: + self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, + eps=1e-5) + + if self.vision_args.mm_projector_id == PATCH_MERGE: + self.patch_merger = PatchMerger( + vision_encoder_dim=self.vision_args.hidden_size, + spatial_merge_size=self.vision_args.spatial_merge_size, + use_mlp_bias=False, + ) self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size) @@ -398,13 +420,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, image_input: PixtralImagePixelInputs, ) -> tuple[torch.Tensor, ...]: images = image_input["images"] - image_features = self.vision_encoder(images) feature_sizes = [ image_feature.shape[0] for image_feature in image_features ] - - image_embeds = self.vision_language_adapter(torch.cat(image_features)) + image_features = torch.cat(image_features) + if self.vision_args.add_pre_mm_projector_layer_norm: + image_features = self.pre_mm_projector_norm(image_features) + if self.vision_args.mm_projector_id == PATCH_MERGE: + patch_size = self.vision_args.patch_size + spatial_merge_size_square = self.vision_args.spatial_merge_size**2 + img_patch_dims = [(img.shape[1] // patch_size, + img.shape[2] // patch_size) for img in images] + feature_sizes = [ + feature_size // spatial_merge_size_square + for feature_size in feature_sizes + ] + image_features = self.patch_merger(image_features, + image_sizes=img_patch_dims) + image_embeds = self.vision_language_adapter(image_features) image_embeds = torch.split(image_embeds, feature_sizes) return image_embeds @@ -524,8 +558,19 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") + def is_patch_merger(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("patch_merger") + + def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("pre_mm_projector_norm") + # Get references to parameters for direct loading vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + patch_merger_dict = dict(self.patch_merger.named_parameters( + )) if self.vision_args.mm_projector_id == PATCH_MERGE else dict() + pre_mm_projector_norm_dict = dict( + self.pre_mm_projector_norm.named_parameters( + )) if self.vision_args.add_pre_mm_projector_layer_norm else dict() vision_lang_adapter_dict = dict( self.vision_language_adapter.named_parameters()) @@ -538,6 +583,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, param = vision_encoder_dict[trimmed_name] with torch.no_grad(): default_weight_loader(param, w) + elif is_patch_merger((name, w)): + # Load vision patch merger weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = patch_merger_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) + elif is_pre_mm_projector_norm((name, w)): + # Load vision pre_mm_projector_norm weights directly + trimmed_name = '.'.join(name.split(".")[1:]) + param = pre_mm_projector_norm_dict[trimmed_name] + with torch.no_grad(): + default_weight_loader(param, w) elif is_vision_lang_adapter_weights((name, w)): # Load vision-language adapter weights directly trimmed_name = '.'.join(name.split(".")[1:]) @@ -566,6 +623,9 @@ class VisionEncoderArgs: rope_theta: float # for rope-2D image_token_id: int adapter_bias: bool = True + spatial_merge_size: int = 1 + add_pre_mm_projector_layer_norm: bool = False + mm_projector_id: str = "" def _reshape_for_broadcast(freqs_cis: torch.Tensor, @@ -843,6 +903,104 @@ class VisionLanguageAdapter(nn.Module): return self.w_out(self.gelu(self.w_in(x))) +class PatchMerger(nn.Module): + """ + Learned merging of spatial_merge_size ** 2 patches + """ + + def __init__( + self, + vision_encoder_dim: int, + spatial_merge_size: int, + use_mlp_bias: bool = False, + ) -> None: + super().__init__() + + mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2) + + self.spatial_merge_size = spatial_merge_size + self.mlp_input_dim = mlp_input_dim + + self.merging_layer = nn.Linear( + mlp_input_dim, + vision_encoder_dim, + bias=use_mlp_bias, + ) + + def forward(self, x: torch.Tensor, + image_sizes: list[tuple[int, int]]) -> torch.Tensor: + # image_sizes specified in tokens + assert sum([h * w for h, w in image_sizes]) == len(x) + + # x is (N, vision_encoder_dim) + x = self.permute(x, image_sizes) + + # x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2) + x = self.merging_layer(x) + + # x is (N / spatial_merge_size ** 2, vision_encoder_dim) + return x + + def permute( + self, + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + ) -> torch.Tensor: + """ + Args: + x: (N, D) where N is flattened and concatenated patch tokens + for all images + image_sizes: list of tuple of (height, width) in tokens for + each image + Returns: + image_features: reorders patch tokens so each grid of + (spatial_merge_size, spatial_merge_size) is contiguous. + now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2) + """ + + sub_grids = get_sub_grids( + x=x, + image_sizes=image_sizes, + spatial_merge_size=self.spatial_merge_size + ) # list of [d x sub_grid_size x sub_grid_size x n_patches] + permuted_tensor: list[torch.Tensor] = [] + for grid in sub_grids: + n_patches = grid.shape[-1] + permuted_tensor.append(grid.view(-1, n_patches).t( + )) # n_patches x d * sub_grid_size * sub_grid_size + return torch.cat( + permuted_tensor, dim=0 + ) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2) + + +def get_sub_grids( + x: torch.Tensor, + image_sizes: list[tuple[int, int]], + spatial_merge_size: int, +) -> list[torch.Tensor]: + # image_sizes specified in tokens + tokens_per_image = [h * w for h, w in image_sizes] + d = x.shape[-1] + all_img_sub_grids: list[torch.Tensor] = [] + sub_grid_size = spatial_merge_size + + for image_index, image_tokens in enumerate(x.split(tokens_per_image)): + # Reshape image_tokens into a 2D grid + h, w = image_sizes[image_index] + image_grid = image_tokens.view(h, w, d).permute( + 2, 0, 1)[None, :, :, :] # 1 x d x h x w + sub_grids = torch.nn.functional.unfold(image_grid, + kernel_size=sub_grid_size, + stride=sub_grid_size) + sub_grids = sub_grids.view( + 1, d, sub_grid_size, sub_grid_size, + -1) # 1 x d x sub_grid_size x sub_grid_size x n_patches + + all_img_sub_grids.append(sub_grids[0]) + + return all_img_sub_grids + + #### HF Transformers version of Pixtral #### # Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py # This model follows the Llava family, meaning image embeddings are placed