|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch FLMR model for Knowledge-intensive Visual Question Answering.""" |
|
|
|
|
|
import copy |
|
import os |
|
import pathlib |
|
import string |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import Tensor, nn |
|
from torch.utils.cpp_extension import load |
|
|
|
from transformers import AutoModel, AutoConfig |
|
from transformers.modeling_outputs import BaseModelOutputWithPooling |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
ModelOutput, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from transformers.models.bert.modeling_bert import BertModel |
|
from transformers.models.clip import CLIPVisionModel |
|
from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig |
|
from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer |
|
from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast |
|
from .flmr_utils import ( |
|
colbert_score, |
|
colbert_score_reduce, |
|
get_rank, |
|
get_world_size, |
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "FLMRConfig" |
|
_CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L" |
|
|
|
|
|
FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [ |
|
"LinWeizheDragon/PreFLMR_ViT-L", |
|
"LinWeizheDragon/FLMR", |
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class FLMRContextEncoderOutput(ModelOutput): |
|
""" |
|
Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`]. |
|
|
|
Args: |
|
pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): |
|
The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation. |
|
This output can be used to embed questions for nearest neighbors queries with query embeddings. |
|
late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`): |
|
The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval. |
|
This output is to be used to embed contexts for late-interaction retrieval with query embeddings. |
|
context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`): |
|
The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation. |
|
text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the text encoder's layers. Each element is a |
|
tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`. |
|
text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding |
|
outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`. |
|
vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a |
|
tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`. |
|
vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding |
|
outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`. |
|
transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element |
|
is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`. |
|
transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the |
|
initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`. |
|
""" |
|
|
|
pooler_output: torch.FloatTensor |
|
late_interaction_output: torch.FloatTensor = None |
|
context_mask: torch.FloatTensor = None |
|
text_encoder_attentions: Optional[Tuple[Tensor]] = None |
|
text_encoder_hidden_states: Optional[Tuple[Tensor]] = None |
|
vision_encoder_attentions: Optional[Tuple[Tensor]] = None |
|
vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None |
|
transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None |
|
transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None |
|
|
|
|
|
@dataclass |
|
class FLMRQueryEncoderOutput(ModelOutput): |
|
""" |
|
Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`]. |
|
|
|
Args: |
|
pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): |
|
The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation. |
|
This output can be used to embed questions for nearest neighbors queries with context embeddings. |
|
late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): |
|
The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval. |
|
This output is to be used to embed questions for late-interaction retrieval with context embeddings. |
|
text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the text encoder's layers. Each element is a |
|
tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`. |
|
text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding |
|
outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`. |
|
vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a |
|
tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`. |
|
vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding |
|
outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`. |
|
transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element |
|
is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`. |
|
transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
|
Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the |
|
initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`. |
|
""" |
|
|
|
pooler_output: torch.FloatTensor |
|
late_interaction_output: torch.FloatTensor = None |
|
text_encoder_attentions: Optional[Tuple[Tensor]] = None |
|
text_encoder_hidden_states: Optional[Tuple[Tensor]] = None |
|
vision_encoder_attentions: Optional[Tuple[Tensor]] = None |
|
vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None |
|
transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None |
|
transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None |
|
|
|
|
|
@dataclass |
|
class FLMRModelForRetrievalOutput(ModelOutput): |
|
""" |
|
Class for outputs of [`FLMRModelForRetrieval.query()`]. |
|
|
|
Args: |
|
loss (`torch.FloatTensor`): |
|
contrastive loss of the input queries and positive and negative examples. This output is to be used in model training. |
|
scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`): |
|
The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples. |
|
in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): |
|
The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training. |
|
query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`): |
|
The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query. |
|
context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`): |
|
The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context. |
|
query_attentions (`Tuple[Tuple[Tensor]]`, *optional*): |
|
Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder. |
|
query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*): |
|
Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder. |
|
context_attentions (`Tuple[Tuple[Tensor]]`, *optional*): |
|
Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder. |
|
context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*): |
|
Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder. |
|
""" |
|
|
|
loss: torch.FloatTensor |
|
scores: torch.FloatTensor = None |
|
in_batch_negative_loss: torch.FloatTensor = None |
|
query_late_interaction_output: torch.FloatTensor = None |
|
context_late_interaction_output: torch.FloatTensor = None |
|
query_attentions: Optional[Tuple[Tuple[Tensor]]] = None |
|
query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None |
|
context_attentions: Optional[Tuple[Tuple[Tensor]]] = None |
|
context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None |
|
|
|
|
|
class FLMRPreTrainedModel(PreTrainedModel): |
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = FLMRConfig |
|
load_tf_weights = None |
|
base_model_prefix = "flmr" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FLMR_START_DOCSTRING = r""" |
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`FLMRConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query. |
|
The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`. |
|
context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context. |
|
The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`. |
|
""" |
|
|
|
|
|
FLMR_MODEL_INPUTS_DOCSTRING = r""" |
|
Args: |
|
query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`): |
|
Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be |
|
formatted with [CLS] and Q marker tokens as follows: |
|
[CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ... |
|
|
|
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right |
|
rather than the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): |
|
Pixel values. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): |
|
Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained |
|
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. |
|
context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`): |
|
Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be |
|
formatted with [CLS] and D marker tokens as follows: |
|
[CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ... |
|
|
|
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right |
|
rather than the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. |
|
|
|
context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. |
|
context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): |
|
Pixel values. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): |
|
Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained |
|
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. |
|
use_in_batch_negatives (`bool`, *optional*): |
|
Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training. |
|
in_batch_negatives_from_all_gpus (`bool`, *optional*): |
|
Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training. |
|
num_negative_examples (`int`, *optional*): |
|
The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`. |
|
query_concat_output_from_vision_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
query_concat_output_from_text_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
|
|
This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations. |
|
context_concat_output_from_vision_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used. |
|
|
|
This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided. |
|
context_concat_output_from_text_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail. |
|
""" |
|
|
|
|
|
FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`): |
|
Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be |
|
formatted with [CLS] and Q marker tokens as follows: |
|
[CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ... |
|
|
|
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right |
|
rather than the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): |
|
Pixel values. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): |
|
Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained |
|
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details. |
|
concat_output_from_vision_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
concat_output_from_text_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
|
|
This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations. |
|
""" |
|
|
|
|
|
FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`): |
|
Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be |
|
formatted with [CLS] and D marker tokens as follows: |
|
[CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ... |
|
|
|
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right |
|
rather than the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
|
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. |
|
attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
|
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details. |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): |
|
Pixel values. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*): |
|
Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained |
|
using [`CLIPVisionModel`]. See [`CLIPVisionModel |
|
.__call__`] for details. |
|
concat_output_from_vision_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used. |
|
|
|
This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided. |
|
concat_output_from_text_encoder (`bool`, *optional*): |
|
Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models. |
|
keep_dims (`bool`, *optional*): |
|
Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training. |
|
return_mask (`bool`, *optional*): |
|
Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training. |
|
""" |
|
|
|
|
|
FLMR_TEXT_ENCODERS_START_DOCSTRING = r""" |
|
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
|
|
FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be |
|
formatted with [CLS] and [SEP] tokens as follows: |
|
|
|
(a) For sequence pairs (for a pair title+text for example): |
|
|
|
``` |
|
tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] |
|
token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 |
|
``` |
|
|
|
(b) For single sequences (for a question for example): |
|
|
|
``` |
|
tokens: [CLS] the dog is hairy . [SEP] |
|
token_type_ids: 0 0 0 0 0 0 0 |
|
``` |
|
|
|
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right |
|
rather than the left. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
|
1]`: |
|
|
|
- 0 corresponds to a *sentence A* token, |
|
- 1 corresponds to a *sentence B* token. |
|
|
|
[What are token type IDs?](../glossary#token-type-ids) |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
FLMR_VISION_ENCODERS_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r""" |
|
Args: |
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): |
|
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using |
|
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
|
|
|
|
class FLMRMultiLayerPerceptron(nn.Module): |
|
""" |
|
A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model. |
|
""" |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.model(x) |
|
|
|
def __init__(self, sizes, bias=True, act=nn.Tanh): |
|
super(FLMRMultiLayerPerceptron, self).__init__() |
|
layers = [] |
|
for i in range(len(sizes) - 1): |
|
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) |
|
if i < len(sizes) - 2: |
|
layers.append(act()) |
|
self.model = nn.Sequential(*layers) |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ", |
|
FLMR_START_DOCSTRING, |
|
) |
|
class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval): |
|
_keys_to_ignore_on_load_unexpected = [r"cls"] |
|
main_input_name = "query_input_ids" |
|
_tied_weights_keys = [] |
|
|
|
def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None): |
|
super().__init__(config) |
|
self.config = config |
|
self.vision_model_version = config.vision_model_version |
|
|
|
self.context_text_encoder = FLMRTextModel(config.text_config) |
|
self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False) |
|
|
|
self.query_tokenizer = query_tokenizer |
|
self.context_tokenizer = context_tokenizer |
|
|
|
if self.query_tokenizer is None: |
|
logger.warning( |
|
"query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer." |
|
) |
|
|
|
|
|
self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
if self.context_tokenizer is None: |
|
logger.warning( |
|
"context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer." |
|
) |
|
|
|
|
|
self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
self.mapping_network_prefix_length = self.config.mapping_network_prefix_length |
|
self.vision_encoder_embedding_size = self.config.vision_config.hidden_size |
|
self.text_encoder_embedding_size = self.config.text_config.hidden_size |
|
self.late_interaction_embedding_size = self.config.dim |
|
|
|
if self.config.use_vision_encoder: |
|
self.context_vision_projection = FLMRMultiLayerPerceptron( |
|
( |
|
self.vision_encoder_embedding_size, |
|
(self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2, |
|
self.late_interaction_embedding_size * self.mapping_network_prefix_length, |
|
) |
|
) |
|
|
|
if self.config.use_vision_encoder: |
|
self.context_vision_encoder = FLMRVisionModel(config.vision_config) |
|
|
|
if self.config.use_transformer_mapping_network: |
|
|
|
transformer_mapping_config_base = self.config.transformer_mapping_config_base |
|
try: |
|
from transformers import BertConfig |
|
from transformers.models.bert.modeling_bert import BertEncoder |
|
except Exception as e: |
|
raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}") |
|
|
|
transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base) |
|
|
|
assert ( |
|
self.config.text_config.hidden_size == transformer_mapping_config.hidden_size |
|
), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match." |
|
|
|
transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers |
|
|
|
transformer_mapping_config.is_decoder = True |
|
transformer_mapping_config.add_cross_attention = True |
|
|
|
|
|
self.transformer_mapping_input_linear = nn.Linear( |
|
self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size |
|
) |
|
|
|
|
|
self.transformer_mapping_network = BertEncoder(transformer_mapping_config) |
|
|
|
|
|
self.transformer_mapping_output_linear = nn.Linear( |
|
transformer_mapping_config.hidden_size, self.late_interaction_embedding_size |
|
) |
|
|
|
if self.config.separate_query_and_context_text_encoder: |
|
self.query_text_encoder = copy.deepcopy(self.context_text_encoder) |
|
self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear) |
|
else: |
|
self.query_text_encoder = self.context_text_encoder |
|
self.query_text_encoder_linear = self.context_text_encoder_linear |
|
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"] |
|
|
|
if self.config.use_vision_encoder: |
|
if self.config.separate_query_and_context_vision_encoder: |
|
self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder) |
|
self.query_vision_projection = copy.deepcopy(self.context_vision_projection) |
|
else: |
|
self.query_vision_encoder = self.context_vision_encoder |
|
self.query_vision_projection = self.context_vision_projection |
|
self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"] |
|
|
|
if self.config.load_cpu_extension: |
|
try: |
|
FLMRModelForRetrieval.try_load_torch_extensions() |
|
except Exception as e: |
|
raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}") |
|
|
|
if self.config.mask_punctuation: |
|
self.skiplist = { |
|
w: True |
|
for symbol in string.punctuation |
|
for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]] |
|
} |
|
|
|
if self.config.mask_instruction_token is not None: |
|
self.mask_instruction = True |
|
|
|
self.instruction_token_id = self.query_tokenizer.encode( |
|
self.config.mask_instruction_token, add_special_tokens=False |
|
)[0] |
|
else: |
|
self.mask_instruction = False |
|
|
|
self.loss_fn = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
self.post_init() |
|
|
|
@property |
|
def use_gpu(self): |
|
return self.device.type == "cuda" |
|
|
|
@classmethod |
|
def from_pretrained(self, name_or_path, **kwargs): |
|
obj = super().from_pretrained(name_or_path, **kwargs) |
|
return obj |
|
|
|
@classmethod |
|
def try_load_torch_extensions(cls): |
|
if hasattr(cls, "loaded_extensions"): |
|
return |
|
|
|
logger.info( |
|
"Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..." |
|
) |
|
segmented_maxsim_cpp = load( |
|
name="segmented_maxsim_cpp", |
|
sources=[ |
|
os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"), |
|
], |
|
extra_cflags=["-O3"], |
|
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True", |
|
) |
|
cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp |
|
|
|
cls.loaded_extensions = True |
|
|
|
def query_mask(self, input_ids, skiplist): |
|
if not self.mask_instruction: |
|
return self.mask(input_ids, skiplist) |
|
|
|
|
|
|
|
sep_id = self.instruction_token_id |
|
sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist() |
|
|
|
for i, x in enumerate(sep_positions): |
|
if x < 1: |
|
sep_positions[i] = 1 |
|
logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}") |
|
mask = [ |
|
[ |
|
(x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2) |
|
for index, x in enumerate(d) |
|
] |
|
for seq_index, d in enumerate(input_ids.cpu().tolist()) |
|
] |
|
return mask |
|
|
|
@add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
query_input_ids: Optional[torch.Tensor] = None, |
|
query_attention_mask: Optional[torch.Tensor] = None, |
|
query_pixel_values: Optional[torch.Tensor] = None, |
|
query_image_features: Optional[torch.Tensor] = None, |
|
context_input_ids: Optional[torch.Tensor] = None, |
|
context_attention_mask: Optional[torch.Tensor] = None, |
|
context_pixel_values: Optional[torch.Tensor] = None, |
|
context_image_features: Optional[torch.Tensor] = None, |
|
use_in_batch_negatives: bool = True, |
|
in_batch_negatives_from_all_gpus: bool = False, |
|
num_negative_examples: int = 1, |
|
query_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, |
|
query_concat_output_from_text_encoder: Optional[Union[bool, list]] = None, |
|
context_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, |
|
context_concat_output_from_text_encoder: Optional[Union[bool, list]] = None, |
|
return_dict: bool = None, |
|
output_attentions: bool = None, |
|
output_hidden_states: bool = None, |
|
) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]: |
|
r""" |
|
Return: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> import torch |
|
>>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor |
|
|
|
>>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L" |
|
>>> image_processor_name = "openai/clip-vit-large-patch14" |
|
>>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer") |
|
>>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer") |
|
|
|
>>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path, |
|
query_tokenizer=query_tokenizer, |
|
context_tokenizer=context_tokenizer, |
|
) |
|
>>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name) |
|
|
|
>>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"]) |
|
>>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.", |
|
"Paris is the capital of France.", "Beijing is the capital of China."]) |
|
>>> Q_pixel_values = torch.zeros(2, 3, 224, 224) |
|
>>> inputs = dict( |
|
query_input_ids=Q_encoding['input_ids'], |
|
query_attention_mask=Q_encoding['attention_mask'], |
|
query_pixel_values=Q_pixel_values, |
|
context_input_ids=D_encoding['input_ids'], |
|
context_attention_mask=D_encoding['attention_mask'], |
|
use_in_batch_negatives=True, |
|
) |
|
|
|
>>> model.forward(**inputs) |
|
FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16, |
|
grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562], |
|
[39.4375, 48.4062]], device='cuda:0', dtype=torch.float16, |
|
grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...) |
|
``` |
|
""" |
|
|
|
if query_concat_output_from_vision_encoder is None: |
|
query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder |
|
|
|
if query_concat_output_from_text_encoder is None: |
|
query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder |
|
|
|
if context_concat_output_from_vision_encoder is None: |
|
context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder |
|
|
|
if context_concat_output_from_text_encoder is None: |
|
context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
query_outputs = self.query( |
|
input_ids=query_input_ids, |
|
attention_mask=query_attention_mask, |
|
pixel_values=query_pixel_values, |
|
image_features=query_image_features, |
|
concat_output_from_vision_encoder=query_concat_output_from_vision_encoder, |
|
concat_output_from_text_encoder=query_concat_output_from_text_encoder, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
Q = query_outputs.late_interaction_output |
|
|
|
context_outputs = self.doc( |
|
input_ids=context_input_ids, |
|
attention_mask=context_attention_mask, |
|
pixel_values=context_pixel_values, |
|
image_features=context_image_features, |
|
concat_output_from_vision_encoder=context_concat_output_from_vision_encoder, |
|
concat_output_from_text_encoder=context_concat_output_from_text_encoder, |
|
keep_dims=True, |
|
return_mask=True, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask |
|
|
|
|
|
if in_batch_negatives_from_all_gpus: |
|
Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask) |
|
|
|
Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous() |
|
|
|
scores = self.score(Q_duplicated, D, D_mask) |
|
|
|
|
|
batch_size = query_input_ids.shape[0] |
|
scores = scores.view(-1, num_negative_examples + 1) |
|
labels = torch.zeros(batch_size, dtype=torch.long, device=self.device) |
|
|
|
|
|
if use_in_batch_negatives: |
|
ib_loss = self.compute_ib_loss_new(Q, D, D_mask) |
|
loss = ib_loss |
|
else: |
|
loss = self.loss_fn(scores, labels) |
|
ib_loss = None |
|
|
|
if output_attentions: |
|
query_attentions = ( |
|
query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None, |
|
query_outputs.vision_encoder_attentions |
|
if query_outputs.vision_encoder_attentions is not None |
|
else None, |
|
query_outputs.transformer_mapping_network_attentions |
|
if query_outputs.transformer_mapping_network_attentions is not None |
|
else None, |
|
) |
|
context_attentions = ( |
|
context_outputs.text_encoder_attentions |
|
if context_outputs.text_encoder_attentions is not None |
|
else None, |
|
context_outputs.vision_encoder_attentions |
|
if context_outputs.vision_encoder_attentions is not None |
|
else None, |
|
context_outputs.transformer_mapping_network_attentions |
|
if context_outputs.transformer_mapping_network_attentions is not None |
|
else None, |
|
) |
|
else: |
|
query_attentions = None |
|
context_attentions = None |
|
|
|
if output_hidden_states: |
|
query_hidden_states = ( |
|
query_outputs.text_encoder_hidden_states |
|
if query_outputs.text_encoder_hidden_states is not None |
|
else None, |
|
query_outputs.vision_encoder_hidden_states |
|
if query_outputs.vision_encoder_hidden_states is not None |
|
else None, |
|
query_outputs.transformer_mapping_network_hidden_states |
|
if query_outputs.transformer_mapping_network_hidden_states is not None |
|
else None, |
|
) |
|
context_hidden_states = ( |
|
context_outputs.text_encoder_hidden_states |
|
if context_outputs.text_encoder_hidden_states is not None |
|
else None, |
|
context_outputs.vision_encoder_hidden_states |
|
if context_outputs.vision_encoder_hidden_states is not None |
|
else None, |
|
context_outputs.transformer_mapping_network_hidden_states |
|
if context_outputs.transformer_mapping_network_hidden_states is not None |
|
else None, |
|
) |
|
else: |
|
query_hidden_states = None |
|
context_hidden_states = None |
|
|
|
if not return_dict: |
|
if output_attentions and output_hidden_states: |
|
return ( |
|
loss, |
|
scores, |
|
ib_loss, |
|
query_outputs.late_interaction_output, |
|
context_outputs.late_interaction_output, |
|
query_attentions, |
|
query_hidden_states, |
|
context_attentions, |
|
context_hidden_states, |
|
) |
|
elif output_attentions: |
|
return ( |
|
loss, |
|
scores, |
|
ib_loss, |
|
query_outputs.late_interaction_output, |
|
context_outputs.late_interaction_output, |
|
query_attentions, |
|
context_attentions, |
|
) |
|
elif output_hidden_states: |
|
return ( |
|
loss, |
|
scores, |
|
ib_loss, |
|
query_outputs.late_interaction_output, |
|
context_outputs.late_interaction_output, |
|
query_hidden_states, |
|
context_hidden_states, |
|
) |
|
else: |
|
return ( |
|
loss, |
|
scores, |
|
ib_loss, |
|
query_outputs.late_interaction_output, |
|
context_outputs.late_interaction_output, |
|
) |
|
|
|
return FLMRModelForRetrievalOutput( |
|
loss=loss, |
|
scores=scores, |
|
in_batch_negative_loss=ib_loss, |
|
query_late_interaction_output=query_outputs.late_interaction_output, |
|
context_late_interaction_output=context_outputs.late_interaction_output, |
|
query_attentions=query_attentions if output_attentions else None, |
|
query_hidden_states=query_hidden_states if output_hidden_states else None, |
|
context_attentions=context_attentions if output_attentions else None, |
|
context_hidden_states=context_hidden_states if output_hidden_states else None, |
|
) |
|
|
|
def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
|
|
scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten( |
|
0, 1 |
|
) |
|
scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1)) |
|
|
|
in_batch_scores = scores.reshape(Q.size(0), -1) |
|
|
|
batch_size = Q.shape[0] |
|
batch_size_with_pos_and_neg = D.shape[0] |
|
num_pos_and_neg = batch_size_with_pos_and_neg // batch_size |
|
|
|
|
|
|
|
in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device) |
|
step = num_pos_and_neg |
|
for i in range(batch_size): |
|
in_batch_labels[i, step * i] = 1 |
|
|
|
in_batch_labels = torch.argmax(in_batch_labels, dim=1) |
|
|
|
|
|
loss = self.loss_fn(in_batch_scores, in_batch_labels) |
|
|
|
return loss |
|
|
|
def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask): |
|
|
|
|
|
|
|
n_nodes = get_world_size() |
|
if n_nodes == 1: |
|
return query_embeddings, item_embeddings, item_mask |
|
|
|
global_query_embeddings_placeholder = [ |
|
torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device) |
|
for _ in range(n_nodes) |
|
] |
|
global_item_embeddings_placeholder = [ |
|
torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device) |
|
for _ in range(n_nodes) |
|
] |
|
global_item_mask_placeholder = [ |
|
torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes) |
|
] |
|
dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach()) |
|
dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach()) |
|
dist.all_gather(global_item_mask_placeholder, item_mask.detach()) |
|
|
|
global_query_embeddings = [] |
|
global_item_embeddings = [] |
|
global_item_mask = [] |
|
|
|
|
|
|
|
current_rank = get_rank() |
|
for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder): |
|
|
|
if rank_index != current_rank: |
|
global_query_embeddings.append(remote_q_embeddings) |
|
else: |
|
global_query_embeddings.append(query_embeddings) |
|
|
|
for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder): |
|
|
|
if rank_index != current_rank: |
|
global_item_embeddings.append(remote_item_embeddings) |
|
else: |
|
global_item_embeddings.append(item_embeddings) |
|
|
|
for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder): |
|
|
|
if rank_index != current_rank: |
|
global_item_mask.append(remote_item_mask) |
|
else: |
|
global_item_mask.append(item_mask) |
|
|
|
|
|
query_embeddings = torch.cat(global_query_embeddings) |
|
item_embeddings = torch.cat(global_item_embeddings) |
|
item_mask = torch.cat(global_item_mask) |
|
|
|
return query_embeddings, item_embeddings, item_mask |
|
|
|
@add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC) |
|
def query( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
image_features: Optional[torch.Tensor] = None, |
|
concat_output_from_vision_encoder: Optional[Union[bool, list]] = None, |
|
concat_output_from_text_encoder: Optional[Union[bool, list]] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
""" |
|
|
|
if concat_output_from_vision_encoder is None: |
|
concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder |
|
|
|
if concat_output_from_text_encoder is None: |
|
concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
input_modality = [] |
|
if pixel_values is not None or image_features is not None: |
|
input_modality.append("image") |
|
if input_ids is not None and attention_mask is not None: |
|
input_modality.append("text") |
|
|
|
text_encoder_outputs = None |
|
vision_encoder_outputs = None |
|
transformer_mapping_outputs = None |
|
|
|
if "image" in input_modality: |
|
assert ( |
|
pixel_values is not None or image_features is not None |
|
), "pixel_values or image_features must be provided if image modality is used" |
|
assert ( |
|
pixel_values is None or image_features is None |
|
), "pixel_values and image_features cannot be provided at the same time" |
|
|
|
if "text" in input_modality: |
|
assert ( |
|
input_ids is not None and attention_mask is not None |
|
), "input_ids and attention_mask must be provided if text modality is used" |
|
|
|
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) |
|
text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask) |
|
text_encoder_hidden_states = text_encoder_outputs[0] |
|
text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states) |
|
mask = torch.tensor(self.query_mask(input_ids, skiplist=self.config.query_mask_input_ids_skip_list), device=self.device).unsqueeze(2).float() |
|
|
|
text_embeddings = text_embeddings * mask |
|
|
|
if "image" in input_modality: |
|
if pixel_values is not None: |
|
batch_size = pixel_values.shape[0] |
|
|
|
pixel_values = pixel_values.to(self.device) |
|
if len(pixel_values.shape) == 5: |
|
|
|
|
|
pixel_values = pixel_values.reshape( |
|
-1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4] |
|
) |
|
vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True) |
|
vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0] |
|
|
|
if image_features is not None: |
|
batch_size = image_features.shape[0] |
|
vision_embeddings = image_features.to(self.device) |
|
|
|
|
|
vision_embeddings = self.query_vision_projection(vision_embeddings) |
|
vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size) |
|
|
|
if self.config.use_transformer_mapping_network: |
|
|
|
vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:] |
|
|
|
transformer_mapping_input_features = self.transformer_mapping_input_linear( |
|
vision_second_last_layer_hidden_states |
|
) |
|
|
|
|
|
encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype) |
|
if len(self.config.query_mask_input_ids_skip_list) > 0: |
|
encoder_mask[torch.isin(input_ids, torch.tensor(self.config.query_mask_input_ids_skip_list))] = 0 |
|
cross_attention_length = self.config.transformer_mapping_cross_attention_length |
|
if text_encoder_hidden_states.shape[1] > cross_attention_length: |
|
text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length] |
|
encoder_mask = encoder_mask[:, :cross_attention_length] |
|
|
|
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1)) |
|
|
|
transformer_mapping_outputs = self.transformer_mapping_network( |
|
transformer_mapping_input_features, |
|
encoder_hidden_states=text_encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
) |
|
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state |
|
|
|
transformer_mapping_output_features = self.transformer_mapping_output_linear( |
|
transformer_mapping_output_features |
|
) |
|
|
|
vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1) |
|
|
|
if concat_output_from_vision_encoder and concat_output_from_text_encoder: |
|
Q = torch.cat([text_embeddings, vision_embeddings], dim=1) |
|
if isinstance(concat_output_from_vision_encoder, list) or isinstance(concat_output_from_text_encoder, list): |
|
|
|
assert isinstance(concat_output_from_vision_encoder, list) and isinstance(concat_output_from_text_encoder, list), "concat_output_from_vision_encoder and concat_output_from_text_encoder must be of the same type." |
|
|
|
text_size = text_embeddings.shape[1] |
|
vision_size = vision_embeddings.shape[1] |
|
|
|
|
|
concat_output_mask = torch.zeros_like(Q).to(Q.device) |
|
|
|
|
|
concat_output_mask[:, :text_size] = torch.tensor(concat_output_from_text_encoder).bool().unsqueeze(-1).unsqueeze(-1) |
|
concat_output_mask[:, text_size:] = torch.tensor(concat_output_from_vision_encoder).bool().unsqueeze(-1).unsqueeze(-1) |
|
|
|
Q = Q * concat_output_mask |
|
|
|
elif concat_output_from_vision_encoder: |
|
Q = vision_embeddings |
|
elif concat_output_from_text_encoder: |
|
Q = text_embeddings |
|
|
|
vision_encoder_attentions = ( |
|
vision_encoder_outputs.attentions |
|
if vision_encoder_outputs is not None |
|
and hasattr(vision_encoder_outputs, "attentions") |
|
and output_attentions |
|
else None |
|
) |
|
vision_encoder_hidden_states = ( |
|
vision_encoder_outputs.hidden_states |
|
if vision_encoder_outputs is not None |
|
and hasattr(vision_encoder_outputs, "hidden_states") |
|
and output_hidden_states |
|
else None |
|
) |
|
text_encoder_attentions = ( |
|
text_encoder_outputs.attentions |
|
if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions |
|
else None |
|
) |
|
text_encoder_hidden_states = ( |
|
text_encoder_outputs.hidden_states |
|
if text_encoder_outputs is not None |
|
and hasattr(text_encoder_outputs, "hidden_states") |
|
and output_hidden_states |
|
else None |
|
) |
|
transformer_mapping_network_attentions = ( |
|
transformer_mapping_outputs.attentions |
|
if transformer_mapping_outputs is not None |
|
and hasattr(transformer_mapping_outputs, "attentions") |
|
and output_attentions |
|
else None |
|
) |
|
transformer_mapping_network_hidden_states = ( |
|
transformer_mapping_outputs.hidden_states |
|
if transformer_mapping_outputs is not None |
|
and hasattr(transformer_mapping_outputs, "hidden_states") |
|
and output_hidden_states |
|
else None |
|
) |
|
|
|
return FLMRQueryEncoderOutput( |
|
pooler_output=Q[:, 0, :], |
|
late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2), |
|
vision_encoder_attentions=vision_encoder_attentions, |
|
vision_encoder_hidden_states=vision_encoder_hidden_states, |
|
text_encoder_attentions=text_encoder_attentions, |
|
text_encoder_hidden_states=text_encoder_hidden_states, |
|
transformer_mapping_network_attentions=transformer_mapping_network_attentions, |
|
transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states, |
|
) |
|
|
|
@add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) |
|
def doc( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
image_features: Optional[torch.Tensor] = None, |
|
concat_output_from_vision_encoder: Optional[bool] = None, |
|
concat_output_from_text_encoder: Optional[bool] = None, |
|
keep_dims: Optional[bool] = True, |
|
return_mask: Optional[bool] = True, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
): |
|
r""" |
|
Returns: |
|
|
|
""" |
|
assert keep_dims in [True, False] |
|
|
|
if concat_output_from_vision_encoder is None: |
|
concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder |
|
|
|
if concat_output_from_text_encoder is None: |
|
concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
input_modality = [] |
|
if pixel_values is not None or image_features is not None: |
|
input_modality.append("image") |
|
if input_ids is not None and attention_mask is not None: |
|
input_modality.append("text") |
|
|
|
text_encoder_outputs = None |
|
vision_encoder_outputs = None |
|
|
|
if "image" in input_modality: |
|
assert ( |
|
pixel_values is not None or image_features is not None |
|
), "pixel_values or image_features must be provided if image modality is used" |
|
assert ( |
|
pixel_values is None or image_features is None |
|
), "pixel_values and image_features cannot be provided at the same time" |
|
|
|
if "text" in input_modality: |
|
assert ( |
|
input_ids is not None and attention_mask is not None |
|
), "input_ids and attention_mask must be provided if text modality is used" |
|
|
|
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device) |
|
text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask) |
|
text_embeddings = text_encoder_outputs[0] |
|
text_embeddings = self.context_text_encoder_linear(text_embeddings) |
|
|
|
mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float() |
|
text_embeddings = text_embeddings * mask |
|
|
|
if "image" in input_modality: |
|
if pixel_values is not None: |
|
|
|
pixel_values = pixel_values.to(self.device) |
|
vision_encoder_outputs = self.context_vision_encoder(pixel_values) |
|
vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0] |
|
|
|
if image_features is not None: |
|
vision_embeddings = image_features.to(self.device) |
|
|
|
batch_size = vision_embeddings.shape[0] |
|
|
|
|
|
vision_embeddings = self.context_vision_projection(vision_embeddings) |
|
vision_embeddings = vision_embeddings.view( |
|
-1, self.mapping_network_prefix_length, self.late_interaction_embedding_size |
|
) |
|
|
|
image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device) |
|
|
|
if concat_output_from_vision_encoder and concat_output_from_text_encoder: |
|
|
|
|
|
D = torch.cat([vision_embeddings, text_embeddings], dim=1) |
|
|
|
mask = torch.cat([image_mask, mask], dim=1) |
|
elif concat_output_from_vision_encoder: |
|
D = vision_embeddings |
|
mask = image_mask |
|
elif concat_output_from_text_encoder: |
|
D = text_embeddings |
|
mask = mask |
|
|
|
D = torch.nn.functional.normalize(D, p=2, dim=2) |
|
|
|
if self.use_gpu: |
|
D = D.half() |
|
|
|
if keep_dims is False: |
|
D, mask = D.cpu(), mask.bool().cpu().squeeze(-1) |
|
D = [d[mask[idx]] for idx, d in enumerate(D)] |
|
|
|
vision_encoder_attentions = ( |
|
vision_encoder_outputs.attentions |
|
if vision_encoder_outputs is not None |
|
and hasattr(vision_encoder_outputs, "attentions") |
|
and output_attentions |
|
else None |
|
) |
|
vision_encoder_hidden_states = ( |
|
vision_encoder_outputs.hidden_states |
|
if vision_encoder_outputs is not None |
|
and hasattr(vision_encoder_outputs, "hidden_states") |
|
and output_hidden_states |
|
else None |
|
) |
|
text_encoder_attentions = ( |
|
text_encoder_outputs.attentions |
|
if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions |
|
else None |
|
) |
|
text_encoder_hidden_states = ( |
|
text_encoder_outputs.hidden_states |
|
if text_encoder_outputs is not None |
|
and hasattr(text_encoder_outputs, "hidden_states") |
|
and output_hidden_states |
|
else None |
|
) |
|
|
|
return FLMRContextEncoderOutput( |
|
pooler_output=D[:, 0, :], |
|
late_interaction_output=D, |
|
context_mask=mask.bool() if return_mask else None, |
|
vision_encoder_attentions=vision_encoder_attentions, |
|
vision_encoder_hidden_states=vision_encoder_hidden_states, |
|
text_encoder_attentions=text_encoder_attentions, |
|
text_encoder_hidden_states=text_encoder_hidden_states, |
|
) |
|
|
|
def score(self, Q, D_padded, D_mask): |
|
|
|
|
|
|
|
|
|
return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu) |
|
|
|
def mask(self, input_ids, skiplist): |
|
mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()] |
|
return mask |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.", |
|
FLMR_TEXT_ENCODERS_START_DOCSTRING, |
|
) |
|
class FLMRTextModel(FLMRPreTrainedModel): |
|
base_model_prefix = "flmr_text_model" |
|
config_class = FLMRTextConfig |
|
|
|
def __init__(self, config: FLMRTextConfig, *args, **kwargs): |
|
super().__init__(config) |
|
if config.text_encoder_base_model == "bert-base-uncased": |
|
self.bert_model = BertModel(config, add_pooling_layer=True) |
|
else: |
|
self.bert_model = AutoModel.from_pretrained(config.text_encoder_base_model, *args, **kwargs) |
|
if self.bert_model.config.hidden_size <= 0: |
|
raise ValueError("Encoder hidden_size can't be zero") |
|
self.projection_dim = config.projection_dim |
|
if self.projection_dim > 0: |
|
self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim) |
|
|
|
self.post_init() |
|
self.text_model = self.bert_model |
|
|
|
@add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[Tensor] = None, |
|
attention_mask: Optional[Tensor] = None, |
|
token_type_ids: Optional[Tensor] = None, |
|
inputs_embeds: Optional[Tensor] = None, |
|
output_attentions: bool = None, |
|
output_hidden_states: bool = None, |
|
return_dict: bool = None, |
|
) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]: |
|
r""" |
|
Returns: |
|
|
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.text_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = outputs[0] |
|
pooled_output = sequence_output[:, 0, :] |
|
|
|
if self.projection_dim > 0: |
|
pooled_output = self.encode_proj(pooled_output) |
|
|
|
if not return_dict: |
|
return (sequence_output, pooled_output) + outputs[2:] |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=sequence_output, |
|
pooler_output=pooled_output, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
@property |
|
def embeddings_size(self) -> int: |
|
if self.projection_dim > 0: |
|
return self.encode_proj.out_features |
|
return self.text_model.config.hidden_size |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.", |
|
FLMR_VISION_ENCODERS_START_DOCSTRING, |
|
) |
|
class FLMRVisionModel(FLMRPreTrainedModel): |
|
base_model_prefix = "vision_model" |
|
config_class = FLMRVisionConfig |
|
main_input_name = "pixel_values" |
|
_no_split_modules = ["CLIPEncoderLayer"] |
|
|
|
def __init__(self, config: FLMRVisionConfig): |
|
super().__init__(config) |
|
self.vision_model = CLIPVisionModel(config) |
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.vision_model.vision_model.embeddings.patch_embedding |
|
|
|
@add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig) |
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
|
r""" |
|
Returns: |
|
|
|
Examples: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, FLMRVisionModel |
|
|
|
>>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32") |
|
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
|
>>> outputs = model(**inputs) |
|
>>> last_hidden_state = outputs.last_hidden_state |
|
>>> pooled_output = outputs.pooler_output # pooled CLS states |
|
```""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
return self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|