# coding=utf-8 # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Reference: # * transformers/models/dinov2/modeling_dinov2.py # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 """ PyTorch CLIP model.""" from typing import Dict, List, Optional, Set, Tuple, Union import torch import torch.nn as nn from .modeling_clip import ( CLIPConfig, CLIPTextConfig, CLIPVisionConfig, CLIPEncoderLayer, CLIPTextTransformer, CLIPVisionTransformer, CLIPModel, CLIPVisionEmbeddings, CLIPVisionModel, CLIPOutput, BaseModelOutput, BaseModelOutputWithPooling ) class ModLN(nn.Module): def __init__(self, inner_dim: int, mod_dim: int = 32): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(mod_dim, inner_dim * 2), ) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.weight) nn.init.zeros_(m.bias) def forward(self, x:torch.Tensor, condition:torch.Tensor): ''' x: [N, M, C_in], M: num of tokens condition: [N, C_mod] ''' shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) return x * (1 + scale) + shift class ConditionalCLIPVisionConfig(CLIPVisionConfig): def __init__(self, modulation_dim: int = 32, *args, **kwargs): super().__init__(*args, **kwargs) self.modulation_dim = modulation_dim class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: ConditionalCLIPVisionConfig) -> None: super().__init__(config) self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor, condition: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: residual = hidden_states hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class ConditionalCLIPEncoder(nn.Module): def __init__(self, config: CLIPConfig) -> None: super().__init__() self.config = config self.layers = nn.ModuleList([ConditionalCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, condition: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, causal_attention_mask, condition=condition, output_attentions=output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, causal_attention_mask, condition=condition, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): def __init__(self, config: ConditionalCLIPVisionConfig): super().__init__(config) self.config = config embed_dim = config.hidden_size self.embeddings = CLIPVisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.encoder = ConditionalCLIPEncoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, condition: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, condition=condition, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] pooled_output = last_hidden_state[:, 0, :] pooled_output = self.post_layernorm(pooled_output) if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class ConditionalCLIPVisionModel(CLIPVisionModel): config_class = ConditionalCLIPVisionConfig def __init__(self, config: ConditionalCLIPVisionConfig): super().__init__(config) self.vision_model = ConditionalCLIPVisionTransformer(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: Optional[torch.FloatTensor] = None, condition: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, condition=condition, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class ConditionalCLIPModel(CLIPModel): config_class = CLIPConfig def __init__(self, config: CLIPConfig): super().__init__(config) if not isinstance(config.text_config, CLIPTextConfig): raise ValueError( "config.text_config is expected to be of type CLIPTextConfig but is of type" f" {type(config.text_config)}." ) if not isinstance(config.vision_config, CLIPVisionConfig): raise ValueError( "config.vision_config is expected to be of type CLIPVisionConfig but is of type" f" {type(config.vision_config)}." ) text_config = config.text_config vision_config = config.vision_config self.projection_dim = config.projection_dim self.text_embed_dim = text_config.hidden_size self.vision_embed_dim = vision_config.hidden_size self.text_model = CLIPTextTransformer(text_config) self.vision_model = ConditionalCLIPVisionTransformer(vision_config) self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) # Initialize weights and apply final processing self.post_init() def get_image_features( self, pixel_values: Optional[torch.FloatTensor] = None, condition: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> torch.FloatTensor: # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, condition=condition, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = vision_outputs[1] # pooled_output image_features = self.visual_projection(pooled_output) return image_features def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, condition: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPOutput]: # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, condition=condition, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[1] image_embeds = self.visual_projection(image_embeds) text_embeds = text_outputs[1] text_embeds = self.text_projection(text_embeds) # normalized features image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.t() loss = None if return_loss: loss = clip_loss(logits_per_text) if not return_dict: output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) return ((loss,) + output) if loss is not None else output return CLIPOutput( loss=loss, logits_per_image=logits_per_image, logits_per_text=logits_per_text, text_embeds=text_embeds, image_embeds=image_embeds, text_model_output=text_outputs, vision_model_output=vision_outputs, )