Add patch merger (#14957)

This commit is contained in:
Patrick von Platen 2025-03-17 14:47:50 +01:00 committed by GitHub
parent 166a168b0f
commit d20b0c139c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 166 additions and 8 deletions

View File

@ -28,7 +28,7 @@ pyzmq
msgspec
gguf == 0.10.0
importlib_metadata
mistral_common[opencv] >= 1.5.0
mistral_common[opencv] >= 1.5.4
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12

View File

@ -15,7 +15,7 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.5.0
mistral_common >= 1.5.4
aiohttp
starlette
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -27,7 +27,7 @@ torchaudio==2.6.0
torchvision==0.21.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[opencv] >= 1.5.0 # required for pixtral test
mistral_common[opencv] >= 1.5.4 # required for pixtral test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.4 # required for model evaluation test
transformers==4.48.2

View File

@ -56,6 +56,8 @@ try:
except ImportError:
USE_XFORMERS_OPS = False
PATCH_MERGE = "patch_merge"
class PixtralImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@ -155,7 +157,6 @@ class PixtralProcessorAdapter:
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
@ -353,6 +354,27 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.vision_encoder = VisionTransformer(self.vision_args)
if self.vision_args.add_pre_mm_projector_layer_norm:
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
eps=1e-5)
if self.vision_args.mm_projector_id == PATCH_MERGE:
self.patch_merger = PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
if self.vision_args.add_pre_mm_projector_layer_norm:
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
eps=1e-5)
if self.vision_args.mm_projector_id == PATCH_MERGE:
self.patch_merger = PatchMerger(
vision_encoder_dim=self.vision_args.hidden_size,
spatial_merge_size=self.vision_args.spatial_merge_size,
use_mlp_bias=False,
)
self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size)
@ -398,13 +420,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input: PixtralImagePixelInputs,
) -> tuple[torch.Tensor, ...]:
images = image_input["images"]
image_features = self.vision_encoder(images)
feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]
image_embeds = self.vision_language_adapter(torch.cat(image_features))
image_features = torch.cat(image_features)
if self.vision_args.add_pre_mm_projector_layer_norm:
image_features = self.pre_mm_projector_norm(image_features)
if self.vision_args.mm_projector_id == PATCH_MERGE:
patch_size = self.vision_args.patch_size
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
img_patch_dims = [(img.shape[1] // patch_size,
img.shape[2] // patch_size) for img in images]
feature_sizes = [
feature_size // spatial_merge_size_square
for feature_size in feature_sizes
]
image_features = self.patch_merger(image_features,
image_sizes=img_patch_dims)
image_embeds = self.vision_language_adapter(image_features)
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
@ -524,8 +558,19 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
def is_patch_merger(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger")
def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
patch_merger_dict = dict(self.patch_merger.named_parameters(
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
pre_mm_projector_norm_dict = dict(
self.pre_mm_projector_norm.named_parameters(
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
vision_lang_adapter_dict = dict(
self.vision_language_adapter.named_parameters())
@ -538,6 +583,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_patch_merger((name, w)):
# Load vision patch merger weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = patch_merger_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)):
# Load vision pre_mm_projector_norm weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = pre_mm_projector_norm_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = '.'.join(name.split(".")[1:])
@ -566,6 +623,9 @@ class VisionEncoderArgs:
rope_theta: float # for rope-2D
image_token_id: int
adapter_bias: bool = True
spatial_merge_size: int = 1
add_pre_mm_projector_layer_norm: bool = False
mm_projector_id: str = ""
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
@ -843,6 +903,104 @@ class VisionLanguageAdapter(nn.Module):
return self.w_out(self.gelu(self.w_in(x)))
class PatchMerger(nn.Module):
"""
Learned merging of spatial_merge_size ** 2 patches
"""
def __init__(
self,
vision_encoder_dim: int,
spatial_merge_size: int,
use_mlp_bias: bool = False,
) -> None:
super().__init__()
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
self.spatial_merge_size = spatial_merge_size
self.mlp_input_dim = mlp_input_dim
self.merging_layer = nn.Linear(
mlp_input_dim,
vision_encoder_dim,
bias=use_mlp_bias,
)
def forward(self, x: torch.Tensor,
image_sizes: list[tuple[int, int]]) -> torch.Tensor:
# image_sizes specified in tokens
assert sum([h * w for h, w in image_sizes]) == len(x)
# x is (N, vision_encoder_dim)
x = self.permute(x, image_sizes)
# x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
x = self.merging_layer(x)
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
return x
def permute(
self,
x: torch.Tensor,
image_sizes: list[tuple[int, int]],
) -> torch.Tensor:
"""
Args:
x: (N, D) where N is flattened and concatenated patch tokens
for all images
image_sizes: list of tuple of (height, width) in tokens for
each image
Returns:
image_features: reorders patch tokens so each grid of
(spatial_merge_size, spatial_merge_size) is contiguous.
now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
"""
sub_grids = get_sub_grids(
x=x,
image_sizes=image_sizes,
spatial_merge_size=self.spatial_merge_size
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
permuted_tensor: list[torch.Tensor] = []
for grid in sub_grids:
n_patches = grid.shape[-1]
permuted_tensor.append(grid.view(-1, n_patches).t(
)) # n_patches x d * sub_grid_size * sub_grid_size
return torch.cat(
permuted_tensor, dim=0
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
def get_sub_grids(
x: torch.Tensor,
image_sizes: list[tuple[int, int]],
spatial_merge_size: int,
) -> list[torch.Tensor]:
# image_sizes specified in tokens
tokens_per_image = [h * w for h, w in image_sizes]
d = x.shape[-1]
all_img_sub_grids: list[torch.Tensor] = []
sub_grid_size = spatial_merge_size
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
# Reshape image_tokens into a 2D grid
h, w = image_sizes[image_index]
image_grid = image_tokens.view(h, w, d).permute(
2, 0, 1)[None, :, :, :] # 1 x d x h x w
sub_grids = torch.nn.functional.unfold(image_grid,
kernel_size=sub_grid_size,
stride=sub_grid_size)
sub_grids = sub_grids.view(
1, d, sub_grid_size, sub_grid_size,
-1) # 1 x d x sub_grid_size x sub_grid_size x n_patches
all_img_sub_grids.append(sub_grids[0])
return all_img_sub_grids
#### HF Transformers version of Pixtral ####
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
# This model follows the Llava family, meaning image embeddings are placed