# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch VLE model.""" from typing import Optional, Tuple, Union import torch from torch import nn from transformers.modeling_utils import PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.modeling_auto import AutoModel from transformers.models.bert.modeling_bert import BertAttention, BertIntermediate, BertOutput, apply_chunking_to_forward from transformers.models.clip.modeling_clip import CLIPOutput, CLIPVisionConfig, CLIPVisionModel from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2OnlyMLMHead from .configuration_vle import VLEConfig from dataclasses import dataclass logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "VLEConfig" @dataclass class VLEModelOutput(ModelOutput): pooler_output: torch.FloatTensor = None text_embeds: torch.FloatTensor = None image_embeds: torch.FloatTensor = None @dataclass class VLEForITMOutput(ModelOutput): loss: torch.FloatTensor = None logits: torch.FloatTensor = None @dataclass class VLEForPBCOutput(ModelOutput): loss: torch.FloatTensor = None logits: torch.FloatTensor = None @dataclass class VLEForMLMOutput(ModelOutput): loss: torch.FloatTensor = None logits: torch.FloatTensor = None @dataclass class VLEForVQAOutput(ModelOutput): loss : torch.FloatTensor = None logits: torch.FloatTensor = None class ITMHead(nn.Module): def __init__(self, hidden_size): super().__init__() self.fc = nn.Linear(hidden_size, 2) def forward(self, x): x = self.fc(x) return x def extend_position_embedding(state_dict, patch_size, after): """ modify state_dict in-place for longer position embeddings """ keys = {} for k,v in state_dict.items(): if k.endswith('vision_model.embeddings.position_embedding.weight'): assert k not in keys keys['pe'] = (k,v) if k.endswith('vision_model.embeddings.position_ids'): assert k not in keys keys['pi'] = (k,v) pe_weight = keys['pe'][1] position_length_before = pe_weight.shape[0] embed_dim = pe_weight.shape[1] grid_before = position_length_before - 1 position_length_after = (after // patch_size) ** 2 + 1 grid_after = position_length_after - 1 new_pe_weight = pe_weight[1:].reshape((grid_before,grid_before,-1)) new_pe_weight = torch.nn.functional.interpolate( new_pe_weight.permute(2,0,1).unsqueeze(0), size = (grid_after,grid_after), mode = 'bicubic') new_pe_weight = new_pe_weight.squeeze(0).permute(1,2,0).reshape(grid_after*grid_after, -1) new_pe_weight = torch.cat((pe_weight[0:1],new_pe_weight), dim=0) assert new_pe_weight.shape == (grid_after*grid_after + 1, embed_dim) state_dict[keys['pe'][0]] = new_pe_weight state_dict[keys['pi'][0]] = torch.arange(grid_after*grid_after + 1).unsqueeze(0) return state_dict class Pooler(nn.Module): def __init__(self, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertCrossLayer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward( self, hidden_states, encoder_hidden_states, attention_mask=None, encoder_attention_mask=None, output_attentions=False, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = None #past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask=None, output_attentions=output_attentions, past_key_value=None, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None cross_attention_outputs = self.crossattention( attention_output, attention_mask, None, encoder_hidden_states, encoder_attention_mask, None, output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class VLEPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization. """ config_class = VLEConfig base_model_prefix = "vle" supports_gradient_checkpointing = False _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) ''' TODO checkpointing def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, BertEncoder): module.gradient_checkpointing = value ''' class VLEModel(VLEPreTrainedModel): def __init__( self, config: Optional[VLEConfig] = None, vision_model: Optional[PreTrainedModel] = None, text_model: Optional[PreTrainedModel] = None, ): if config is None and (vision_model is None or text_model is None): raise ValueError("Either a configuration or an vision and a text model has to be provided") if config is None: config = VLEConfig(vision_model.config, text_model.config) else: if not isinstance(config, self.config_class): raise ValueError(f"config: {config} has to be of type {self.config_class}") # initialize with config super().__init__(config) if vision_model is None: if isinstance(config.vision_config, CLIPVisionConfig): vision_model = CLIPVisionModel(config.vision_config) else: vision_model = AutoModel.from_config(config.vision_config) if text_model is None: text_model = AutoModel.from_config(config.text_config) self.vision_model = vision_model self.text_model = text_model # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.vision_model.config = self.config.vision_config self.text_model.config = self.config.text_config self.vision_embed_dim = config.vision_config.hidden_size self.text_embed_dim = config.text_config.hidden_size self.coattention_dim = config.hidden_size # add projection layers self.text_projection_layer = nn.Linear(self.text_embed_dim, self.coattention_dim) self.image_projection_layer = nn.Linear(self.vision_embed_dim, self.coattention_dim) #self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) self.token_type_embeddings = nn.Embedding(config.num_token_types, config.hidden_size) self.cross_modal_image_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)]) self.cross_modal_text_layers = nn.ModuleList([BertCrossLayer(config) for _ in range(config.num_hidden_layers)]) self.cross_modal_image_pooler = Pooler(config.hidden_size) self.cross_modal_text_pooler = Pooler(config.hidden_size) # Initialize weights and apply final processing self.token_type_embeddings.apply(self._init_weights) self.cross_modal_image_layers.apply(self._init_weights) self.cross_modal_text_layers.apply(self._init_weights) self.cross_modal_image_pooler.apply(self._init_weights) self.cross_modal_text_pooler.apply(self._init_weights) if hasattr(self,"text_projection_layer"): self.text_projection_layer.apply(self._init_weights) if hasattr(self,"image_projection_layer"): self.image_projection_layer.apply(self._init_weights) def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, patch_ids = None, return_loss: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], VLEModelOutput]: return_dict = return_dict if return_dict is not None else self.config.return_dict vision_outputs = self.vision_model( pixel_values=pixel_values, return_dict=return_dict, ) text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, return_dict=return_dict, ) image_embeds = self.vision_model.vision_model.post_layernorm(vision_outputs[0]) # last_hidden_state image_embeds = self.image_projection_layer(image_embeds) text_embeds = text_outputs[0] # last_hidden_state text_embeds = self.text_projection_layer(text_embeds) if patch_ids is not None: raise NotImplementedError #TODO image_masks = torch.ones((image_embeds.size(0), image_embeds.size(1)), dtype=torch.long, device=image_embeds.device) extend_image_masks = self.text_model.get_extended_attention_mask(image_masks, image_masks.size()) image_embeds = image_embeds + self.token_type_embeddings(torch.full_like(image_masks, 1)) # image_token_type_idx=1 TODO use_vcr_token_type_embedding extend_text_masks = self.text_model.get_extended_attention_mask(attention_mask, attention_mask.size()) text_embeds = text_embeds + self.token_type_embeddings(torch.zeros_like(attention_mask)) x, y = text_embeds, image_embeds for text_layer, image_layer in zip(self.cross_modal_text_layers, self.cross_modal_image_layers): x1 = text_layer(x, y, extend_text_masks, extend_image_masks) y1 = image_layer(y, x, extend_image_masks, extend_text_masks) x, y = x1[0], y1[0] text_embeds, image_embeds = x, y text_pooler_output = self.cross_modal_text_pooler(x) image_pooler_output = self.cross_modal_image_pooler(y) pooler_output = torch.cat([text_pooler_output, image_pooler_output], dim=-1) if not return_dict: output = (pooler_output, text_embeds, image_embeds) return output return VLEModelOutput( pooler_output = pooler_output, text_embeds = text_embeds, image_embeds = image_embeds ) @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported # for composite models kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @classmethod def from_vision_text_pretrained( cls, vision_model_name_or_path: str = None, text_model_name_or_path: str = None, *model_args, **kwargs, ) -> PreTrainedModel: kwargs_vision = { argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") } kwargs_text = { argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") } # remove vision, text kwargs from kwargs for key in kwargs_vision.keys(): del kwargs["vision_" + key] for key in kwargs_text.keys(): del kwargs["text_" + key] # Load and initialize the vision and text model vision_model = kwargs_vision.pop("model", None) if vision_model is None: if vision_model_name_or_path is None: raise ValueError( "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined" ) if "config" not in kwargs_vision: vision_config = AutoConfig.from_pretrained(vision_model_name_or_path) if vision_config.model_type == "clip": kwargs_vision["config"] = vision_config.vision_config vision_model = CLIPVisionModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) else: kwargs_vision["config"] = vision_config vision_model = AutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision) text_model = kwargs_text.pop("model", None) if text_model is None: if text_model_name_or_path is None: raise ValueError( "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined" ) if "config" not in kwargs_text: text_config = AutoConfig.from_pretrained(text_model_name_or_path) kwargs_text["config"] = text_config text_model = AutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) # instantiate config with corresponding kwargs config = VLEConfig(vision_model.config, text_model.config, **kwargs) # init model model = cls(config=config, vision_model=vision_model, text_model=text_model) # the projection layers are always newly initialized when loading the model # using pre-trained vision and text model. logger.warning( "The coattention layers and projection layers are newly initialized. You should probably TRAIN this model on a down-stream task to be" " able to use it for predictions and inference." ) return model def get_text_features( self, input_ids=None, attention_mask=None, position_ids=None, token_type_ids=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, token_type_ids=token_type_ids, #output_attentions=output_attentions, #output_hidden_states=output_hidden_states, return_dict=return_dict, ) return text_outputs[0] # last_hidden_state def get_image_features( self, pixel_values=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import VLEModel, AutoImageProcessor >>> model = VLEModel.from_pretrained("clip-italian/clip-italian") >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = image_processor(images=image, return_tensors="pt") >>> image_features = model.get_image_features(**inputs) ```""" vision_outputs = self.vision_model( pixel_values=pixel_values, #output_attentions=output_attentions, #output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = self.vision_model.vision_model.post_layernorm(vision_outputs[0]) return last_hidden_state def get_input_embeddings(self): return self.text_model.embeddings.word_embeddings def set_input_embeddings(self, new_embeddings): self.text_model.embeddings.word_embeddings = new_embeddings class VLEForVQA(VLEPreTrainedModel): def __init__( self, config: Optional[VLEConfig] = None, vision_model: Optional[PreTrainedModel] = None, text_model: Optional[PreTrainedModel] = None, ): super().__init__(config) self.vle = VLEModel(config, vision_model, text_model) hidden_size = config.hidden_size self.num_vqa_labels = len(self.config.id2label) self.vqa_classifier = nn.Sequential( nn.Linear(hidden_size * 2, hidden_size * 2), nn.LayerNorm(hidden_size * 2), nn.GELU(), nn.Linear(hidden_size * 2, self.num_vqa_labels), ) self.vqa_classifier.apply(self._init_weights) def forward(self, input_ids: Optional[torch.LongTensor], pixel_values: Optional[torch.FloatTensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, patch_ids = None, vqa_labels = None, vqa_scores = None, return_loss: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], VLEForVQAOutput]: return_dict = return_dict if return_dict is not None else self.config.return_dict vle_output = self.vle( input_ids = input_ids, pixel_values = pixel_values, attention_mask = attention_mask, position_ids = position_ids, token_type_ids = token_type_ids, patch_ids = patch_ids,) pooler_output = vle_output[0] vqa_logits = self.vqa_classifier(pooler_output) vqa_loss = None if return_loss and vqa_labels is not None and vqa_scores is not None: vqa_targets = torch.zeros(len(vqa_logits), self.num_vqa_labels,device=vqa_logits.device) for i, (_label, _score) in enumerate(zip(vqa_labels, vqa_scores)): for l, s in zip(_label, _score): vqa_targets[i, l] = s vqa_loss = F.binary_cross_entropy_with_logits(vqa_logits, vqa_targets) * vqa_targets.shape[1] # https://github.com/jnhwkim/ban-vqa/blob/master/train.py#L19 if not return_dict: output = (vqa_logits,) return ((vqa_loss,) + output) if vqa_loss is not None else output return VLEForVQAOutput( loss = vqa_loss, logits = vqa_logits ) class VLEForITM(VLEPreTrainedModel): def __init__( self, config: Optional[VLEConfig] = None, vision_model: Optional[PreTrainedModel] = None, text_model: Optional[PreTrainedModel] = None, ): super().__init__(config) self.vle = VLEModel(config, vision_model, text_model) hidden_size = config.hidden_size self.itm_score = ITMHead(hidden_size*2) self.itm_score.apply(self._init_weights) def forward(self, input_ids: Optional[torch.LongTensor], pixel_values: Optional[torch.FloatTensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, patch_ids = None, itm_labels = None, return_loss: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], VLEForITMOutput]: return_dict = return_dict if return_dict is not None else self.config.return_dict vle_output = self.vle( input_ids = input_ids, pixel_values = pixel_values, attention_mask = attention_mask, position_ids = position_ids, token_type_ids = token_type_ids, patch_ids = patch_ids,) pooler_output = vle_output[0] itm_logits = self.itm_score(pooler_output) itm_loss = None if return_loss and itm_labels is not None: itm_loss = nn.functional.cross_entropy(itm_logits, torch.tensor(itm_labels).long().to(itm_logits.device)) if not return_dict: output = (itm_logits,) return ((itm_loss,) + output) if itm_loss is not None else output return VLEForITMOutput(loss = itm_loss, logits = itm_logits) class VLEForPBC(VLEPreTrainedModel): def __init__( self, config: Optional[VLEConfig] = None, vision_model: Optional[PreTrainedModel] = None, text_model: Optional[PreTrainedModel] = None, ): super().__init__(config) self.vle = VLEModel(config, vision_model, text_model) hidden_size = config.hidden_size self.pbc_classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.LayerNorm(hidden_size), nn.GELU(), nn.Linear(hidden_size, 2), ) self.pbc_classifier.apply(self._init_weights) def forward(self, input_ids: Optional[torch.LongTensor], pixel_values: Optional[torch.FloatTensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, patch_ids = None, pbc_labels = None, return_loss: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], VLEForPBCOutput]: return_dict = return_dict if return_dict is not None else self.config.return_dict vle_output = self.vle( input_ids = input_ids, pixel_values = pixel_values, attention_mask = attention_mask, position_ids = position_ids, token_type_ids = token_type_ids, patch_ids = patch_ids,) image_embeds = vle_output['image_embeds'] pbc_logits = self.pbc_classifier(image_embeds[:,1:,:]) pbc_loss = None if return_loss and pbc_labels is not None: pbc_loss = F.cross_entropy(pbc_logits, torch.tensor(pbc_labels).long().to(pbc_logits.device)) if not return_dict: output = (pbc_logits,) return ((pbc_loss,) + output) if pbc_loss is not None else output return VLEForPBCOutput(loss = pbc_loss, logits = pbc_logits) class VLEForMLM(VLEPreTrainedModel): _keys_to_ignore_on_load_missing = [r"mlm_score.1.predictions.decoder.weight",r"mlm_score.1.predictions.decoder.bias"] def __init__( self, config: Optional[VLEConfig] = None, vision_model: Optional[PreTrainedModel] = None, text_model: Optional[PreTrainedModel] = None, ): super().__init__(config) self.vle = VLEModel(config, vision_model, text_model) hidden_size = config.hidden_size mlm_head = DebertaV2OnlyMLMHead(self.config.text_config) mlm_transform = nn.Linear(hidden_size, self.config.text_config.hidden_size) self.mlm_score = nn.Sequential( mlm_transform, mlm_head, ) def forward(self, input_ids: Optional[torch.LongTensor], pixel_values: Optional[torch.FloatTensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, patch_ids = None, mlm_labels = None, return_loss: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], VLEForMLMOutput]: return_dict = return_dict if return_dict is not None else self.config.return_dict vle_output = self.vle( input_ids = input_ids, pixel_values = pixel_values, attention_mask = attention_mask, position_ids = position_ids, token_type_ids = token_type_ids, patch_ids = patch_ids,) text_feats = vle_output.text_embeds mlm_logits = self.mlm_score(text_feats) mlm_loss = None if return_loss and mlm_labels is not None: mlm_loss = F.cross_entropy( mlm_logits.view(-1, self.config.text_config.vocab_size), mlm_labels.view(-1), ignore_index=-100, ) if not return_dict: output = (mlm_logits,) return ((mlm_loss,) + output) if mlm_loss is not None else output return VLEForMLMOutput(loss = mlm_loss, logits = mlm_logits) def get_output_embeddings(self): return self.mlm_score[1].predictions.decoder def set_output_embeddings(self, new_embeddings): self.mlm_score[1].predictions.decoder = new_embeddings