Add patch merger (#14957)
This commit is contained in:
parent
166a168b0f
commit
d20b0c139c
@ -28,7 +28,7 @@ pyzmq
|
||||
msgspec
|
||||
gguf == 0.10.0
|
||||
importlib_metadata
|
||||
mistral_common[opencv] >= 1.5.0
|
||||
mistral_common[opencv] >= 1.5.4
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
|
@ -15,7 +15,7 @@ pydantic >= 2.8
|
||||
torch
|
||||
py-cpuinfo
|
||||
transformers
|
||||
mistral_common >= 1.5.0
|
||||
mistral_common >= 1.5.4
|
||||
aiohttp
|
||||
starlette
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
|
@ -27,7 +27,7 @@ torchaudio==2.6.0
|
||||
torchvision==0.21.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.5.0 # required for pixtral test
|
||||
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.4 # required for model evaluation test
|
||||
transformers==4.48.2
|
||||
|
@ -56,6 +56,8 @@ try:
|
||||
except ImportError:
|
||||
USE_XFORMERS_OPS = False
|
||||
|
||||
PATCH_MERGE = "patch_merge"
|
||||
|
||||
|
||||
class PixtralImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
@ -155,7 +157,6 @@ class PixtralProcessorAdapter:
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
|
||||
image_processed = torch.tensor(image_inputs.image)
|
||||
image_tokens = torch.tensor(image_inputs.tokens)
|
||||
|
||||
@ -353,6 +354,27 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
|
||||
eps=1e-5)
|
||||
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
self.patch_merger = PatchMerger(
|
||||
vision_encoder_dim=self.vision_args.hidden_size,
|
||||
spatial_merge_size=self.vision_args.spatial_merge_size,
|
||||
use_mlp_bias=False,
|
||||
)
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size,
|
||||
eps=1e-5)
|
||||
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
self.patch_merger = PatchMerger(
|
||||
vision_encoder_dim=self.vision_args.hidden_size,
|
||||
spatial_merge_size=self.vision_args.spatial_merge_size,
|
||||
use_mlp_bias=False,
|
||||
)
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
self.vision_args, dim=config.text_config.hidden_size)
|
||||
|
||||
@ -398,13 +420,25 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_input: PixtralImagePixelInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
images = image_input["images"]
|
||||
|
||||
image_features = self.vision_encoder(images)
|
||||
feature_sizes = [
|
||||
image_feature.shape[0] for image_feature in image_features
|
||||
]
|
||||
|
||||
image_embeds = self.vision_language_adapter(torch.cat(image_features))
|
||||
image_features = torch.cat(image_features)
|
||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
||||
image_features = self.pre_mm_projector_norm(image_features)
|
||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
||||
patch_size = self.vision_args.patch_size
|
||||
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
||||
img_patch_dims = [(img.shape[1] // patch_size,
|
||||
img.shape[2] // patch_size) for img in images]
|
||||
feature_sizes = [
|
||||
feature_size // spatial_merge_size_square
|
||||
for feature_size in feature_sizes
|
||||
]
|
||||
image_features = self.patch_merger(image_features,
|
||||
image_sizes=img_patch_dims)
|
||||
image_embeds = self.vision_language_adapter(image_features)
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
@ -524,8 +558,19 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("vision_language_adapter")
|
||||
|
||||
def is_patch_merger(weight: Tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("patch_merger")
|
||||
|
||||
def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("pre_mm_projector_norm")
|
||||
|
||||
# Get references to parameters for direct loading
|
||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
||||
patch_merger_dict = dict(self.patch_merger.named_parameters(
|
||||
)) if self.vision_args.mm_projector_id == PATCH_MERGE else dict()
|
||||
pre_mm_projector_norm_dict = dict(
|
||||
self.pre_mm_projector_norm.named_parameters(
|
||||
)) if self.vision_args.add_pre_mm_projector_layer_norm else dict()
|
||||
vision_lang_adapter_dict = dict(
|
||||
self.vision_language_adapter.named_parameters())
|
||||
|
||||
@ -538,6 +583,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
param = vision_encoder_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_patch_merger((name, w)):
|
||||
# Load vision patch merger weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
param = patch_merger_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_pre_mm_projector_norm((name, w)):
|
||||
# Load vision pre_mm_projector_norm weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
param = pre_mm_projector_norm_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_lang_adapter_weights((name, w)):
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
@ -566,6 +623,9 @@ class VisionEncoderArgs:
|
||||
rope_theta: float # for rope-2D
|
||||
image_token_id: int
|
||||
adapter_bias: bool = True
|
||||
spatial_merge_size: int = 1
|
||||
add_pre_mm_projector_layer_norm: bool = False
|
||||
mm_projector_id: str = ""
|
||||
|
||||
|
||||
def _reshape_for_broadcast(freqs_cis: torch.Tensor,
|
||||
@ -843,6 +903,104 @@ class VisionLanguageAdapter(nn.Module):
|
||||
return self.w_out(self.gelu(self.w_in(x)))
|
||||
|
||||
|
||||
class PatchMerger(nn.Module):
|
||||
"""
|
||||
Learned merging of spatial_merge_size ** 2 patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_encoder_dim: int,
|
||||
spatial_merge_size: int,
|
||||
use_mlp_bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)
|
||||
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.mlp_input_dim = mlp_input_dim
|
||||
|
||||
self.merging_layer = nn.Linear(
|
||||
mlp_input_dim,
|
||||
vision_encoder_dim,
|
||||
bias=use_mlp_bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
image_sizes: list[tuple[int, int]]) -> torch.Tensor:
|
||||
# image_sizes specified in tokens
|
||||
assert sum([h * w for h, w in image_sizes]) == len(x)
|
||||
|
||||
# x is (N, vision_encoder_dim)
|
||||
x = self.permute(x, image_sizes)
|
||||
|
||||
# x is (N / spatial_merge_size ** 2, vision_encoder_dim * spatial_merge_size ** 2)
|
||||
x = self.merging_layer(x)
|
||||
|
||||
# x is (N / spatial_merge_size ** 2, vision_encoder_dim)
|
||||
return x
|
||||
|
||||
def permute(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
image_sizes: list[tuple[int, int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (N, D) where N is flattened and concatenated patch tokens
|
||||
for all images
|
||||
image_sizes: list of tuple of (height, width) in tokens for
|
||||
each image
|
||||
Returns:
|
||||
image_features: reorders patch tokens so each grid of
|
||||
(spatial_merge_size, spatial_merge_size) is contiguous.
|
||||
now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
|
||||
"""
|
||||
|
||||
sub_grids = get_sub_grids(
|
||||
x=x,
|
||||
image_sizes=image_sizes,
|
||||
spatial_merge_size=self.spatial_merge_size
|
||||
) # list of [d x sub_grid_size x sub_grid_size x n_patches]
|
||||
permuted_tensor: list[torch.Tensor] = []
|
||||
for grid in sub_grids:
|
||||
n_patches = grid.shape[-1]
|
||||
permuted_tensor.append(grid.view(-1, n_patches).t(
|
||||
)) # n_patches x d * sub_grid_size * sub_grid_size
|
||||
return torch.cat(
|
||||
permuted_tensor, dim=0
|
||||
) # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)
|
||||
|
||||
|
||||
def get_sub_grids(
|
||||
x: torch.Tensor,
|
||||
image_sizes: list[tuple[int, int]],
|
||||
spatial_merge_size: int,
|
||||
) -> list[torch.Tensor]:
|
||||
# image_sizes specified in tokens
|
||||
tokens_per_image = [h * w for h, w in image_sizes]
|
||||
d = x.shape[-1]
|
||||
all_img_sub_grids: list[torch.Tensor] = []
|
||||
sub_grid_size = spatial_merge_size
|
||||
|
||||
for image_index, image_tokens in enumerate(x.split(tokens_per_image)):
|
||||
# Reshape image_tokens into a 2D grid
|
||||
h, w = image_sizes[image_index]
|
||||
image_grid = image_tokens.view(h, w, d).permute(
|
||||
2, 0, 1)[None, :, :, :] # 1 x d x h x w
|
||||
sub_grids = torch.nn.functional.unfold(image_grid,
|
||||
kernel_size=sub_grid_size,
|
||||
stride=sub_grid_size)
|
||||
sub_grids = sub_grids.view(
|
||||
1, d, sub_grid_size, sub_grid_size,
|
||||
-1) # 1 x d x sub_grid_size x sub_grid_size x n_patches
|
||||
|
||||
all_img_sub_grids.append(sub_grids[0])
|
||||
|
||||
return all_img_sub_grids
|
||||
|
||||
|
||||
#### HF Transformers version of Pixtral ####
|
||||
# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
|
||||
# This model follows the Llava family, meaning image embeddings are placed
|
||||
|
Loading…
x
Reference in New Issue
Block a user