[Bugfix] Fix missing post_layernorm in CLIP (#8155)

This commit is contained in:
Cyrus Leung 2024-09-10 16:22:50 +08:00 committed by GitHub
parent a1d874224d
commit da1a844e61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 19 deletions

View File

@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
def forward(
self,
pixel_values: torch.Tensor,
@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module):
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
return hidden_states
if self.post_layernorm is None:
return hidden_states
return self.post_layernorm(hidden_states)
class CLIPVisionModel(nn.Module):
@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)
def forward(self, pixel_values: Optional[torch.Tensor] = None):
@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
return self.vision_model(pixel_values=pixel_values)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
@property
def device(self):
@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])

View File

@ -443,27 +443,26 @@ class SiglipVisionTransformer(nn.Module):
self.config = config
embed_dim = config.hidden_size
if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
if self.need_post_layernorm:
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
@ -482,6 +481,9 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return encoder_outputs
last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
@ -512,8 +514,8 @@ class SiglipVisionModel(nn.Module):
)
@property
def need_post_layernorm(self):
return self.vision_model.need_post_layernorm
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@ -541,7 +543,7 @@ class SiglipVisionModel(nn.Module):
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
and not self._require_post_layernorm):
continue
# omit layers when num_hidden_layers_override is set