# coding=utf-8 # Copyright 2024 FLMR Authors, The Hugging Face Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch FLMR model for Knowledge-intensive Visual Question Answering.""" import copy import os import pathlib import string from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.distributed as dist from torch import Tensor, nn from torch.utils.cpp_extension import load from transformers import AutoModel, AutoConfig from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel from transformers.utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) from transformers.models.bert.modeling_bert import BertModel from transformers.models.clip import CLIPVisionModel from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast from .flmr_utils import ( colbert_score, colbert_score_reduce, get_rank, get_world_size, ) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "FLMRConfig" _CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L" FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [ "LinWeizheDragon/PreFLMR_ViT-L", "LinWeizheDragon/FLMR", # See all FLMR models at https://huggingface.co/models?filter=flmr ] ########## # Outputs ########## @dataclass class FLMRContextEncoderOutput(ModelOutput): """ Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`]. Args: pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation. This output can be used to embed questions for nearest neighbors queries with query embeddings. late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`): The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval. This output is to be used to embed contexts for late-interaction retrieval with query embeddings. context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`): The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation. text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the text encoder's layers. Each element is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`. text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`. vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`. vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`. transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`. transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`. """ pooler_output: torch.FloatTensor late_interaction_output: torch.FloatTensor = None context_mask: torch.FloatTensor = None text_encoder_attentions: Optional[Tuple[Tensor]] = None text_encoder_hidden_states: Optional[Tuple[Tensor]] = None vision_encoder_attentions: Optional[Tuple[Tensor]] = None vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None @dataclass class FLMRQueryEncoderOutput(ModelOutput): """ Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`]. Args: pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation. This output can be used to embed questions for nearest neighbors queries with context embeddings. late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval. This output is to be used to embed questions for late-interaction retrieval with context embeddings. text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the text encoder's layers. Each element is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`. text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`. vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`. vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`. transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`. transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`. """ pooler_output: torch.FloatTensor late_interaction_output: torch.FloatTensor = None text_encoder_attentions: Optional[Tuple[Tensor]] = None text_encoder_hidden_states: Optional[Tuple[Tensor]] = None vision_encoder_attentions: Optional[Tuple[Tensor]] = None vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None @dataclass class FLMRModelForRetrievalOutput(ModelOutput): """ Class for outputs of [`FLMRModelForRetrieval.query()`]. Args: loss (`torch.FloatTensor`): contrastive loss of the input queries and positive and negative examples. This output is to be used in model training. scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`): The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples. in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training. query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query. context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`): The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context. query_attentions (`Tuple[Tuple[Tensor]]`, *optional*): Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder. query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*): Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder. context_attentions (`Tuple[Tuple[Tensor]]`, *optional*): Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder. context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*): Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder. """ loss: torch.FloatTensor scores: torch.FloatTensor = None in_batch_negative_loss: torch.FloatTensor = None query_late_interaction_output: torch.FloatTensor = None context_late_interaction_output: torch.FloatTensor = None query_attentions: Optional[Tuple[Tuple[Tensor]]] = None query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None context_attentions: Optional[Tuple[Tuple[Tensor]]] = None context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None class FLMRPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) ################## # PreTrainedModel ################## class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = FLMRConfig load_tf_weights = None base_model_prefix = "flmr" ############### # Actual Models ############### FLMR_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`FLMRConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query. The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`. context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context. The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`. """ FLMR_MODEL_INPUTS_DOCSTRING = r""" Args: query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`): Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be formatted with [CLS] and Q marker tokens as follows: [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ... FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`): Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be formatted with [CLS] and D marker tokens as follows: [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ... FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. use_in_batch_negatives (`bool`, *optional*): Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training. in_batch_negatives_from_all_gpus (`bool`, *optional*): Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training. num_negative_examples (`int`, *optional*): The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`. query_concat_output_from_vision_encoder (`bool`, *optional*): Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. query_concat_output_from_text_encoder (`bool`, *optional*): Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations. context_concat_output_from_vision_encoder (`bool`, *optional*): Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used. This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided. context_concat_output_from_text_encoder (`bool`, *optional*): Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail. """ FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`): Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be formatted with [CLS] and Q marker tokens as follows: [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ... FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. concat_output_from_vision_encoder (`bool`, *optional*): Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. concat_output_from_text_encoder (`bool`, *optional*): Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations. """ FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`): Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be formatted with [CLS] and D marker tokens as follows: [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ... FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained using [`CLIPVisionModel`]. See [`CLIPVisionModel .__call__`] for details. concat_output_from_vision_encoder (`bool`, *optional*): Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used. This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided. concat_output_from_text_encoder (`bool`, *optional*): Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. keep_dims (`bool`, *optional*): Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training. return_mask (`bool`, *optional*): Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training. """ FLMR_TEXT_ENCODERS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ # Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be formatted with [CLS] and [SEP] tokens as follows: (a) For sequence pairs (for a pair title+text for example): ``` tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 ``` (b) For single sequences (for a question for example): ``` tokens: [CLS] the dog is hairy . [SEP] token_type_ids: 0 0 0 0 0 0 0 ``` FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. [What are token type IDs?](../glossary#token-type-ids) inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ FLMR_VISION_ENCODERS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ # Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ class FLMRMultiLayerPerceptron(nn.Module): """ A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model. """ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def __init__(self, sizes, bias=True, act=nn.Tanh): super(FLMRMultiLayerPerceptron, self).__init__() layers = [] for i in range(len(sizes) - 1): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) if i < len(sizes) - 2: layers.append(act()) self.model = nn.Sequential(*layers) @add_start_docstrings( "The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ", FLMR_START_DOCSTRING, ) class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval): _keys_to_ignore_on_load_unexpected = [r"cls"] main_input_name = "query_input_ids" _tied_weights_keys = [] # Added dynamically at initialization depending on the architecture def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None): super().__init__(config) self.config = config self.vision_model_version = config.vision_model_version self.context_text_encoder = FLMRTextModel(config.text_config) self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False) self.query_tokenizer = query_tokenizer self.context_tokenizer = context_tokenizer if self.query_tokenizer is None: logger.warning( "query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer." ) # initialize a FLMRQueryEncoderTokenizer self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased") if self.context_tokenizer is None: logger.warning( "context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer." ) # initialize a FLMRContextEncoderTokenizer self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased") self.mapping_network_prefix_length = self.config.mapping_network_prefix_length self.vision_encoder_embedding_size = self.config.vision_config.hidden_size self.text_encoder_embedding_size = self.config.text_config.hidden_size self.late_interaction_embedding_size = self.config.dim if self.config.use_vision_encoder: self.context_vision_projection = FLMRMultiLayerPerceptron( ( self.vision_encoder_embedding_size, (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2, self.late_interaction_embedding_size * self.mapping_network_prefix_length, ) ) if self.config.use_vision_encoder: self.context_vision_encoder = FLMRVisionModel(config.vision_config) if self.config.use_transformer_mapping_network: # This is a PreFLMR style model transformer_mapping_config_base = self.config.transformer_mapping_config_base try: from transformers import BertConfig from transformers.models.bert.modeling_bert import BertEncoder except Exception as e: raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}") transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base) assert ( self.config.text_config.hidden_size == transformer_mapping_config.hidden_size ), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match." # shallow transformer transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers # add cross attention transformer_mapping_config.is_decoder = True transformer_mapping_config.add_cross_attention = True # The linear layer from vision encoder to transformer input self.transformer_mapping_input_linear = nn.Linear( self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size ) # The transformer encoder self.transformer_mapping_network = BertEncoder(transformer_mapping_config) # The linear layer from transformer output to FLMR dim self.transformer_mapping_output_linear = nn.Linear( transformer_mapping_config.hidden_size, self.late_interaction_embedding_size ) if self.config.separate_query_and_context_text_encoder: self.query_text_encoder = copy.deepcopy(self.context_text_encoder) self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear) else: self.query_text_encoder = self.context_text_encoder self.query_text_encoder_linear = self.context_text_encoder_linear self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"] if self.config.use_vision_encoder: if self.config.separate_query_and_context_vision_encoder: self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder) self.query_vision_projection = copy.deepcopy(self.context_vision_projection) else: self.query_vision_encoder = self.context_vision_encoder self.query_vision_projection = self.context_vision_projection self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"] if self.config.load_cpu_extension: try: FLMRModelForRetrieval.try_load_torch_extensions() except Exception as e: raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}") if self.config.mask_punctuation: self.skiplist = { w: True for symbol in string.punctuation for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]] } if self.config.mask_instruction_token is not None: self.mask_instruction = True # obtain the token id of the instruction token self.instruction_token_id = self.query_tokenizer.encode( self.config.mask_instruction_token, add_special_tokens=False )[0] else: self.mask_instruction = False self.loss_fn = torch.nn.CrossEntropyLoss() # Initialize weights and apply final processing self.post_init() @property def use_gpu(self): return self.device.type == "cuda" @classmethod def from_pretrained(self, name_or_path, **kwargs): obj = super().from_pretrained(name_or_path, **kwargs) return obj @classmethod def try_load_torch_extensions(cls): if hasattr(cls, "loaded_extensions"): return logger.info( "Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..." ) segmented_maxsim_cpp = load( name="segmented_maxsim_cpp", sources=[ os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"), ], extra_cflags=["-O3"], verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", ) cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp cls.loaded_extensions = True def query_mask(self, input_ids, skiplist): if not self.mask_instruction: return self.mask(input_ids, skiplist) # find the position of end of instruction in input_ids # mask the tokens before the position sep_id = self.instruction_token_id sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist() # if any of the positions is lower than 1, set to 1 for i, x in enumerate(sep_positions): if x < 1: sep_positions[i] = 1 logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}") mask = [ [ (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2) for index, x in enumerate(d) ] for seq_index, d in enumerate(input_ids.cpu().tolist()) ] return mask @add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC) def forward( self, query_input_ids: Optional[torch.Tensor] = None, query_attention_mask: Optional[torch.Tensor] = None, query_pixel_values: Optional[torch.Tensor] = None, query_image_features: Optional[torch.Tensor] = None, context_input_ids: Optional[torch.Tensor] = None, context_attention_mask: Optional[torch.Tensor] = None, context_pixel_values: Optional[torch.Tensor] = None, context_image_features: Optional[torch.Tensor] = None, use_in_batch_negatives: bool = True, in_batch_negatives_from_all_gpus: bool = False, num_negative_examples: int = 1, query_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, query_concat_output_from_text_encoder: Optional[Union[bool, list]] = None, context_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, context_concat_output_from_text_encoder: Optional[Union[bool, list]] = None, return_dict: bool = None, output_attentions: bool = None, output_hidden_states: bool = None, ) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]: r""" Return: Examples: ```python >>> import torch >>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor >>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L" >>> image_processor_name = "openai/clip-vit-large-patch14" >>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer") >>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer") >>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path, query_tokenizer=query_tokenizer, context_tokenizer=context_tokenizer, ) >>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name) >>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"]) >>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.", "Paris is the capital of France.", "Beijing is the capital of China."]) >>> Q_pixel_values = torch.zeros(2, 3, 224, 224) >>> inputs = dict( query_input_ids=Q_encoding['input_ids'], query_attention_mask=Q_encoding['attention_mask'], query_pixel_values=Q_pixel_values, context_input_ids=D_encoding['input_ids'], context_attention_mask=D_encoding['attention_mask'], use_in_batch_negatives=True, ) >>> model.forward(**inputs) FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16, grad_fn=), scores=tensor([[44.2188, 40.6562], [39.4375, 48.4062]], device='cuda:0', dtype=torch.float16, grad_fn=), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...) ``` """ if query_concat_output_from_vision_encoder is None: query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder if query_concat_output_from_text_encoder is None: query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder if context_concat_output_from_vision_encoder is None: context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder if context_concat_output_from_text_encoder is None: context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict query_outputs = self.query( input_ids=query_input_ids, attention_mask=query_attention_mask, pixel_values=query_pixel_values, image_features=query_image_features, concat_output_from_vision_encoder=query_concat_output_from_vision_encoder, concat_output_from_text_encoder=query_concat_output_from_text_encoder, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) Q = query_outputs.late_interaction_output context_outputs = self.doc( input_ids=context_input_ids, attention_mask=context_attention_mask, pixel_values=context_pixel_values, image_features=context_image_features, concat_output_from_vision_encoder=context_concat_output_from_vision_encoder, concat_output_from_text_encoder=context_concat_output_from_text_encoder, keep_dims=True, return_mask=True, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask # Gather tensors from other GPUs if in_batch_negatives_from_all_gpus: Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask) # Repeat each query encoding for every corresponding document. Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous() scores = self.score(Q_duplicated, D, D_mask) # Use contrastive learning batch_size = query_input_ids.shape[0] scores = scores.view(-1, num_negative_examples + 1) labels = torch.zeros(batch_size, dtype=torch.long, device=self.device) if use_in_batch_negatives: ib_loss = self.compute_ib_loss_new(Q, D, D_mask) loss = ib_loss else: loss = self.loss_fn(scores, labels) ib_loss = None if output_attentions: query_attentions = ( query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None, query_outputs.vision_encoder_attentions if query_outputs.vision_encoder_attentions is not None else None, query_outputs.transformer_mapping_network_attentions if query_outputs.transformer_mapping_network_attentions is not None else None, ) context_attentions = ( context_outputs.text_encoder_attentions if context_outputs.text_encoder_attentions is not None else None, context_outputs.vision_encoder_attentions if context_outputs.vision_encoder_attentions is not None else None, context_outputs.transformer_mapping_network_attentions if context_outputs.transformer_mapping_network_attentions is not None else None, ) else: query_attentions = None context_attentions = None if output_hidden_states: query_hidden_states = ( query_outputs.text_encoder_hidden_states if query_outputs.text_encoder_hidden_states is not None else None, query_outputs.vision_encoder_hidden_states if query_outputs.vision_encoder_hidden_states is not None else None, query_outputs.transformer_mapping_network_hidden_states if query_outputs.transformer_mapping_network_hidden_states is not None else None, ) context_hidden_states = ( context_outputs.text_encoder_hidden_states if context_outputs.text_encoder_hidden_states is not None else None, context_outputs.vision_encoder_hidden_states if context_outputs.vision_encoder_hidden_states is not None else None, context_outputs.transformer_mapping_network_hidden_states if context_outputs.transformer_mapping_network_hidden_states is not None else None, ) else: query_hidden_states = None context_hidden_states = None if not return_dict: if output_attentions and output_hidden_states: return ( loss, scores, ib_loss, query_outputs.late_interaction_output, context_outputs.late_interaction_output, query_attentions, query_hidden_states, context_attentions, context_hidden_states, ) elif output_attentions: return ( loss, scores, ib_loss, query_outputs.late_interaction_output, context_outputs.late_interaction_output, query_attentions, context_attentions, ) elif output_hidden_states: return ( loss, scores, ib_loss, query_outputs.late_interaction_output, context_outputs.late_interaction_output, query_hidden_states, context_hidden_states, ) else: return ( loss, scores, ib_loss, query_outputs.late_interaction_output, context_outputs.late_interaction_output, ) return FLMRModelForRetrievalOutput( loss=loss, scores=scores, in_batch_negative_loss=ib_loss, query_late_interaction_output=query_outputs.late_interaction_output, context_late_interaction_output=context_outputs.late_interaction_output, query_attentions=query_attentions if output_attentions else None, query_hidden_states=query_hidden_states if output_hidden_states else None, context_attentions=context_attentions if output_attentions else None, context_hidden_states=context_hidden_states if output_hidden_states else None, ) def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor: # Q: batch_size x q_len x dim # D: batch_size*n_docs x i_len x dim # D_mask: batch_size*n_docs x i_len x dim # 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim # = batch_size x batch_size*n_docs x i_len x q_len scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten( 0, 1 ) # query-major unsqueeze scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1)) in_batch_scores = scores.reshape(Q.size(0), -1) batch_size = Q.shape[0] batch_size_with_pos_and_neg = D.shape[0] num_pos_and_neg = batch_size_with_pos_and_neg // batch_size # batch_size x dim matmul dim x (num_pos+num_neg)*batch_size # --> batch_size x (num_pos+num_neg)*batch_size in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device) step = num_pos_and_neg for i in range(batch_size): in_batch_labels[i, step * i] = 1 # print('in_batch_labels', in_batch_labels) in_batch_labels = torch.argmax(in_batch_labels, dim=1) # print('in_batch_labels', in_batch_labels) loss = self.loss_fn(in_batch_scores, in_batch_labels) return loss def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask): # print("get rank", get_rank()) # print("get world size", get_world_size()) # Gather embeddings from other GPUs n_nodes = get_world_size() if n_nodes == 1: return query_embeddings, item_embeddings, item_mask # Create placeholder to hold embeddings passed from other ranks global_query_embeddings_placeholder = [ torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device) for _ in range(n_nodes) ] global_item_embeddings_placeholder = [ torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device) for _ in range(n_nodes) ] global_item_mask_placeholder = [ torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes) ] dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach()) dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach()) dist.all_gather(global_item_mask_placeholder, item_mask.detach()) global_query_embeddings = [] global_item_embeddings = [] global_item_mask = [] # print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings) # print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings) # input() current_rank = get_rank() for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder): # We append the embeddings from other GPUs if this embedding does not require gradients if rank_index != current_rank: global_query_embeddings.append(remote_q_embeddings) else: global_query_embeddings.append(query_embeddings) for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder): # We append the embeddings from other GPUs if this embedding does not require gradients if rank_index != current_rank: global_item_embeddings.append(remote_item_embeddings) else: global_item_embeddings.append(item_embeddings) for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder): # We append the embeddings from other GPUs if this embedding does not require gradients if rank_index != current_rank: global_item_mask.append(remote_item_mask) else: global_item_mask.append(item_mask) # Replace the previous variables with gathered tensors query_embeddings = torch.cat(global_query_embeddings) item_embeddings = torch.cat(global_item_embeddings) item_mask = torch.cat(global_item_mask) return query_embeddings, item_embeddings, item_mask @add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC) def query( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_features: Optional[torch.Tensor] = None, concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, concat_output_from_text_encoder: Optional[Union[bool, list]] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ): r""" Returns: """ if concat_output_from_vision_encoder is None: concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder if concat_output_from_text_encoder is None: concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) input_modality = [] if pixel_values is not None or image_features is not None: input_modality.append("image") if input_ids is not None and attention_mask is not None: input_modality.append("text") text_encoder_outputs = None vision_encoder_outputs = None transformer_mapping_outputs = None if "image" in input_modality: assert ( pixel_values is not None or image_features is not None ), "pixel_values or image_features must be provided if image modality is used" assert ( pixel_values is None or image_features is None ), "pixel_values and image_features cannot be provided at the same time" if "text" in input_modality: assert ( input_ids is not None and attention_mask is not None ), "input_ids and attention_mask must be provided if text modality is used" # Forward the text encoder input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask) text_encoder_hidden_states = text_encoder_outputs[0] text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states) mask = torch.tensor(self.query_mask(input_ids, skiplist=self.config.query_mask_input_ids_skip_list), device=self.device).unsqueeze(2).float() text_embeddings = text_embeddings * mask if "image" in input_modality: if pixel_values is not None: batch_size = pixel_values.shape[0] # Forward the vision encoder pixel_values = pixel_values.to(self.device) if len(pixel_values.shape) == 5: # Multiple ROIs are provided # merge the first two dimensions pixel_values = pixel_values.reshape( -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4] ) vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True) vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0] if image_features is not None: batch_size = image_features.shape[0] vision_embeddings = image_features.to(self.device) # Forward the vision projection / mapping network vision_embeddings = self.query_vision_projection(vision_embeddings) vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size) if self.config.use_transformer_mapping_network: # select the second last layer vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:] # transformer_mapping transformer_mapping_input_features = self.transformer_mapping_input_linear( vision_second_last_layer_hidden_states ) # Cross attention only attends to the first 32 tokens encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype) if len(self.config.query_mask_input_ids_skip_list) > 0: encoder_mask[torch.isin(input_ids, torch.tensor(self.config.query_mask_input_ids_skip_list))] = 0 cross_attention_length = self.config.transformer_mapping_cross_attention_length if text_encoder_hidden_states.shape[1] > cross_attention_length: text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length] encoder_mask = encoder_mask[:, :cross_attention_length] # Obtain cross attention mask encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1)) # Pass through the transformer mapping transformer_mapping_outputs = self.transformer_mapping_network( transformer_mapping_input_features, encoder_hidden_states=text_encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, ) transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state # Convert the dimension to FLMR dim transformer_mapping_output_features = self.transformer_mapping_output_linear( transformer_mapping_output_features ) # Merge with the vision embeddings vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1) if concat_output_from_vision_encoder and concat_output_from_text_encoder: Q = torch.cat([text_embeddings, vision_embeddings], dim=1) if isinstance(concat_output_from_vision_encoder, list) or isinstance(concat_output_from_text_encoder, list): # When lists are passed in, mask the output accordingly assert isinstance(concat_output_from_vision_encoder, list) and isinstance(concat_output_from_text_encoder, list), "concat_output_from_vision_encoder and concat_output_from_text_encoder must be of the same type." # obtain the size of each output text_size = text_embeddings.shape[1] vision_size = vision_embeddings.shape[1] # Prepare the mask concat_output_mask = torch.zeros_like(Q).to(Q.device) # Mask the late interaction outputs concat_output_mask[:, :text_size] = torch.tensor(concat_output_from_text_encoder).bool().unsqueeze(-1).unsqueeze(-1) concat_output_mask[:, text_size:] = torch.tensor(concat_output_from_vision_encoder).bool().unsqueeze(-1).unsqueeze(-1) Q = Q * concat_output_mask elif concat_output_from_vision_encoder: Q = vision_embeddings elif concat_output_from_text_encoder: Q = text_embeddings vision_encoder_attentions = ( vision_encoder_outputs.attentions if vision_encoder_outputs is not None and hasattr(vision_encoder_outputs, "attentions") and output_attentions else None ) vision_encoder_hidden_states = ( vision_encoder_outputs.hidden_states if vision_encoder_outputs is not None and hasattr(vision_encoder_outputs, "hidden_states") and output_hidden_states else None ) text_encoder_attentions = ( text_encoder_outputs.attentions if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions else None ) text_encoder_hidden_states = ( text_encoder_outputs.hidden_states if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "hidden_states") and output_hidden_states else None ) transformer_mapping_network_attentions = ( transformer_mapping_outputs.attentions if transformer_mapping_outputs is not None and hasattr(transformer_mapping_outputs, "attentions") and output_attentions else None ) transformer_mapping_network_hidden_states = ( transformer_mapping_outputs.hidden_states if transformer_mapping_outputs is not None and hasattr(transformer_mapping_outputs, "hidden_states") and output_hidden_states else None ) return FLMRQueryEncoderOutput( pooler_output=Q[:, 0, :], late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2), vision_encoder_attentions=vision_encoder_attentions, vision_encoder_hidden_states=vision_encoder_hidden_states, text_encoder_attentions=text_encoder_attentions, text_encoder_hidden_states=text_encoder_hidden_states, transformer_mapping_network_attentions=transformer_mapping_network_attentions, transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states, ) @add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) def doc( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pixel_values: Optional[torch.Tensor] = None, image_features: Optional[torch.Tensor] = None, concat_output_from_vision_encoder: Optional[bool] = None, concat_output_from_text_encoder: Optional[bool] = None, keep_dims: Optional[bool] = True, return_mask: Optional[bool] = True, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ): r""" Returns: """ assert keep_dims in [True, False] if concat_output_from_vision_encoder is None: concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder if concat_output_from_text_encoder is None: concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) input_modality = [] if pixel_values is not None or image_features is not None: input_modality.append("image") if input_ids is not None and attention_mask is not None: input_modality.append("text") text_encoder_outputs = None vision_encoder_outputs = None if "image" in input_modality: assert ( pixel_values is not None or image_features is not None ), "pixel_values or image_features must be provided if image modality is used" assert ( pixel_values is None or image_features is None ), "pixel_values and image_features cannot be provided at the same time" if "text" in input_modality: assert ( input_ids is not None and attention_mask is not None ), "input_ids and attention_mask must be provided if text modality is used" # Forward the text encoder input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask) text_embeddings = text_encoder_outputs[0] text_embeddings = self.context_text_encoder_linear(text_embeddings) mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float() text_embeddings = text_embeddings * mask if "image" in input_modality: if pixel_values is not None: # Forward the vision encoder pixel_values = pixel_values.to(self.device) vision_encoder_outputs = self.context_vision_encoder(pixel_values) vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0] if image_features is not None: vision_embeddings = image_features.to(self.device) batch_size = vision_embeddings.shape[0] # Forward the vision projection / mapping network vision_embeddings = self.context_vision_projection(vision_embeddings) vision_embeddings = vision_embeddings.view( -1, self.mapping_network_prefix_length, self.late_interaction_embedding_size ) image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device) if concat_output_from_vision_encoder and concat_output_from_text_encoder: # Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask # TODO: fix the engine to support masks with discontinuous 0 and 1. D = torch.cat([vision_embeddings, text_embeddings], dim=1) # concatenate the mask mask = torch.cat([image_mask, mask], dim=1) elif concat_output_from_vision_encoder: D = vision_embeddings mask = image_mask elif concat_output_from_text_encoder: D = text_embeddings mask = mask D = torch.nn.functional.normalize(D, p=2, dim=2) if self.use_gpu: D = D.half() if keep_dims is False: D, mask = D.cpu(), mask.bool().cpu().squeeze(-1) D = [d[mask[idx]] for idx, d in enumerate(D)] vision_encoder_attentions = ( vision_encoder_outputs.attentions if vision_encoder_outputs is not None and hasattr(vision_encoder_outputs, "attentions") and output_attentions else None ) vision_encoder_hidden_states = ( vision_encoder_outputs.hidden_states if vision_encoder_outputs is not None and hasattr(vision_encoder_outputs, "hidden_states") and output_hidden_states else None ) text_encoder_attentions = ( text_encoder_outputs.attentions if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions else None ) text_encoder_hidden_states = ( text_encoder_outputs.hidden_states if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "hidden_states") and output_hidden_states else None ) return FLMRContextEncoderOutput( pooler_output=D[:, 0, :], late_interaction_output=D, context_mask=mask.bool() if return_mask else None, vision_encoder_attentions=vision_encoder_attentions, vision_encoder_hidden_states=vision_encoder_hidden_states, text_encoder_attentions=text_encoder_attentions, text_encoder_hidden_states=text_encoder_hidden_states, ) def score(self, Q, D_padded, D_mask): # assert self.colbert_config.similarity == 'cosine' # if self.colbert_config.similarity == 'l2': # assert self.colbert_config.interaction == 'colbert' # return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1) return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu) def mask(self, input_ids, skiplist): mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()] return mask @add_start_docstrings( "The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.", FLMR_TEXT_ENCODERS_START_DOCSTRING, ) class FLMRTextModel(FLMRPreTrainedModel): base_model_prefix = "flmr_text_model" config_class = FLMRTextConfig def __init__(self, config: FLMRTextConfig, *args, **kwargs): super().__init__(config) if config.text_encoder_base_model == "bert-base-uncased": self.bert_model = BertModel(config, add_pooling_layer=True) else: self.bert_model = AutoModel.from_pretrained(config.text_encoder_base_model, *args, **kwargs) if self.bert_model.config.hidden_size <= 0: raise ValueError("Encoder hidden_size can't be zero") self.projection_dim = config.projection_dim if self.projection_dim > 0: self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim) # Initialize weights and apply final processing self.post_init() self.text_model = self.bert_model @add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig) def forward( self, input_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, token_type_ids: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None, output_attentions: bool = None, output_hidden_states: bool = None, return_dict: bool = None, ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]: r""" Returns: """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] pooled_output = sequence_output[:, 0, :] if self.projection_dim > 0: pooled_output = self.encode_proj(pooled_output) if not return_dict: return (sequence_output, pooled_output) + outputs[2:] return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @property def embeddings_size(self) -> int: if self.projection_dim > 0: return self.encode_proj.out_features return self.text_model.config.hidden_size @add_start_docstrings( "The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.", FLMR_VISION_ENCODERS_START_DOCSTRING, ) class FLMRVisionModel(FLMRPreTrainedModel): base_model_prefix = "vision_model" config_class = FLMRVisionConfig main_input_name = "pixel_values" _no_split_modules = ["CLIPEncoderLayer"] def __init__(self, config: FLMRVisionConfig): super().__init__(config) self.vision_model = CLIPVisionModel(config) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.vision_model.vision_model.embeddings.patch_embedding @add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig) def forward( self, pixel_values: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: Examples: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, FLMRVisionModel >>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32") >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled CLS states ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict return self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, )