LinWeizheDragon commited on
Commit
66ae8fc
1 Parent(s): 4430fe1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ context_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ query_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ language:
5
+ - en
6
+ - zh
7
+ tags:
8
+ - retrieval
9
+ - multi-modal
10
+ - knowledge-based visual question answering
11
+ - FLMR
12
+ - PreFLMR
13
+ ---
14
+
15
+ # PreFLMR model card
16
+
17
+ PreFLMR is an open-source model for multimodal knowledge retrieval. It is a transformer-based model that uses a combination of text and image inputs to retrieve relevant documents from a large corpus.
18
+
19
+ ## Model Details
20
+
21
+ PreFLMR_ViT-L_ENCN is based on PreFLMR_ViT-L, and the text_encoder is replaced with [bge-m3](https://huggingface.co/BAAI/bge-m3) for training. The training dataset includes [Chinese](https://huggingface.co/datasets/BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR_CN) and [English](https://huggingface.co/datasets/BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR) datasets.
22
+
23
+ ### Model Description
24
+
25
+ - **Model type:** FLMRModelForRetrieval
26
+ - **Language(s) (NLP):** English Chinese
27
+ - **License:** MIT License
28
+
29
+ ### Paper and resources for more detail
30
+
31
+ - **Blog Post for quick overview:** https://www.jinghong-chen.net/preflmr-sota-open-sourced-multi/
32
+ - **Paper:** https://arxiv.org/abs/2402.08327
33
+ - **Gradio Demo:** https://u60544-b8d4-53eaa55d.westx.seetacloud.com:8443/
34
+ - **Repository:** https://github.com/LinWeizheDragon/FLMR
35
+ - **Project Page:** https://preflmr.github.io/
36
+
37
+ ## Uses
38
+
39
+ ### Direct Use
40
+
41
+ This model can be used directly to retrieve documents from a large corpus using a combination of text and image input queries. The retrieval usage can be found in the [official implementation](https://github.com/LinWeizheDragon/FLMR).
42
+
43
+ ### Downstream Use
44
+
45
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
46
+
47
+ This model can be used combined with language models to create a retrieval-augmented language model. The use for Knowledge-based VQA can be found in [RAVQA](https://github.com/linweizhedragon/retrieval-augmented-visual-question-answering)
48
+
49
+ ## How to Get Started with the Model
50
+
51
+ For details of training, indexing, and performing retrieval, please refer to [here](https://github.com/LinWeizheDragon/FLMR).
52
+
53
+ ## Training datasets
54
+
55
+ The model is pre-trained on three types of tasks with a total of nine datasets:
56
+ 1. Image to Text retrieval: WIT, KVQA, and CC3M
57
+ 2. Question to Text retrieval: MSMARCO
58
+ 3. Image & Question to Text retrieval: LLaVA, OVEN, OKVQA, Infoseek and E-VQA
59
+
60
+ These datasets were converted to retrieval format. For details on the dataset split and conversion process, please refer to the paper [PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers](https://arxiv.org/abs/2402.08327). We will release the proprocessed datasets soon.
61
+
62
+
63
+ ## Evaluation datasets
64
+ We evaluate our models on WIT, LLaVA, OVEN, KVQA, IGLUE (subset of WIT), Infoseek, E-VQA, OKVQA and MSMARCO.
65
+ | Model | Vision Encoder | Text Encoder | Checkpoint Name | No. Param. | WIT(EN) | WIT(CN) | LLaVA(EN) | LLaVA(CN) | OVEN(EN) | OVEN(CN) | KVQA(EN) | KVQA(EN) | IGLUE(EN) | Infoseek(EN) | Infoseek(CN) | EVQA(EN) | EVQA(CN) | OKVQA(EN) | OKVQA(CN) | MSMARCO(EN) | MSMARCO(CN) |
66
+ | ------- | :------------- | ------------ | ------------------------------------------------------------ | ---------- | ------- | ------- | --------- | --------- | -------- | -------- | -------- | -------- | --------- | ------------ | ------------ | -------- | -------- | --------- | --------- | ----------- | ----------- |
67
+ | PreFLMR | ViT-L | Base-v2 | [LinWeizheDragon/PreFLMR_ViT-L](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L) | 543M | 60.5 | 10.9 | 71.8 | 3.2 | 59.8 | 6.6 | 43.6 | 3.2 | 69.2 | 57.9 | 7.9 | 70.8 | 2.8 | 68.5 | 2.1 | 78.7 | 10.3 |
68
+ | PreFLMR | Vit-L_ENCN | bge-m3 | [LinWeizheDragon/PreFLMR_ViT-L_ENCN](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L_ENCN) | 883M | 60.8 | 83.4 | 71.11 | 58.93 | 60.8 | 58.83 | 41.05 | 37.27 | | 41.91 | 39.70 | 57.97 | 46.64 | 13.87 | 13.32 | 82.6 | 82.33 |
69
+
70
+ For the evaluation metrics, WIT uses Recall@10, IGLUE uses Recall@1, and all the rest datasets use Recall@5.
71
+
72
+
73
+ ## Citation
74
+
75
+ **BibTeX:**
76
+ ```
77
+ @article{Lin_Mei_Chen_Byrne_2024,
78
+ title={PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers},
79
+ url={http://arxiv.org/abs/2402.08327},
80
+ number={arXiv:2402.08327},
81
+ publisher={arXiv},
82
+ author={Lin, Weizhe and Mei, Jingbiao and Chen, Jinghong and Byrne, Bill},
83
+ year={2024}}
84
+ ```
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LinWeizheDragon/PreFLMR_ViT-L_ENCN",
3
+ "architectures": [
4
+ "FLMRModelForRetrieval"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_flmr.FLMRConfig",
8
+ "AutoModel": "modeling_flmr.FLMRModelForRetrieval"
9
+ },
10
+ "context_concat_output_from_text_encoder": true,
11
+ "context_concat_output_from_vision_encoder": false,
12
+ "dim": 128,
13
+ "initializer_range": 0.02,
14
+ "load_cpu_extension": true,
15
+ "mapping_network_prefix_length": 32,
16
+ "mask_instruction_token": "\n:",
17
+ "mask_punctuation": true,
18
+ "model_type": "flmr",
19
+ "query_concat_output_from_text_encoder": true,
20
+ "query_concat_output_from_vision_encoder": true,
21
+ "query_mask_input_ids_skip_list": [
22
+ 1
23
+ ],
24
+ "separate_query_and_context_text_encoder": false,
25
+ "separate_query_and_context_vision_encoder": false,
26
+ "text_config": {
27
+ "architectures": [
28
+ "XLMRobertaModel"
29
+ ],
30
+ "gradient_checkpointing": false,
31
+ "hidden_size": 1024,
32
+ "model_type": "flmr_text_model",
33
+ "projection_dim": 1024,
34
+ "text_encoder_base_model": "BAAI/bge-m3",
35
+ "query_maxlen": 64,
36
+ "use_cache": true
37
+ },
38
+ "torch_dtype": "float32",
39
+ "transformer_mapping_config_base": "google-bert/bert-large-uncased",
40
+ "transformer_mapping_cross_attention_length": 32,
41
+ "transformer_mapping_num_hidden_layers": 1,
42
+ "transformers_version": "4.44.2",
43
+ "use_transformer_mapping_network": false,
44
+ "use_vision_encoder": true,
45
+ "vision_config": {
46
+ "dropout": 0.0,
47
+ "hidden_size": 1024,
48
+ "intermediate_size": 4096,
49
+ "model_type": "flmr_vision_model",
50
+ "num_attention_heads": 16,
51
+ "num_hidden_layers": 24,
52
+ "patch_size": 14,
53
+ "projection_dim": 768
54
+ },
55
+ "vision_model_version": "openai/clip-vit-large-patch14"
56
+ }
configuration_flmr.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2010, FLMR authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ FLMR model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ FLMR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "LinWeizheDragon/PreFLMR_ViT-L": "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/config.json",
28
+ "LinWeizheDragon/FLMR": "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ # Modified from transformers.models.clip.configuration_clip.CLIPVisionConfig with CLIP -> FLMR
33
+ class FLMRVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`FLMRVisionModel`]. It is used to instantiate a
36
+ FLMR vision encoder according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the FLMR
38
+ [openai/flmr-vit-base-patch32](https://huggingface.co/openai/flmr-vit-base-patch32) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ intermediate_size (`int`, *optional*, defaults to 3072):
47
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
48
+ projection_dim (`int`, *optional*, defaults to 512):
49
+ Dimentionality of text and vision projection layers.
50
+ num_hidden_layers (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 12):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_channels (`int`, *optional*, defaults to 3):
55
+ The number of input channels.
56
+ image_size (`int`, *optional*, defaults to 224):
57
+ The size (resolution) of each image.
58
+ patch_size (`int`, *optional*, defaults to 32):
59
+ The size (resolution) of each patch.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
62
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
64
+ The epsilon used by the layer normalization layers.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio for the attention probabilities.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ initializer_factor (`float`, *optional*, defaults to 1.0):
70
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
71
+ testing).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import FLMRVisionConfig, FLMRVisionModel
77
+
78
+ >>> # Initializing a FLMRVisionConfig with LinWeizheDragon/FLMR style configuration
79
+ >>> configuration = FLMRVisionConfig()
80
+
81
+ >>> # Initializing a FLMRVisionModel (with random weights) from the LinWeizheDragon/FLMR style configuration
82
+ >>> model = FLMRVisionModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "flmr_vision_model"
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size=768,
93
+ intermediate_size=3072,
94
+ projection_dim=512,
95
+ num_hidden_layers=12,
96
+ num_attention_heads=12,
97
+ num_channels=3,
98
+ image_size=224,
99
+ patch_size=32,
100
+ hidden_act="quick_gelu",
101
+ layer_norm_eps=1e-5,
102
+ attention_dropout=0.0,
103
+ initializer_range=0.02,
104
+ initializer_factor=1.0,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+
109
+ self.hidden_size = hidden_size
110
+ self.intermediate_size = intermediate_size
111
+ self.projection_dim = projection_dim
112
+ self.num_hidden_layers = num_hidden_layers
113
+ self.num_attention_heads = num_attention_heads
114
+ self.num_channels = num_channels
115
+ self.patch_size = patch_size
116
+ self.image_size = image_size
117
+ self.initializer_range = initializer_range
118
+ self.initializer_factor = initializer_factor
119
+ self.attention_dropout = attention_dropout
120
+ self.layer_norm_eps = layer_norm_eps
121
+ self.hidden_act = hidden_act
122
+
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
125
+ cls._set_token_in_kwargs(kwargs)
126
+
127
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
128
+
129
+ # get the vision config dict if we are loading from a CLIPConfig
130
+ if config_dict.get("model_type") == "clip":
131
+ config_dict = config_dict["vision_config"]
132
+
133
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
134
+ logger.warning(
135
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
136
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
137
+ )
138
+
139
+ return cls.from_dict(config_dict, **kwargs)
140
+
141
+
142
+ # Modified from transformers.models.dpr.configuration_dpr.DPRConfig with DPR -> FLMR
143
+ class FLMRTextConfig(PretrainedConfig):
144
+ r"""
145
+ [`FLMRTextConfig`] is the configuration class to store the configuration of a *FLMRTextModel*.
146
+
147
+ This is the configuration class to store the configuration of a [`FLMRTextModel`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
148
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
149
+ configuration to that of the DPRContextEncoder
150
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
151
+ architecture.
152
+
153
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
154
+
155
+ Args:
156
+ vocab_size (`int`, *optional*, defaults to 30522):
157
+ Vocabulary size of the FLMR model. Defines the different tokens that can be represented by the *inputs_ids*
158
+ passed to the forward method of [`BertModel`].
159
+ hidden_size (`int`, *optional*, defaults to 768):
160
+ Dimensionality of the encoder layers and the pooler layer.
161
+ num_hidden_layers (`int`, *optional*, defaults to 12):
162
+ Number of hidden layers in the Transformer encoder.
163
+ num_attention_heads (`int`, *optional*, defaults to 12):
164
+ Number of attention heads for each attention layer in the Transformer encoder.
165
+ intermediate_size (`int`, *optional*, defaults to 3072):
166
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
167
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
168
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
169
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
170
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
171
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
172
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
173
+ The dropout ratio for the attention probabilities.
174
+ max_position_embeddings (`int`, *optional*, defaults to 512):
175
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
176
+ just in case (e.g., 512 or 1024 or 2048).
177
+ type_vocab_size (`int`, *optional*, defaults to 2):
178
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
179
+ initializer_range (`float`, *optional*, defaults to 0.02):
180
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
181
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
182
+ The epsilon used by the layer normalization layers.
183
+ pad_token_id (`int`, *optional*, defaults to 0):
184
+ Padding token id.
185
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
186
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
187
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
188
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
189
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
190
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
191
+ projection_dim (`int`, *optional*, defaults to 0):
192
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
193
+ projection is done.
194
+ text_encoder_base_model (`str`, *optional*, defaults to `"bert-base-uncased"`):
195
+ The text_encoder flmr based on.
196
+ query_maxlen (`int`, *optional*, defaults to 32)
197
+ The max_length for query tokenizer encoding.
198
+
199
+ Example:
200
+
201
+ ```python
202
+ >>> from transformers import FLMRTextConfig, FLMRTextModel
203
+
204
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
205
+ >>> configuration = FLMRTextConfig()
206
+
207
+ >>> # Initializing a model (with random weights) from the LinWeizheDragon/FLMR style configuration
208
+ >>> model = FLMRTextModel(configuration)
209
+
210
+ >>> # Accessing the model configuration
211
+ >>> configuration = model.config
212
+ ```"""
213
+
214
+ model_type = "flmr_text_model"
215
+
216
+ def __init__(
217
+ self,
218
+ vocab_size=30522,
219
+ hidden_size=768,
220
+ num_hidden_layers=12,
221
+ num_attention_heads=12,
222
+ intermediate_size=3072,
223
+ hidden_act="gelu",
224
+ hidden_dropout_prob=0.1,
225
+ attention_probs_dropout_prob=0.1,
226
+ max_position_embeddings=512,
227
+ type_vocab_size=2,
228
+ initializer_range=0.02,
229
+ layer_norm_eps=1e-12,
230
+ pad_token_id=0,
231
+ position_embedding_type="absolute",
232
+ projection_dim: int = 0,
233
+ text_encoder_base_model="bert-base-uncased",
234
+ query_maxlen: int = 32,
235
+ **kwargs,
236
+ ):
237
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
238
+
239
+ self.vocab_size = vocab_size
240
+ self.hidden_size = hidden_size
241
+ self.num_hidden_layers = num_hidden_layers
242
+ self.num_attention_heads = num_attention_heads
243
+ self.hidden_act = hidden_act
244
+ self.intermediate_size = intermediate_size
245
+ self.hidden_dropout_prob = hidden_dropout_prob
246
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
247
+ self.max_position_embeddings = max_position_embeddings
248
+ self.type_vocab_size = type_vocab_size
249
+ self.initializer_range = initializer_range
250
+ self.layer_norm_eps = layer_norm_eps
251
+ self.projection_dim = projection_dim
252
+ self.text_encoder_base_model = text_encoder_base_model
253
+ self.position_embedding_type = position_embedding_type
254
+ self.query_maxlen = query_maxlen
255
+
256
+
257
+ class FLMRConfig(PretrainedConfig):
258
+ r"""
259
+ [`FLMRConfig`] is the configuration class to store the configuration of a *FLMRModelForRetrieval*.
260
+ This is the configuration class to store the configuration of a [`FLMRModelForRetrieval`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
261
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
262
+ configuration to that of the FLMR
263
+ [LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G)
264
+ architecture.
265
+
266
+ Args:
267
+ vision_config (`FLMRVisionConfig`, *optional*):
268
+ Configuration for the vision encoder.
269
+ text_config (`FLMRTextConfig`, *optional*):
270
+ Configuration for the text encoder.
271
+ mask_punctuation (`bool`, *optional*, defaults to `True`):
272
+ Whether to mask punctuation tokens in the input.
273
+ mapping_network_prefix_length (`int`, *optional*, defaults to 32):
274
+ The output length of the linear mapping network.
275
+ dim (`int`, *optional*, defaults to 128):
276
+ The late-interaction dimension of the model. The output of the text encoder, vision encoder, transformer mapping network should all be projected to this dimension for late-interaction scoring.
277
+ use_vision_encoder (`bool`, *optional*, defaults to `True`):
278
+ Whether to load the vision encoder. When no vision encoder is loaded, `image_features` should be used in the forward pass rather than `pixel_values`.
279
+ initializer_range (`float`, *optional*, defaults to 0.02):
280
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
281
+ separate_query_and_context_text_encoder (`bool`, *optional*, defaults to `False`):
282
+ Whether to use separate text encoders for query and context.
283
+ separate_query_and_context_vision_encoder (`bool`, *optional*, defaults to `False`):
284
+ Whether to use separate vision encoders for query and context.
285
+ query_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `True`):
286
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the query.
287
+ query_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
288
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the query.
289
+ context_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `False`):
290
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the context.
291
+ context_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
292
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the context.
293
+ use_transformer_mapping_network (`bool`, *optional*, defaults to `False`):
294
+ Whether to add a transformer mapping network to map the features from the vision encoder to the embedding space. This option is used in PreFLMR.
295
+ transformer_mapping_config_base (`str`, *optional*):
296
+ The base configuration for the transformer mapping network. This option is used in PreFLMR. An example of this argument is `bert-base-uncased`.
297
+ transformer_mapping_num_hidden_layers (`int`, *optional*):
298
+ The number of hidden layers in the transformer mapping network. This option is used in PreFLMR.
299
+ load_cpu_extension (`bool`, *optional*, defaults to `False`):
300
+ Whether to load the CPU extension. Only set this to `True` if a CPU is used in training and inference. In any case, GPU is recommended for training and inference.
301
+ mask_instruction_token (`str`, *optional*):
302
+ The token that indicates the end of the input instruction. All tokens before this token (the first one in a sequence) will be masked. This option is used in PreFLMR.
303
+ transformer_mapping_cross_attention_length (`int`, *optional*, defaults to 32):
304
+ The length of the cross attention in the transformer mapping network. This option is used in PreFLMR.
305
+ vision_model_version (`str`, *optional*, defaults to `"openai/clip-vit-base-patch32"`):
306
+ The version of the vision model being used in this FLMR model.
307
+ This option is used in performing retrieval only. Though it does not affect the model architecture, it is highly recommended to set this argument so that it properly reflects the version of the vision model being used in the FLMR model. This arugment will be saved in the model configuration, and it can be read by the indexing engine. The indexing engine will use this argument to initialize an image processor, which can process the input image files. Find more details under `examples/research_projects/flmr-retrieval`.
308
+ query_mask_input_ids_skip_list (`List`, *optional*, defaults to `[]`):
309
+ The input_ids need to skip when execute query_mask.
310
+
311
+ Example:
312
+
313
+ ```python
314
+ >>> from transformers import FLMRConfig, FLMRModelForRetrieval
315
+
316
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
317
+ >>> configuration = FLMRConfig()
318
+
319
+ >>> # Initializing a model (with random weights) from the FLMR style configuration
320
+ >>> model = FLMRModelForRetrieval(configuration)
321
+
322
+ >>> # Accessing the model configuration
323
+ >>> configuration = model.config
324
+ ```"""
325
+
326
+ model_type = "flmr"
327
+
328
+ def __init__(
329
+ self,
330
+ vision_config: FLMRVisionConfig = None,
331
+ text_config: FLMRTextConfig = None,
332
+ mask_punctuation: bool = True,
333
+ mapping_network_prefix_length: int = 32,
334
+ dim: int = 128,
335
+ use_vision_encoder: bool = True,
336
+ initializer_range: float = 0.02,
337
+ separate_query_and_context_text_encoder: bool = False,
338
+ separate_query_and_context_vision_encoder: bool = False,
339
+ query_concat_output_from_vision_encoder: bool = True,
340
+ query_concat_output_from_text_encoder: bool = True,
341
+ context_concat_output_from_vision_encoder: bool = False,
342
+ context_concat_output_from_text_encoder: bool = True,
343
+ use_transformer_mapping_network: bool = False,
344
+ transformer_mapping_config_base: str = None,
345
+ transformer_mapping_num_hidden_layers: int = None,
346
+ load_cpu_extension: bool = False,
347
+ mask_instruction_token: str = None,
348
+ transformer_mapping_cross_attention_length: int = 32,
349
+ vision_model_version: str = "openai/clip-vit-base-patch32",
350
+ query_mask_input_ids_skip_list: list = [],
351
+ **kwargs,
352
+ ):
353
+ super().__init__(**kwargs)
354
+
355
+ if vision_config is None:
356
+ vision_config = {}
357
+ if text_config is None:
358
+ text_config = {}
359
+
360
+ if not isinstance(vision_config, FLMRVisionConfig):
361
+ vision_config = FLMRVisionConfig(**vision_config)
362
+ if not isinstance(text_config, FLMRTextConfig):
363
+ text_config = FLMRTextConfig(**text_config)
364
+
365
+ self.vision_config = vision_config
366
+ self.text_config = text_config
367
+ self.dim = dim
368
+ self.initializer_range = initializer_range
369
+ self.mask_punctuation = mask_punctuation
370
+ self.mapping_network_prefix_length = mapping_network_prefix_length
371
+ self.use_vision_encoder = use_vision_encoder
372
+ self.separate_query_and_context_text_encoder = separate_query_and_context_text_encoder
373
+ self.separate_query_and_context_vision_encoder = separate_query_and_context_vision_encoder
374
+ self.query_concat_output_from_vision_encoder = query_concat_output_from_vision_encoder
375
+ self.query_concat_output_from_text_encoder = query_concat_output_from_text_encoder
376
+ self.context_concat_output_from_vision_encoder = context_concat_output_from_vision_encoder
377
+ self.context_concat_output_from_text_encoder = context_concat_output_from_text_encoder
378
+ self.use_transformer_mapping_network = use_transformer_mapping_network
379
+ self.transformer_mapping_config_base = transformer_mapping_config_base
380
+ self.transformer_mapping_num_hidden_layers = transformer_mapping_num_hidden_layers
381
+ self.load_cpu_extension = load_cpu_extension
382
+ self.mask_instruction_token = mask_instruction_token
383
+ self.transformer_mapping_cross_attention_length = transformer_mapping_cross_attention_length
384
+ self.vision_model_version = vision_model_version
385
+ self.query_mask_input_ids_skip_list = query_mask_input_ids_skip_list
386
+
387
+ @classmethod
388
+ def from_text_vision_configs(cls, text_config: FLMRTextConfig, vision_config: FLMRVisionConfig, **kwargs):
389
+ r"""
390
+ Instantiate a [`FLMRConfig`] (or a derived class) from FLMR text model configuration and FLMR vision model
391
+ configuration.
392
+
393
+ Returns:
394
+ [`FLMRConfig`]: An instance of a configuration object
395
+ """
396
+
397
+ return cls(text_config=text_config, vision_config=vision_config, **kwargs)
context_tokenizer/sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
context_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
context_tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:249df0778f236f6ece390de0de746838ef25b9d6954b68c2ee71249e0a9d8fd4
3
+ size 17082799
context_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 8192,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "</s>",
52
+ "sp_model_kwargs": {},
53
+ "tokenizer_class": "XLMRobertaTokenizer",
54
+ "unk_token": "<unk>"
55
+ }
flmr_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
3
+ """
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+ def get_rank():
10
+ return dist.get_rank()
11
+
12
+
13
+ def get_world_size():
14
+ return dist.get_world_size()
15
+
16
+
17
+ def get_default_group():
18
+ return dist.group.WORLD
19
+
20
+
21
+ # TODO: The masking below might also be applicable in the kNN part
22
+ def colbert_score_reduce(scores_padded, D_mask):
23
+ # print('D_mask', D_mask.shape, D_mask)
24
+ D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
25
+ # print('D_padding', D_padding.shape, D_padding)
26
+ # print(D_padding[0].tolist())
27
+ scores_padded[D_padding] = -9999
28
+ scores = scores_padded.max(1).values
29
+
30
+ return scores.sum(-1)
31
+
32
+
33
+ def colbert_score(Q, D_padded, D_mask, use_gpu=False):
34
+ """
35
+ Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
36
+ If Q.size(0) is 1, the matrix will be compared with all passages.
37
+ Otherwise, each query matrix will be compared against the *aligned* passage.
38
+
39
+ EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
40
+ """
41
+ if use_gpu:
42
+ Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
43
+ assert Q.dim() == 3, Q.size()
44
+ assert D_padded.dim() == 3, D_padded.size()
45
+ assert Q.size(0) in [1, D_padded.size(0)]
46
+
47
+ scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
48
+
49
+ return colbert_score_reduce(scores, D_mask)
50
+
51
+
52
+ def _sort_by_length(ids, mask, bsize, *args):
53
+ if ids.size(0) <= bsize:
54
+ return ids, mask, torch.arange(ids.size(0))
55
+
56
+ indices = mask.sum(-1).sort().indices
57
+ reverse_indices = indices.sort().indices
58
+
59
+ return_array = [ids[indices], mask[indices]]
60
+ for arg in args:
61
+ if isinstance(arg, torch.Tensor):
62
+ return_array.append(arg[indices])
63
+ else:
64
+ # arg is a list, and we want to sort the list according to indices
65
+ return_array.append([arg[i] for i in indices])
66
+
67
+ return *return_array, reverse_indices
68
+
69
+
70
+ def _split_into_batches(ids, mask, bsize, *args):
71
+ batches = []
72
+ for offset in range(0, ids.size(0), bsize):
73
+ batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
74
+ for arg in args:
75
+ batch.append(arg[offset : offset + bsize])
76
+ batches.append(batch)
77
+ return batches
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59cf447aaa66331417b1b346d8fdde1c21f737884364140dafe8f518e2a01ec6
3
+ size 3530548760
modeling_flmr.py ADDED
@@ -0,0 +1,1527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 FLMR Authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch FLMR model for Knowledge-intensive Visual Question Answering."""
16
+
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+ import string
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ from torch import Tensor, nn
28
+ from torch.utils.cpp_extension import load
29
+
30
+ from transformers import AutoModel, AutoConfig
31
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import (
34
+ ModelOutput,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.models.bert.modeling_bert import BertModel
41
+ from transformers.models.clip import CLIPVisionModel
42
+ from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
43
+ from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
44
+ from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
45
+ from .flmr_utils import (
46
+ colbert_score,
47
+ colbert_score_reduce,
48
+ get_rank,
49
+ get_world_size,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__)
54
+
55
+ _CONFIG_FOR_DOC = "FLMRConfig"
56
+ _CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L"
57
+
58
+
59
+ FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [
60
+ "LinWeizheDragon/PreFLMR_ViT-L",
61
+ "LinWeizheDragon/FLMR",
62
+ # See all FLMR models at https://huggingface.co/models?filter=flmr
63
+ ]
64
+
65
+
66
+ ##########
67
+ # Outputs
68
+ ##########
69
+
70
+
71
+ @dataclass
72
+ class FLMRContextEncoderOutput(ModelOutput):
73
+ """
74
+ Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`].
75
+
76
+ Args:
77
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
78
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation.
79
+ This output can be used to embed questions for nearest neighbors queries with query embeddings.
80
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
81
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
82
+ This output is to be used to embed contexts for late-interaction retrieval with query embeddings.
83
+ context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`):
84
+ The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation.
85
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
86
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
87
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
88
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
89
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
90
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
91
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
92
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
93
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
94
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
95
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
96
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
97
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
98
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
99
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
100
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
101
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
102
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
103
+ """
104
+
105
+ pooler_output: torch.FloatTensor
106
+ late_interaction_output: torch.FloatTensor = None
107
+ context_mask: torch.FloatTensor = None
108
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
109
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
110
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
111
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
112
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
113
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
114
+
115
+
116
+ @dataclass
117
+ class FLMRQueryEncoderOutput(ModelOutput):
118
+ """
119
+ Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`].
120
+
121
+ Args:
122
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
123
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation.
124
+ This output can be used to embed questions for nearest neighbors queries with context embeddings.
125
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
126
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
127
+ This output is to be used to embed questions for late-interaction retrieval with context embeddings.
128
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
129
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
130
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
131
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
132
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
133
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
134
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
135
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
136
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
137
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
138
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
139
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
140
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
141
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
142
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
143
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
144
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
145
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
146
+ """
147
+
148
+ pooler_output: torch.FloatTensor
149
+ late_interaction_output: torch.FloatTensor = None
150
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
151
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
152
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
153
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
154
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
155
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
156
+
157
+
158
+ @dataclass
159
+ class FLMRModelForRetrievalOutput(ModelOutput):
160
+ """
161
+ Class for outputs of [`FLMRModelForRetrieval.query()`].
162
+
163
+ Args:
164
+ loss (`torch.FloatTensor`):
165
+ contrastive loss of the input queries and positive and negative examples. This output is to be used in model training.
166
+ scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`):
167
+ The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples.
168
+ in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
169
+ The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training.
170
+ query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
171
+ The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query.
172
+ context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
173
+ The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context.
174
+ query_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
175
+ Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
176
+ query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
177
+ Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
178
+ context_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
179
+ Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
180
+ context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
181
+ Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
182
+ """
183
+
184
+ loss: torch.FloatTensor
185
+ scores: torch.FloatTensor = None
186
+ in_batch_negative_loss: torch.FloatTensor = None
187
+ query_late_interaction_output: torch.FloatTensor = None
188
+ context_late_interaction_output: torch.FloatTensor = None
189
+ query_attentions: Optional[Tuple[Tuple[Tensor]]] = None
190
+ query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
191
+ context_attentions: Optional[Tuple[Tuple[Tensor]]] = None
192
+ context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
193
+
194
+
195
+ class FLMRPreTrainedModel(PreTrainedModel):
196
+ def _init_weights(self, module):
197
+ """Initialize the weights"""
198
+ if isinstance(module, nn.Linear):
199
+ # Slightly different from the TF version which uses truncated_normal for initialization
200
+ # cf https://github.com/pytorch/pytorch/pull/5617
201
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
202
+ if module.bias is not None:
203
+ module.bias.data.zero_()
204
+ elif isinstance(module, nn.Embedding):
205
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
206
+ if module.padding_idx is not None:
207
+ module.weight.data[module.padding_idx].zero_()
208
+ elif isinstance(module, nn.LayerNorm):
209
+ module.bias.data.zero_()
210
+ module.weight.data.fill_(1.0)
211
+
212
+
213
+ ##################
214
+ # PreTrainedModel
215
+ ##################
216
+
217
+
218
+ class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel):
219
+ """
220
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
221
+ models.
222
+ """
223
+
224
+ config_class = FLMRConfig
225
+ load_tf_weights = None
226
+ base_model_prefix = "flmr"
227
+
228
+
229
+ ###############
230
+ # Actual Models
231
+ ###############
232
+
233
+
234
+ FLMR_START_DOCSTRING = r"""
235
+
236
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
237
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
238
+ etc.)
239
+
240
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
241
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
242
+ and behavior.
243
+
244
+ Parameters:
245
+ config ([`FLMRConfig`]): Model configuration class with all the parameters of the model.
246
+ Initializing with a config file does not load the weights associated with the model, only the
247
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
248
+ query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query.
249
+ The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
250
+ context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context.
251
+ The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
252
+ """
253
+
254
+
255
+ FLMR_MODEL_INPUTS_DOCSTRING = r"""
256
+ Args:
257
+ query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
258
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
259
+ formatted with [CLS] and Q marker tokens as follows:
260
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
261
+
262
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
263
+ rather than the left.
264
+
265
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
266
+ [`PreTrainedTokenizer.__call__`] for details.
267
+
268
+ [What are input IDs?](../glossary#input-ids)
269
+ query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
270
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
271
+
272
+ - 1 for tokens that are **not masked**,
273
+ - 0 for tokens that are **masked**.
274
+
275
+ [What are attention masks?](../glossary#attention-mask)
276
+ query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
277
+ Pixel values. Pixel values can be obtained using
278
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
279
+ query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
280
+ Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
281
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
282
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
283
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
284
+ formatted with [CLS] and D marker tokens as follows:
285
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
286
+
287
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
288
+ rather than the left.
289
+
290
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
291
+ [`PreTrainedTokenizer.__call__`] for details.
292
+
293
+ [What are input IDs?](../glossary#input-ids)
294
+
295
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
296
+
297
+ context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
298
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
299
+
300
+ - 1 for tokens that are **not masked**,
301
+ - 0 for tokens that are **masked**.
302
+
303
+ [What are attention masks?](../glossary#attention-mask)
304
+
305
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
306
+ context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
307
+ Pixel values. Pixel values can be obtained using
308
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
309
+ context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
310
+ Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
311
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
312
+ use_in_batch_negatives (`bool`, *optional*):
313
+ Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training.
314
+ in_batch_negatives_from_all_gpus (`bool`, *optional*):
315
+ Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training.
316
+ num_negative_examples (`int`, *optional*):
317
+ The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`.
318
+ query_concat_output_from_vision_encoder (`bool`, *optional*):
319
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
320
+ query_concat_output_from_text_encoder (`bool`, *optional*):
321
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
322
+
323
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
324
+ context_concat_output_from_vision_encoder (`bool`, *optional*):
325
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
326
+
327
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
328
+ context_concat_output_from_text_encoder (`bool`, *optional*):
329
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
330
+ return_dict (`bool`, *optional*):
331
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
332
+ output_attentions (`bool`, *optional*):
333
+ Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned
334
+ tensors for more detail.
335
+ output_hidden_states (`bool`, *optional*):
336
+ Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail.
337
+ """
338
+
339
+
340
+ FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r"""
341
+ Args:
342
+ input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
343
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
344
+ formatted with [CLS] and Q marker tokens as follows:
345
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
346
+
347
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
348
+ rather than the left.
349
+
350
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
351
+ [`PreTrainedTokenizer.__call__`] for details.
352
+
353
+ [What are input IDs?](../glossary#input-ids)
354
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
355
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
356
+
357
+ - 1 for tokens that are **not masked**,
358
+ - 0 for tokens that are **masked**.
359
+
360
+ [What are attention masks?](../glossary#attention-mask)
361
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
362
+ Pixel values. Pixel values can be obtained using
363
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
364
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
365
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
366
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
367
+ concat_output_from_vision_encoder (`bool`, *optional*):
368
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
369
+ concat_output_from_text_encoder (`bool`, *optional*):
370
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
371
+
372
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
373
+ """
374
+
375
+
376
+ FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r"""
377
+ Args:
378
+ input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
379
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
380
+ formatted with [CLS] and D marker tokens as follows:
381
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
382
+
383
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
384
+ rather than the left.
385
+
386
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
387
+ [`PreTrainedTokenizer.__call__`] for details.
388
+
389
+ [What are input IDs?](../glossary#input-ids)
390
+
391
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
392
+ attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
393
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
394
+
395
+ - 1 for tokens that are **not masked**,
396
+ - 0 for tokens that are **masked**.
397
+
398
+ [What are attention masks?](../glossary#attention-mask)
399
+
400
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
401
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
402
+ Pixel values. Pixel values can be obtained using
403
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
404
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
405
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
406
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel
407
+ .__call__`] for details.
408
+ concat_output_from_vision_encoder (`bool`, *optional*):
409
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
410
+
411
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
412
+ concat_output_from_text_encoder (`bool`, *optional*):
413
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
414
+ keep_dims (`bool`, *optional*):
415
+ Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training.
416
+ return_mask (`bool`, *optional*):
417
+ Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training.
418
+ """
419
+
420
+
421
+ FLMR_TEXT_ENCODERS_START_DOCSTRING = r"""
422
+
423
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
424
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
425
+ etc.)
426
+
427
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
428
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
429
+ and behavior.
430
+
431
+ Parameters:
432
+ config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model.
433
+ Initializing with a config file does not load the weights associated with the model, only the
434
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
435
+ """
436
+
437
+
438
+ # Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR
439
+ FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r"""
440
+ Args:
441
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
442
+ Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be
443
+ formatted with [CLS] and [SEP] tokens as follows:
444
+
445
+ (a) For sequence pairs (for a pair title+text for example):
446
+
447
+ ```
448
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
449
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
450
+ ```
451
+
452
+ (b) For single sequences (for a question for example):
453
+
454
+ ```
455
+ tokens: [CLS] the dog is hairy . [SEP]
456
+ token_type_ids: 0 0 0 0 0 0 0
457
+ ```
458
+
459
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
460
+ rather than the left.
461
+
462
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
463
+ [`PreTrainedTokenizer.__call__`] for details.
464
+
465
+ [What are input IDs?](../glossary#input-ids)
466
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
467
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
468
+
469
+ - 1 for tokens that are **not masked**,
470
+ - 0 for tokens that are **masked**.
471
+
472
+ [What are attention masks?](../glossary#attention-mask)
473
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
474
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
475
+ 1]`:
476
+
477
+ - 0 corresponds to a *sentence A* token,
478
+ - 1 corresponds to a *sentence B* token.
479
+
480
+ [What are token type IDs?](../glossary#token-type-ids)
481
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
482
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
483
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
484
+ model's internal embedding lookup matrix.
485
+ output_attentions (`bool`, *optional*):
486
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
487
+ tensors for more detail.
488
+ output_hidden_states (`bool`, *optional*):
489
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
490
+ more detail.
491
+ return_dict (`bool`, *optional*):
492
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
493
+ """
494
+
495
+ FLMR_VISION_ENCODERS_START_DOCSTRING = r"""
496
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
497
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
498
+ etc.)
499
+
500
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
501
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
502
+ and behavior.
503
+
504
+ Parameters:
505
+ config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model.
506
+ Initializing with a config file does not load the weights associated with the model, only the
507
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
508
+ """
509
+
510
+ # Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR
511
+ FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r"""
512
+ Args:
513
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
514
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
515
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
516
+ output_attentions (`bool`, *optional*):
517
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
+ tensors for more detail.
519
+ output_hidden_states (`bool`, *optional*):
520
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
+ more detail.
522
+ return_dict (`bool`, *optional*):
523
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
524
+ """
525
+
526
+
527
+ class FLMRMultiLayerPerceptron(nn.Module):
528
+ """
529
+ A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model.
530
+ """
531
+
532
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
533
+ return self.model(x)
534
+
535
+ def __init__(self, sizes, bias=True, act=nn.Tanh):
536
+ super(FLMRMultiLayerPerceptron, self).__init__()
537
+ layers = []
538
+ for i in range(len(sizes) - 1):
539
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
540
+ if i < len(sizes) - 2:
541
+ layers.append(act())
542
+ self.model = nn.Sequential(*layers)
543
+
544
+
545
+ @add_start_docstrings(
546
+ "The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ",
547
+ FLMR_START_DOCSTRING,
548
+ )
549
+ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
550
+ _keys_to_ignore_on_load_unexpected = [r"cls"]
551
+ main_input_name = "query_input_ids"
552
+ _tied_weights_keys = [] # Added dynamically at initialization depending on the architecture
553
+
554
+ def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None):
555
+ super().__init__(config)
556
+ self.config = config
557
+ self.vision_model_version = config.vision_model_version
558
+
559
+ self.context_text_encoder = FLMRTextModel(config.text_config)
560
+ self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)
561
+
562
+ self.query_tokenizer = query_tokenizer
563
+ self.context_tokenizer = context_tokenizer
564
+
565
+ if self.query_tokenizer is None:
566
+ logger.warning(
567
+ "query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
568
+ )
569
+
570
+ # initialize a FLMRQueryEncoderTokenizer
571
+ self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")
572
+
573
+ if self.context_tokenizer is None:
574
+ logger.warning(
575
+ "context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
576
+ )
577
+
578
+ # initialize a FLMRContextEncoderTokenizer
579
+ self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")
580
+
581
+ self.mapping_network_prefix_length = self.config.mapping_network_prefix_length
582
+ self.vision_encoder_embedding_size = self.config.vision_config.hidden_size
583
+ self.text_encoder_embedding_size = self.config.text_config.hidden_size
584
+ self.late_interaction_embedding_size = self.config.dim
585
+
586
+ if self.config.use_vision_encoder:
587
+ self.context_vision_projection = FLMRMultiLayerPerceptron(
588
+ (
589
+ self.vision_encoder_embedding_size,
590
+ (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
591
+ self.late_interaction_embedding_size * self.mapping_network_prefix_length,
592
+ )
593
+ )
594
+
595
+ if self.config.use_vision_encoder:
596
+ self.context_vision_encoder = FLMRVisionModel(config.vision_config)
597
+
598
+ if self.config.use_transformer_mapping_network:
599
+ # This is a PreFLMR style model
600
+ transformer_mapping_config_base = self.config.transformer_mapping_config_base
601
+ try:
602
+ from transformers import BertConfig
603
+ from transformers.models.bert.modeling_bert import BertEncoder
604
+ except Exception as e:
605
+ raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}")
606
+
607
+ transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
608
+
609
+ assert (
610
+ self.config.text_config.hidden_size == transformer_mapping_config.hidden_size
611
+ ), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match."
612
+ # shallow transformer
613
+ transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers
614
+ # add cross attention
615
+ transformer_mapping_config.is_decoder = True
616
+ transformer_mapping_config.add_cross_attention = True
617
+
618
+ # The linear layer from vision encoder to transformer input
619
+ self.transformer_mapping_input_linear = nn.Linear(
620
+ self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size
621
+ )
622
+
623
+ # The transformer encoder
624
+ self.transformer_mapping_network = BertEncoder(transformer_mapping_config)
625
+
626
+ # The linear layer from transformer output to FLMR dim
627
+ self.transformer_mapping_output_linear = nn.Linear(
628
+ transformer_mapping_config.hidden_size, self.late_interaction_embedding_size
629
+ )
630
+
631
+ if self.config.separate_query_and_context_text_encoder:
632
+ self.query_text_encoder = copy.deepcopy(self.context_text_encoder)
633
+ self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear)
634
+ else:
635
+ self.query_text_encoder = self.context_text_encoder
636
+ self.query_text_encoder_linear = self.context_text_encoder_linear
637
+ self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
638
+
639
+ if self.config.use_vision_encoder:
640
+ if self.config.separate_query_and_context_vision_encoder:
641
+ self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
642
+ self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
643
+ else:
644
+ self.query_vision_encoder = self.context_vision_encoder
645
+ self.query_vision_projection = self.context_vision_projection
646
+ self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
647
+
648
+ if self.config.load_cpu_extension:
649
+ try:
650
+ FLMRModelForRetrieval.try_load_torch_extensions()
651
+ except Exception as e:
652
+ raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}")
653
+
654
+ if self.config.mask_punctuation:
655
+ self.skiplist = {
656
+ w: True
657
+ for symbol in string.punctuation
658
+ for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
659
+ }
660
+
661
+ if self.config.mask_instruction_token is not None:
662
+ self.mask_instruction = True
663
+ # obtain the token id of the instruction token
664
+ self.instruction_token_id = self.query_tokenizer.encode(
665
+ self.config.mask_instruction_token, add_special_tokens=False
666
+ )[0]
667
+ else:
668
+ self.mask_instruction = False
669
+
670
+ self.loss_fn = torch.nn.CrossEntropyLoss()
671
+
672
+ # Initialize weights and apply final processing
673
+ self.post_init()
674
+
675
+ @property
676
+ def use_gpu(self):
677
+ return self.device.type == "cuda"
678
+
679
+ @classmethod
680
+ def from_pretrained(self, name_or_path, **kwargs):
681
+ obj = super().from_pretrained(name_or_path, **kwargs)
682
+ return obj
683
+
684
+ @classmethod
685
+ def try_load_torch_extensions(cls):
686
+ if hasattr(cls, "loaded_extensions"):
687
+ return
688
+
689
+ logger.info(
690
+ "Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..."
691
+ )
692
+ segmented_maxsim_cpp = load(
693
+ name="segmented_maxsim_cpp",
694
+ sources=[
695
+ os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"),
696
+ ],
697
+ extra_cflags=["-O3"],
698
+ verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
699
+ )
700
+ cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp
701
+
702
+ cls.loaded_extensions = True
703
+
704
+ def query_mask(self, input_ids, skiplist):
705
+ if not self.mask_instruction:
706
+ return self.mask(input_ids, skiplist)
707
+
708
+ # find the position of end of instruction in input_ids
709
+ # mask the tokens before the position
710
+ sep_id = self.instruction_token_id
711
+ sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
712
+ # if any of the positions is lower than 1, set to 1
713
+ for i, x in enumerate(sep_positions):
714
+ if x < 1:
715
+ sep_positions[i] = 1
716
+ logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
717
+ mask = [
718
+ [
719
+ (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
720
+ for index, x in enumerate(d)
721
+ ]
722
+ for seq_index, d in enumerate(input_ids.cpu().tolist())
723
+ ]
724
+ return mask
725
+
726
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING)
727
+ @replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
728
+ def forward(
729
+ self,
730
+ query_input_ids: Optional[torch.Tensor] = None,
731
+ query_attention_mask: Optional[torch.Tensor] = None,
732
+ query_pixel_values: Optional[torch.Tensor] = None,
733
+ query_image_features: Optional[torch.Tensor] = None,
734
+ context_input_ids: Optional[torch.Tensor] = None,
735
+ context_attention_mask: Optional[torch.Tensor] = None,
736
+ context_pixel_values: Optional[torch.Tensor] = None,
737
+ context_image_features: Optional[torch.Tensor] = None,
738
+ use_in_batch_negatives: bool = True,
739
+ in_batch_negatives_from_all_gpus: bool = False,
740
+ num_negative_examples: int = 1,
741
+ query_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
742
+ query_concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
743
+ context_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
744
+ context_concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
745
+ return_dict: bool = None,
746
+ output_attentions: bool = None,
747
+ output_hidden_states: bool = None,
748
+ ) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]:
749
+ r"""
750
+ Return:
751
+
752
+ Examples:
753
+
754
+ ```python
755
+ >>> import torch
756
+ >>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor
757
+
758
+ >>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L"
759
+ >>> image_processor_name = "openai/clip-vit-large-patch14"
760
+ >>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer")
761
+ >>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer")
762
+
763
+ >>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path,
764
+ query_tokenizer=query_tokenizer,
765
+ context_tokenizer=context_tokenizer,
766
+ )
767
+ >>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
768
+
769
+ >>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"])
770
+ >>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.",
771
+ "Paris is the capital of France.", "Beijing is the capital of China."])
772
+ >>> Q_pixel_values = torch.zeros(2, 3, 224, 224)
773
+ >>> inputs = dict(
774
+ query_input_ids=Q_encoding['input_ids'],
775
+ query_attention_mask=Q_encoding['attention_mask'],
776
+ query_pixel_values=Q_pixel_values,
777
+ context_input_ids=D_encoding['input_ids'],
778
+ context_attention_mask=D_encoding['attention_mask'],
779
+ use_in_batch_negatives=True,
780
+ )
781
+
782
+ >>> model.forward(**inputs)
783
+ FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16,
784
+ grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562],
785
+ [39.4375, 48.4062]], device='cuda:0', dtype=torch.float16,
786
+ grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...)
787
+ ```
788
+ """
789
+
790
+ if query_concat_output_from_vision_encoder is None:
791
+ query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
792
+
793
+ if query_concat_output_from_text_encoder is None:
794
+ query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
795
+
796
+ if context_concat_output_from_vision_encoder is None:
797
+ context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
798
+
799
+ if context_concat_output_from_text_encoder is None:
800
+ context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
801
+
802
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
803
+ output_hidden_states = (
804
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
805
+ )
806
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
807
+
808
+ query_outputs = self.query(
809
+ input_ids=query_input_ids,
810
+ attention_mask=query_attention_mask,
811
+ pixel_values=query_pixel_values,
812
+ image_features=query_image_features,
813
+ concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
814
+ concat_output_from_text_encoder=query_concat_output_from_text_encoder,
815
+ output_attentions=output_attentions,
816
+ output_hidden_states=output_hidden_states,
817
+ )
818
+ Q = query_outputs.late_interaction_output
819
+
820
+ context_outputs = self.doc(
821
+ input_ids=context_input_ids,
822
+ attention_mask=context_attention_mask,
823
+ pixel_values=context_pixel_values,
824
+ image_features=context_image_features,
825
+ concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
826
+ concat_output_from_text_encoder=context_concat_output_from_text_encoder,
827
+ keep_dims=True,
828
+ return_mask=True,
829
+ output_attentions=output_attentions,
830
+ output_hidden_states=output_hidden_states,
831
+ )
832
+ D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
833
+
834
+ # Gather tensors from other GPUs
835
+ if in_batch_negatives_from_all_gpus:
836
+ Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask)
837
+ # Repeat each query encoding for every corresponding document.
838
+ Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous()
839
+
840
+ scores = self.score(Q_duplicated, D, D_mask)
841
+
842
+ # Use contrastive learning
843
+ batch_size = query_input_ids.shape[0]
844
+ scores = scores.view(-1, num_negative_examples + 1)
845
+ labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
846
+
847
+
848
+ if use_in_batch_negatives:
849
+ ib_loss = self.compute_ib_loss_new(Q, D, D_mask)
850
+ loss = ib_loss
851
+ else:
852
+ loss = self.loss_fn(scores, labels)
853
+ ib_loss = None
854
+
855
+ if output_attentions:
856
+ query_attentions = (
857
+ query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None,
858
+ query_outputs.vision_encoder_attentions
859
+ if query_outputs.vision_encoder_attentions is not None
860
+ else None,
861
+ query_outputs.transformer_mapping_network_attentions
862
+ if query_outputs.transformer_mapping_network_attentions is not None
863
+ else None,
864
+ )
865
+ context_attentions = (
866
+ context_outputs.text_encoder_attentions
867
+ if context_outputs.text_encoder_attentions is not None
868
+ else None,
869
+ context_outputs.vision_encoder_attentions
870
+ if context_outputs.vision_encoder_attentions is not None
871
+ else None,
872
+ context_outputs.transformer_mapping_network_attentions
873
+ if context_outputs.transformer_mapping_network_attentions is not None
874
+ else None,
875
+ )
876
+ else:
877
+ query_attentions = None
878
+ context_attentions = None
879
+
880
+ if output_hidden_states:
881
+ query_hidden_states = (
882
+ query_outputs.text_encoder_hidden_states
883
+ if query_outputs.text_encoder_hidden_states is not None
884
+ else None,
885
+ query_outputs.vision_encoder_hidden_states
886
+ if query_outputs.vision_encoder_hidden_states is not None
887
+ else None,
888
+ query_outputs.transformer_mapping_network_hidden_states
889
+ if query_outputs.transformer_mapping_network_hidden_states is not None
890
+ else None,
891
+ )
892
+ context_hidden_states = (
893
+ context_outputs.text_encoder_hidden_states
894
+ if context_outputs.text_encoder_hidden_states is not None
895
+ else None,
896
+ context_outputs.vision_encoder_hidden_states
897
+ if context_outputs.vision_encoder_hidden_states is not None
898
+ else None,
899
+ context_outputs.transformer_mapping_network_hidden_states
900
+ if context_outputs.transformer_mapping_network_hidden_states is not None
901
+ else None,
902
+ )
903
+ else:
904
+ query_hidden_states = None
905
+ context_hidden_states = None
906
+
907
+ if not return_dict:
908
+ if output_attentions and output_hidden_states:
909
+ return (
910
+ loss,
911
+ scores,
912
+ ib_loss,
913
+ query_outputs.late_interaction_output,
914
+ context_outputs.late_interaction_output,
915
+ query_attentions,
916
+ query_hidden_states,
917
+ context_attentions,
918
+ context_hidden_states,
919
+ )
920
+ elif output_attentions:
921
+ return (
922
+ loss,
923
+ scores,
924
+ ib_loss,
925
+ query_outputs.late_interaction_output,
926
+ context_outputs.late_interaction_output,
927
+ query_attentions,
928
+ context_attentions,
929
+ )
930
+ elif output_hidden_states:
931
+ return (
932
+ loss,
933
+ scores,
934
+ ib_loss,
935
+ query_outputs.late_interaction_output,
936
+ context_outputs.late_interaction_output,
937
+ query_hidden_states,
938
+ context_hidden_states,
939
+ )
940
+ else:
941
+ return (
942
+ loss,
943
+ scores,
944
+ ib_loss,
945
+ query_outputs.late_interaction_output,
946
+ context_outputs.late_interaction_output,
947
+ )
948
+
949
+ return FLMRModelForRetrievalOutput(
950
+ loss=loss,
951
+ scores=scores,
952
+ in_batch_negative_loss=ib_loss,
953
+ query_late_interaction_output=query_outputs.late_interaction_output,
954
+ context_late_interaction_output=context_outputs.late_interaction_output,
955
+ query_attentions=query_attentions if output_attentions else None,
956
+ query_hidden_states=query_hidden_states if output_hidden_states else None,
957
+ context_attentions=context_attentions if output_attentions else None,
958
+ context_hidden_states=context_hidden_states if output_hidden_states else None,
959
+ )
960
+
961
+ def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
962
+ # Q: batch_size x q_len x dim
963
+ # D: batch_size*n_docs x i_len x dim
964
+ # D_mask: batch_size*n_docs x i_len x dim
965
+ # 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim
966
+ # = batch_size x batch_size*n_docs x i_len x q_len
967
+
968
+ scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten(
969
+ 0, 1
970
+ ) # query-major unsqueeze
971
+ scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1))
972
+
973
+ in_batch_scores = scores.reshape(Q.size(0), -1)
974
+
975
+ batch_size = Q.shape[0]
976
+ batch_size_with_pos_and_neg = D.shape[0]
977
+ num_pos_and_neg = batch_size_with_pos_and_neg // batch_size
978
+
979
+ # batch_size x dim matmul dim x (num_pos+num_neg)*batch_size
980
+ # --> batch_size x (num_pos+num_neg)*batch_size
981
+ in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device)
982
+ step = num_pos_and_neg
983
+ for i in range(batch_size):
984
+ in_batch_labels[i, step * i] = 1
985
+ # print('in_batch_labels', in_batch_labels)
986
+ in_batch_labels = torch.argmax(in_batch_labels, dim=1)
987
+ # print('in_batch_labels', in_batch_labels)
988
+
989
+ loss = self.loss_fn(in_batch_scores, in_batch_labels)
990
+
991
+ return loss
992
+
993
+ def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask):
994
+ # print("get rank", get_rank())
995
+ # print("get world size", get_world_size())
996
+ # Gather embeddings from other GPUs
997
+ n_nodes = get_world_size()
998
+ if n_nodes == 1:
999
+ return query_embeddings, item_embeddings, item_mask
1000
+ # Create placeholder to hold embeddings passed from other ranks
1001
+ global_query_embeddings_placeholder = [
1002
+ torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device)
1003
+ for _ in range(n_nodes)
1004
+ ]
1005
+ global_item_embeddings_placeholder = [
1006
+ torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device)
1007
+ for _ in range(n_nodes)
1008
+ ]
1009
+ global_item_mask_placeholder = [
1010
+ torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes)
1011
+ ]
1012
+ dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach())
1013
+ dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach())
1014
+ dist.all_gather(global_item_mask_placeholder, item_mask.detach())
1015
+
1016
+ global_query_embeddings = []
1017
+ global_item_embeddings = []
1018
+ global_item_mask = []
1019
+ # print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings)
1020
+ # print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings)
1021
+ # input()
1022
+ current_rank = get_rank()
1023
+ for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder):
1024
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1025
+ if rank_index != current_rank:
1026
+ global_query_embeddings.append(remote_q_embeddings)
1027
+ else:
1028
+ global_query_embeddings.append(query_embeddings)
1029
+
1030
+ for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder):
1031
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1032
+ if rank_index != current_rank:
1033
+ global_item_embeddings.append(remote_item_embeddings)
1034
+ else:
1035
+ global_item_embeddings.append(item_embeddings)
1036
+
1037
+ for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder):
1038
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1039
+ if rank_index != current_rank:
1040
+ global_item_mask.append(remote_item_mask)
1041
+ else:
1042
+ global_item_mask.append(item_mask)
1043
+
1044
+ # Replace the previous variables with gathered tensors
1045
+ query_embeddings = torch.cat(global_query_embeddings)
1046
+ item_embeddings = torch.cat(global_item_embeddings)
1047
+ item_mask = torch.cat(global_item_mask)
1048
+
1049
+ return query_embeddings, item_embeddings, item_mask
1050
+
1051
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING)
1052
+ @replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC)
1053
+ def query(
1054
+ self,
1055
+ input_ids: torch.Tensor,
1056
+ attention_mask: torch.Tensor,
1057
+ pixel_values: Optional[torch.Tensor] = None,
1058
+ image_features: Optional[torch.Tensor] = None,
1059
+ concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
1060
+ concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
1061
+ output_attentions: Optional[bool] = None,
1062
+ output_hidden_states: Optional[bool] = None,
1063
+ ):
1064
+ r"""
1065
+ Returns:
1066
+
1067
+ """
1068
+
1069
+ if concat_output_from_vision_encoder is None:
1070
+ concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
1071
+
1072
+ if concat_output_from_text_encoder is None:
1073
+ concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
1074
+
1075
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1076
+ output_hidden_states = (
1077
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1078
+ )
1079
+
1080
+ input_modality = []
1081
+ if pixel_values is not None or image_features is not None:
1082
+ input_modality.append("image")
1083
+ if input_ids is not None and attention_mask is not None:
1084
+ input_modality.append("text")
1085
+
1086
+ text_encoder_outputs = None
1087
+ vision_encoder_outputs = None
1088
+ transformer_mapping_outputs = None
1089
+
1090
+ if "image" in input_modality:
1091
+ assert (
1092
+ pixel_values is not None or image_features is not None
1093
+ ), "pixel_values or image_features must be provided if image modality is used"
1094
+ assert (
1095
+ pixel_values is None or image_features is None
1096
+ ), "pixel_values and image_features cannot be provided at the same time"
1097
+
1098
+ if "text" in input_modality:
1099
+ assert (
1100
+ input_ids is not None and attention_mask is not None
1101
+ ), "input_ids and attention_mask must be provided if text modality is used"
1102
+ # Forward the text encoder
1103
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1104
+ text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask)
1105
+ text_encoder_hidden_states = text_encoder_outputs[0]
1106
+ text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states)
1107
+ mask = torch.tensor(self.query_mask(input_ids, skiplist=self.config.query_mask_input_ids_skip_list), device=self.device).unsqueeze(2).float()
1108
+
1109
+ text_embeddings = text_embeddings * mask
1110
+
1111
+ if "image" in input_modality:
1112
+ if pixel_values is not None:
1113
+ batch_size = pixel_values.shape[0]
1114
+ # Forward the vision encoder
1115
+ pixel_values = pixel_values.to(self.device)
1116
+ if len(pixel_values.shape) == 5:
1117
+ # Multiple ROIs are provided
1118
+ # merge the first two dimensions
1119
+ pixel_values = pixel_values.reshape(
1120
+ -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
1121
+ )
1122
+ vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True)
1123
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1124
+
1125
+ if image_features is not None:
1126
+ batch_size = image_features.shape[0]
1127
+ vision_embeddings = image_features.to(self.device)
1128
+
1129
+ # Forward the vision projection / mapping network
1130
+ vision_embeddings = self.query_vision_projection(vision_embeddings)
1131
+ vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size)
1132
+
1133
+ if self.config.use_transformer_mapping_network:
1134
+ # select the second last layer
1135
+ vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
1136
+ # transformer_mapping
1137
+ transformer_mapping_input_features = self.transformer_mapping_input_linear(
1138
+ vision_second_last_layer_hidden_states
1139
+ )
1140
+
1141
+ # Cross attention only attends to the first 32 tokens
1142
+ encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
1143
+ if len(self.config.query_mask_input_ids_skip_list) > 0:
1144
+ encoder_mask[torch.isin(input_ids, torch.tensor(self.config.query_mask_input_ids_skip_list))] = 0
1145
+ cross_attention_length = self.config.transformer_mapping_cross_attention_length
1146
+ if text_encoder_hidden_states.shape[1] > cross_attention_length:
1147
+ text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
1148
+ encoder_mask = encoder_mask[:, :cross_attention_length]
1149
+
1150
+ # Obtain cross attention mask
1151
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1))
1152
+ # Pass through the transformer mapping
1153
+ transformer_mapping_outputs = self.transformer_mapping_network(
1154
+ transformer_mapping_input_features,
1155
+ encoder_hidden_states=text_encoder_hidden_states,
1156
+ encoder_attention_mask=encoder_extended_attention_mask,
1157
+ )
1158
+ transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
1159
+ # Convert the dimension to FLMR dim
1160
+ transformer_mapping_output_features = self.transformer_mapping_output_linear(
1161
+ transformer_mapping_output_features
1162
+ )
1163
+ # Merge with the vision embeddings
1164
+ vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)
1165
+
1166
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1167
+ Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
1168
+ if isinstance(concat_output_from_vision_encoder, list) or isinstance(concat_output_from_text_encoder, list):
1169
+ # When lists are passed in, mask the output accordingly
1170
+ assert isinstance(concat_output_from_vision_encoder, list) and isinstance(concat_output_from_text_encoder, list), "concat_output_from_vision_encoder and concat_output_from_text_encoder must be of the same type."
1171
+ # obtain the size of each output
1172
+ text_size = text_embeddings.shape[1]
1173
+ vision_size = vision_embeddings.shape[1]
1174
+
1175
+ # Prepare the mask
1176
+ concat_output_mask = torch.zeros_like(Q).to(Q.device)
1177
+
1178
+ # Mask the late interaction outputs
1179
+ concat_output_mask[:, :text_size] = torch.tensor(concat_output_from_text_encoder).bool().unsqueeze(-1).unsqueeze(-1)
1180
+ concat_output_mask[:, text_size:] = torch.tensor(concat_output_from_vision_encoder).bool().unsqueeze(-1).unsqueeze(-1)
1181
+
1182
+ Q = Q * concat_output_mask
1183
+
1184
+ elif concat_output_from_vision_encoder:
1185
+ Q = vision_embeddings
1186
+ elif concat_output_from_text_encoder:
1187
+ Q = text_embeddings
1188
+
1189
+ vision_encoder_attentions = (
1190
+ vision_encoder_outputs.attentions
1191
+ if vision_encoder_outputs is not None
1192
+ and hasattr(vision_encoder_outputs, "attentions")
1193
+ and output_attentions
1194
+ else None
1195
+ )
1196
+ vision_encoder_hidden_states = (
1197
+ vision_encoder_outputs.hidden_states
1198
+ if vision_encoder_outputs is not None
1199
+ and hasattr(vision_encoder_outputs, "hidden_states")
1200
+ and output_hidden_states
1201
+ else None
1202
+ )
1203
+ text_encoder_attentions = (
1204
+ text_encoder_outputs.attentions
1205
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1206
+ else None
1207
+ )
1208
+ text_encoder_hidden_states = (
1209
+ text_encoder_outputs.hidden_states
1210
+ if text_encoder_outputs is not None
1211
+ and hasattr(text_encoder_outputs, "hidden_states")
1212
+ and output_hidden_states
1213
+ else None
1214
+ )
1215
+ transformer_mapping_network_attentions = (
1216
+ transformer_mapping_outputs.attentions
1217
+ if transformer_mapping_outputs is not None
1218
+ and hasattr(transformer_mapping_outputs, "attentions")
1219
+ and output_attentions
1220
+ else None
1221
+ )
1222
+ transformer_mapping_network_hidden_states = (
1223
+ transformer_mapping_outputs.hidden_states
1224
+ if transformer_mapping_outputs is not None
1225
+ and hasattr(transformer_mapping_outputs, "hidden_states")
1226
+ and output_hidden_states
1227
+ else None
1228
+ )
1229
+
1230
+ return FLMRQueryEncoderOutput(
1231
+ pooler_output=Q[:, 0, :],
1232
+ late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
1233
+ vision_encoder_attentions=vision_encoder_attentions,
1234
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1235
+ text_encoder_attentions=text_encoder_attentions,
1236
+ text_encoder_hidden_states=text_encoder_hidden_states,
1237
+ transformer_mapping_network_attentions=transformer_mapping_network_attentions,
1238
+ transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
1239
+ )
1240
+
1241
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING)
1242
+ @replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
1243
+ def doc(
1244
+ self,
1245
+ input_ids: torch.Tensor,
1246
+ attention_mask: torch.Tensor,
1247
+ pixel_values: Optional[torch.Tensor] = None,
1248
+ image_features: Optional[torch.Tensor] = None,
1249
+ concat_output_from_vision_encoder: Optional[bool] = None,
1250
+ concat_output_from_text_encoder: Optional[bool] = None,
1251
+ keep_dims: Optional[bool] = True,
1252
+ return_mask: Optional[bool] = True,
1253
+ output_attentions: Optional[bool] = None,
1254
+ output_hidden_states: Optional[bool] = None,
1255
+ ):
1256
+ r"""
1257
+ Returns:
1258
+
1259
+ """
1260
+ assert keep_dims in [True, False]
1261
+
1262
+ if concat_output_from_vision_encoder is None:
1263
+ concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
1264
+
1265
+ if concat_output_from_text_encoder is None:
1266
+ concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
1267
+
1268
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1269
+ output_hidden_states = (
1270
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1271
+ )
1272
+
1273
+ input_modality = []
1274
+ if pixel_values is not None or image_features is not None:
1275
+ input_modality.append("image")
1276
+ if input_ids is not None and attention_mask is not None:
1277
+ input_modality.append("text")
1278
+
1279
+ text_encoder_outputs = None
1280
+ vision_encoder_outputs = None
1281
+
1282
+ if "image" in input_modality:
1283
+ assert (
1284
+ pixel_values is not None or image_features is not None
1285
+ ), "pixel_values or image_features must be provided if image modality is used"
1286
+ assert (
1287
+ pixel_values is None or image_features is None
1288
+ ), "pixel_values and image_features cannot be provided at the same time"
1289
+
1290
+ if "text" in input_modality:
1291
+ assert (
1292
+ input_ids is not None and attention_mask is not None
1293
+ ), "input_ids and attention_mask must be provided if text modality is used"
1294
+ # Forward the text encoder
1295
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1296
+ text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask)
1297
+ text_embeddings = text_encoder_outputs[0]
1298
+ text_embeddings = self.context_text_encoder_linear(text_embeddings)
1299
+
1300
+ mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
1301
+ text_embeddings = text_embeddings * mask
1302
+
1303
+ if "image" in input_modality:
1304
+ if pixel_values is not None:
1305
+ # Forward the vision encoder
1306
+ pixel_values = pixel_values.to(self.device)
1307
+ vision_encoder_outputs = self.context_vision_encoder(pixel_values)
1308
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1309
+
1310
+ if image_features is not None:
1311
+ vision_embeddings = image_features.to(self.device)
1312
+
1313
+ batch_size = vision_embeddings.shape[0]
1314
+
1315
+ # Forward the vision projection / mapping network
1316
+ vision_embeddings = self.context_vision_projection(vision_embeddings)
1317
+ vision_embeddings = vision_embeddings.view(
1318
+ -1, self.mapping_network_prefix_length, self.late_interaction_embedding_size
1319
+ )
1320
+
1321
+ image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device)
1322
+
1323
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1324
+ # Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask
1325
+ # TODO: fix the engine to support masks with discontinuous 0 and 1.
1326
+ D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1327
+ # concatenate the mask
1328
+ mask = torch.cat([image_mask, mask], dim=1)
1329
+ elif concat_output_from_vision_encoder:
1330
+ D = vision_embeddings
1331
+ mask = image_mask
1332
+ elif concat_output_from_text_encoder:
1333
+ D = text_embeddings
1334
+ mask = mask
1335
+
1336
+ D = torch.nn.functional.normalize(D, p=2, dim=2)
1337
+
1338
+ if self.use_gpu:
1339
+ D = D.half()
1340
+
1341
+ if keep_dims is False:
1342
+ D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
1343
+ D = [d[mask[idx]] for idx, d in enumerate(D)]
1344
+
1345
+ vision_encoder_attentions = (
1346
+ vision_encoder_outputs.attentions
1347
+ if vision_encoder_outputs is not None
1348
+ and hasattr(vision_encoder_outputs, "attentions")
1349
+ and output_attentions
1350
+ else None
1351
+ )
1352
+ vision_encoder_hidden_states = (
1353
+ vision_encoder_outputs.hidden_states
1354
+ if vision_encoder_outputs is not None
1355
+ and hasattr(vision_encoder_outputs, "hidden_states")
1356
+ and output_hidden_states
1357
+ else None
1358
+ )
1359
+ text_encoder_attentions = (
1360
+ text_encoder_outputs.attentions
1361
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1362
+ else None
1363
+ )
1364
+ text_encoder_hidden_states = (
1365
+ text_encoder_outputs.hidden_states
1366
+ if text_encoder_outputs is not None
1367
+ and hasattr(text_encoder_outputs, "hidden_states")
1368
+ and output_hidden_states
1369
+ else None
1370
+ )
1371
+
1372
+ return FLMRContextEncoderOutput(
1373
+ pooler_output=D[:, 0, :],
1374
+ late_interaction_output=D,
1375
+ context_mask=mask.bool() if return_mask else None,
1376
+ vision_encoder_attentions=vision_encoder_attentions,
1377
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1378
+ text_encoder_attentions=text_encoder_attentions,
1379
+ text_encoder_hidden_states=text_encoder_hidden_states,
1380
+ )
1381
+
1382
+ def score(self, Q, D_padded, D_mask):
1383
+ # assert self.colbert_config.similarity == 'cosine'
1384
+ # if self.colbert_config.similarity == 'l2':
1385
+ # assert self.colbert_config.interaction == 'colbert'
1386
+ # return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
1387
+ return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu)
1388
+
1389
+ def mask(self, input_ids, skiplist):
1390
+ mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
1391
+ return mask
1392
+
1393
+
1394
+ @add_start_docstrings(
1395
+ "The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.",
1396
+ FLMR_TEXT_ENCODERS_START_DOCSTRING,
1397
+ )
1398
+ class FLMRTextModel(FLMRPreTrainedModel):
1399
+ base_model_prefix = "flmr_text_model"
1400
+ config_class = FLMRTextConfig
1401
+
1402
+ def __init__(self, config: FLMRTextConfig, *args, **kwargs):
1403
+ super().__init__(config)
1404
+ if config.text_encoder_base_model == "bert-base-uncased":
1405
+ self.bert_model = BertModel(config, add_pooling_layer=True)
1406
+ else:
1407
+ self.bert_model = AutoModel.from_pretrained(config.text_encoder_base_model, *args, **kwargs)
1408
+ if self.bert_model.config.hidden_size <= 0:
1409
+ raise ValueError("Encoder hidden_size can't be zero")
1410
+ self.projection_dim = config.projection_dim
1411
+ if self.projection_dim > 0:
1412
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
1413
+ # Initialize weights and apply final processing
1414
+ self.post_init()
1415
+ self.text_model = self.bert_model
1416
+
1417
+ @add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING)
1418
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig)
1419
+ def forward(
1420
+ self,
1421
+ input_ids: Optional[Tensor] = None,
1422
+ attention_mask: Optional[Tensor] = None,
1423
+ token_type_ids: Optional[Tensor] = None,
1424
+ inputs_embeds: Optional[Tensor] = None,
1425
+ output_attentions: bool = None,
1426
+ output_hidden_states: bool = None,
1427
+ return_dict: bool = None,
1428
+ ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
1429
+ r"""
1430
+ Returns:
1431
+
1432
+ """
1433
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1434
+ output_hidden_states = (
1435
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1436
+ )
1437
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1438
+
1439
+ outputs = self.text_model(
1440
+ input_ids=input_ids,
1441
+ attention_mask=attention_mask,
1442
+ token_type_ids=token_type_ids,
1443
+ inputs_embeds=inputs_embeds,
1444
+ output_attentions=output_attentions,
1445
+ output_hidden_states=output_hidden_states,
1446
+ return_dict=return_dict,
1447
+ )
1448
+ sequence_output = outputs[0]
1449
+ pooled_output = sequence_output[:, 0, :]
1450
+
1451
+ if self.projection_dim > 0:
1452
+ pooled_output = self.encode_proj(pooled_output)
1453
+
1454
+ if not return_dict:
1455
+ return (sequence_output, pooled_output) + outputs[2:]
1456
+
1457
+ return BaseModelOutputWithPooling(
1458
+ last_hidden_state=sequence_output,
1459
+ pooler_output=pooled_output,
1460
+ hidden_states=outputs.hidden_states,
1461
+ attentions=outputs.attentions,
1462
+ )
1463
+
1464
+ @property
1465
+ def embeddings_size(self) -> int:
1466
+ if self.projection_dim > 0:
1467
+ return self.encode_proj.out_features
1468
+ return self.text_model.config.hidden_size
1469
+
1470
+
1471
+ @add_start_docstrings(
1472
+ "The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.",
1473
+ FLMR_VISION_ENCODERS_START_DOCSTRING,
1474
+ )
1475
+ class FLMRVisionModel(FLMRPreTrainedModel):
1476
+ base_model_prefix = "vision_model"
1477
+ config_class = FLMRVisionConfig
1478
+ main_input_name = "pixel_values"
1479
+ _no_split_modules = ["CLIPEncoderLayer"]
1480
+
1481
+ def __init__(self, config: FLMRVisionConfig):
1482
+ super().__init__(config)
1483
+ self.vision_model = CLIPVisionModel(config)
1484
+ self.post_init()
1485
+
1486
+ def get_input_embeddings(self) -> nn.Module:
1487
+ return self.vision_model.vision_model.embeddings.patch_embedding
1488
+
1489
+ @add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING)
1490
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig)
1491
+ def forward(
1492
+ self,
1493
+ pixel_values: Optional[torch.FloatTensor] = None,
1494
+ output_attentions: Optional[bool] = None,
1495
+ output_hidden_states: Optional[bool] = None,
1496
+ return_dict: Optional[bool] = None,
1497
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1498
+ r"""
1499
+ Returns:
1500
+
1501
+ Examples:
1502
+
1503
+ ```python
1504
+ >>> from PIL import Image
1505
+ >>> import requests
1506
+ >>> from transformers import AutoProcessor, FLMRVisionModel
1507
+
1508
+ >>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1509
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1510
+
1511
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1512
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1513
+
1514
+ >>> inputs = processor(images=image, return_tensors="pt")
1515
+
1516
+ >>> outputs = model(**inputs)
1517
+ >>> last_hidden_state = outputs.last_hidden_state
1518
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1519
+ ```"""
1520
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1521
+
1522
+ return self.vision_model(
1523
+ pixel_values=pixel_values,
1524
+ output_attentions=output_attentions,
1525
+ output_hidden_states=output_hidden_states,
1526
+ return_dict=return_dict,
1527
+ )
query_tokenizer/sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
query_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
query_tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:249df0778f236f6ece390de0de746838ef25b9d6954b68c2ee71249e0a9d8fd4
3
+ size 17082799
query_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 8192,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "</s>",
52
+ "sp_model_kwargs": {},
53
+ "tokenizer_class": "XLMRobertaTokenizer",
54
+ "unk_token": "<unk>"
55
+ }
segmented_maxsim.cpp ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pthread.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include <algorithm>
5
+ #include <numeric>
6
+
7
+ typedef struct {
8
+ int tid;
9
+ int nthreads;
10
+
11
+ int ndocs;
12
+ int ndoc_vectors;
13
+ int nquery_vectors;
14
+
15
+ int64_t* lengths;
16
+ float* scores;
17
+ int64_t* offsets;
18
+
19
+ float* max_scores;
20
+ } max_args_t;
21
+
22
+ void* max(void* args) {
23
+ max_args_t* max_args = (max_args_t*)args;
24
+
25
+ int ndocs_per_thread =
26
+ std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27
+ int start = max_args->tid * ndocs_per_thread;
28
+ int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29
+
30
+ auto max_scores_offset =
31
+ max_args->max_scores + (start * max_args->nquery_vectors);
32
+ auto scores_offset =
33
+ max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34
+
35
+ for (int i = start; i < end; i++) {
36
+ for (int j = 0; j < max_args->lengths[i]; j++) {
37
+ std::transform(max_scores_offset,
38
+ max_scores_offset + max_args->nquery_vectors,
39
+ scores_offset, max_scores_offset,
40
+ [](float a, float b) { return std::max(a, b); });
41
+ scores_offset += max_args->nquery_vectors;
42
+ }
43
+ max_scores_offset += max_args->nquery_vectors;
44
+ }
45
+
46
+ return NULL;
47
+ }
48
+
49
+ torch::Tensor segmented_maxsim(const torch::Tensor scores,
50
+ const torch::Tensor lengths) {
51
+ auto lengths_a = lengths.data_ptr<int64_t>();
52
+ auto scores_a = scores.data_ptr<float>();
53
+ auto ndocs = lengths.size(0);
54
+ auto ndoc_vectors = scores.size(0);
55
+ auto nquery_vectors = scores.size(1);
56
+ auto nthreads = at::get_num_threads();
57
+
58
+ torch::Tensor max_scores =
59
+ torch::zeros({ndocs, nquery_vectors}, scores.options());
60
+
61
+ int64_t offsets[ndocs + 1];
62
+ offsets[0] = 0;
63
+ std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64
+
65
+ pthread_t threads[nthreads];
66
+ max_args_t args[nthreads];
67
+
68
+ for (int i = 0; i < nthreads; i++) {
69
+ args[i].tid = i;
70
+ args[i].nthreads = nthreads;
71
+
72
+ args[i].ndocs = ndocs;
73
+ args[i].ndoc_vectors = ndoc_vectors;
74
+ args[i].nquery_vectors = nquery_vectors;
75
+
76
+ args[i].lengths = lengths_a;
77
+ args[i].scores = scores_a;
78
+ args[i].offsets = offsets;
79
+
80
+ args[i].max_scores = max_scores.data_ptr<float>();
81
+
82
+ int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83
+ if (rc) {
84
+ fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85
+ }
86
+ }
87
+
88
+ for (int i = 0; i < nthreads; i++) {
89
+ pthread_join(threads[i], NULL);
90
+ }
91
+
92
+ return max_scores.sum(1);
93
+ }
94
+
95
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96
+ m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97
+ }
tokenization_flmr.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+ from transformers import AutoTokenizer
23
+ from .configuration_flmr import FLMRTextConfig
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
28
+
29
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
30
+ "vocab_file": {
31
+ "LinWeizheDragon/PreFLMR_ViT-L": (
32
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
33
+ ),
34
+ "LinWeizheDragon/FLMR": (
35
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
36
+ ),
37
+ },
38
+ "tokenizer_file": {
39
+ "LinWeizheDragon/PreFLMR_ViT-L": (
40
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
41
+ ),
42
+ "LinWeizheDragon/FLMR": (
43
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
44
+ ),
45
+ },
46
+ }
47
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
48
+ "vocab_file": {
49
+ "LinWeizheDragon/PreFLMR_ViT-L": (
50
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
51
+ ),
52
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
53
+ },
54
+ "tokenizer_file": {
55
+ "LinWeizheDragon/PreFLMR_ViT-L": (
56
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
57
+ ),
58
+ "LinWeizheDragon/FLMR": (
59
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
60
+ ),
61
+ },
62
+ }
63
+
64
+
65
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
66
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
67
+ "LinWeizheDragon/FLMR": 512,
68
+ }
69
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
70
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
71
+ "LinWeizheDragon/FLMR": 512,
72
+ }
73
+
74
+
75
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
76
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
77
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
78
+ }
79
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
80
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
81
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
82
+ }
83
+
84
+
85
+ # Modified from colbert.modeling.tokenization
86
+ class FLMRBertContextEncoderTokenizer(BertTokenizer):
87
+ r"""
88
+ Construct a FLMRContextEncoder tokenizer.
89
+
90
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
91
+ splitting and wordpiece.
92
+
93
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
94
+ """
95
+
96
+ vocab_files_names = VOCAB_FILES_NAMES
97
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
98
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
99
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
100
+
101
+ def __init__(
102
+ self,
103
+ doc_maxlen: Optional[int] = 512,
104
+ **kwargs,
105
+ ):
106
+ super().__init__(
107
+ doc_maxlen=doc_maxlen,
108
+ **kwargs,
109
+ )
110
+
111
+ self.doc_maxlen = doc_maxlen
112
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
113
+
114
+ def __call__(
115
+ self,
116
+ text: List[str],
117
+ padding: Optional[Union[str, bool]] = "max_length",
118
+ truncation: Optional[Union[bool, str]] = "longest_first",
119
+ max_length: Optional[int] = 512,
120
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
121
+ **kwargs,
122
+ ):
123
+ # add placehold for the [D] marker
124
+ text = [". " + x for x in text]
125
+
126
+ if max_length > self.doc_maxlen:
127
+ # can not exceed the pre-set length
128
+ max_length = self.doc_maxlen
129
+
130
+ encoding = super().__call__(
131
+ text,
132
+ padding=padding,
133
+ truncation=truncation,
134
+ return_tensors=return_tensors,
135
+ max_length=max_length,
136
+ **kwargs,
137
+ )
138
+
139
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
140
+
141
+ # postprocess for the [D] marker
142
+ ids[:, 1] = self.D_marker_token_id
143
+
144
+ # if bsize:
145
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
146
+ # if image_features is not None:
147
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
148
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
149
+ # else:
150
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
151
+ # batches = _split_into_batches(ids, mask, bsize)
152
+
153
+ # return batches, reverse_indices
154
+
155
+ encoding["input_ids"] = ids
156
+ encoding["attention_mask"] = mask
157
+
158
+ return encoding
159
+
160
+
161
+ # Modified from colbert.modeling.tokenization
162
+ class FLMRBertQueryEncoderTokenizer(BertTokenizer):
163
+ r"""
164
+ Constructs a FLMRQueryEncoder tokenizer.
165
+
166
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
167
+ splitting and wordpiece.
168
+
169
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
170
+ """
171
+
172
+ vocab_files_names = VOCAB_FILES_NAMES
173
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
174
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
175
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
176
+
177
+ def __init__(
178
+ self,
179
+ *args,
180
+ query_maxlen: Optional[int] = 32,
181
+ attend_to_mask_tokens: Optional[bool] = False,
182
+ **kwargs,
183
+ ):
184
+ super().__init__(
185
+ *args,
186
+ query_maxlen=query_maxlen,
187
+ attend_to_mask_tokens=attend_to_mask_tokens,
188
+ **kwargs,
189
+ )
190
+
191
+ self.query_maxlen = query_maxlen
192
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
193
+ self.attend_to_mask_tokens = attend_to_mask_tokens
194
+
195
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
196
+
197
+ def __call__(
198
+ self,
199
+ text: Union[str, List[str]],
200
+ padding: Optional[Union[str, bool]] = "max_length",
201
+ truncation: Optional[Union[bool, str]] = True,
202
+ max_length: Optional[int] = None,
203
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
204
+ **kwargs,
205
+ ):
206
+ if isinstance(text, str):
207
+ # convert to list if input is a single string
208
+ text = [text]
209
+
210
+ # add placehold for the [Q] marker
211
+ text = [". " + x for x in text]
212
+
213
+ if max_length is not None:
214
+ # use user specified max_length
215
+ pass
216
+ else:
217
+ # use default max length
218
+ max_length = self.query_maxlen
219
+
220
+ encoding = super().__call__(
221
+ text,
222
+ padding=padding,
223
+ truncation=truncation,
224
+ return_tensors=return_tensors,
225
+ max_length=max_length,
226
+ **kwargs,
227
+ )
228
+
229
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
230
+
231
+ # postprocess for the [Q] marker and the [MASK] augmentation
232
+ ids[:, 1] = self.Q_marker_token_id
233
+ ids[ids == self.pad_token_id] = self.mask_token_id
234
+
235
+ if self.attend_to_mask_tokens:
236
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
237
+ mask[ids == self.mask_token_id] = 1
238
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
239
+
240
+ return {"input_ids": ids, "attention_mask": mask}
241
+
242
+ class FLMRAutoContextEncoderTokenizer:
243
+ r"""
244
+ Construct a ContextEncoderTokenizer tokenizer with AutoTokenizer.
245
+
246
+ [`FLMRAutoContextEncoderTokenizer`] is identical to [`AutoTokenizer`] and runs end-to-end tokenization: punctuation
247
+ splitting and wordpiece.
248
+
249
+ Refer to superclass [`AutoTokenizer`] for usage examples and documentation concerning parameters.
250
+ """
251
+
252
+ vocab_files_names = VOCAB_FILES_NAMES
253
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
254
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
255
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
256
+
257
+ def __init__(
258
+ self,
259
+ *args,
260
+ doc_maxlen: Optional[int] = 512,
261
+ **kwargs,
262
+ ):
263
+ self.doc_maxlen = doc_maxlen
264
+ self.tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
265
+ self.additional_special_tokens = self.tokenizer.additional_special_tokens
266
+
267
+
268
+ def __call__(
269
+ self,
270
+ text: List[str],
271
+ padding: Optional[Union[str, bool]] = "max_length",
272
+ truncation: Optional[Union[bool, str]] = "longest_first",
273
+ max_length: Optional[int] = 512,
274
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
275
+ **kwargs,
276
+ ):
277
+ # add placehold for the [D] marker
278
+ text = [". " + x for x in text]
279
+
280
+ if max_length > self.doc_maxlen:
281
+ # can not exceed the pre-set length
282
+ max_length = self.doc_maxlen
283
+
284
+ encoding = self.tokenizer(
285
+ text,
286
+ padding=padding,
287
+ truncation=True,
288
+ return_tensors=return_tensors,
289
+ max_length=max_length,
290
+ **kwargs,
291
+ )
292
+
293
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
294
+
295
+ encoding["input_ids"] = ids
296
+ encoding["attention_mask"] = mask
297
+
298
+ return encoding
299
+
300
+ def encode(self, text, text_pair=None, add_special_tokens=True, **kwargs):
301
+ return self.tokenizer.encode(text, text_pair, add_special_tokens, **kwargs)
302
+
303
+ def add_special_tokens(self, token, **kwargs):
304
+ return self.tokenizer.add_special_tokens(token, **kwargs)
305
+
306
+ def save_pretrained(self, path):
307
+ self.tokenizer.save_pretrained(path)
308
+
309
+
310
+ # Modified from colbert.modeling.tokenization
311
+ class FLMRAutoQueryEncoderTokenizer:
312
+ r"""
313
+ Constructs a QueryEncoderTokenizer tokenizer with AutoTokenizer.
314
+
315
+ [`FLMRAutoQueryEncoderTokenizer`] is identical to [`AutoTokenizer`] and runs end-to-end tokenization: punctuation
316
+ splitting and wordpiece.
317
+
318
+ Refer to superclass [`AutoTokenizer`] for usage examples and documentation concerning parameters.
319
+ """
320
+
321
+ vocab_files_names = VOCAB_FILES_NAMES
322
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
323
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
324
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
325
+
326
+ def __init__(
327
+ self,
328
+ *args,
329
+ query_maxlen: Optional[int] = 32,
330
+ attend_to_mask_tokens: Optional[bool] = False,
331
+ **kwargs,
332
+ ):
333
+ self.tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
334
+ self.additional_special_tokens = self.tokenizer.additional_special_tokens
335
+ self.query_maxlen = query_maxlen
336
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
337
+ self.attend_to_mask_tokens = attend_to_mask_tokens
338
+
339
+ def __call__(
340
+ self,
341
+ text: Union[str, List[str]],
342
+ padding: Optional[Union[str, bool]] = "max_length",
343
+ truncation: Optional[Union[bool, str]] = True,
344
+ max_length: Optional[int] = None,
345
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
346
+ **kwargs,
347
+ ):
348
+ if isinstance(text, str):
349
+ # convert to list if input is a single string
350
+ text = [text]
351
+
352
+ # add placehold for the [Q] marker
353
+ text = [". " + x for x in text]
354
+
355
+ if max_length is not None:
356
+ # use user specified max_length
357
+ pass
358
+ else:
359
+ # use default max length
360
+ max_length = self.query_maxlen
361
+
362
+ encoding = self.tokenizer(
363
+ text,
364
+ padding=padding,
365
+ truncation=True,
366
+ return_tensors=return_tensors,
367
+ max_length=max_length,
368
+ **kwargs,
369
+ )
370
+
371
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
372
+
373
+ if self.attend_to_mask_tokens:
374
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
375
+ mask[ids == self.mask_token_id] = 1
376
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
377
+
378
+ return {"input_ids": ids, "attention_mask": mask}
379
+
380
+ def encode(self, text, text_pair=None, add_special_tokens=True, **kwargs):
381
+ return self.tokenizer.encode(text, text_pair, add_special_tokens, **kwargs)
382
+
383
+ def add_special_tokens(self, token, **kwargs):
384
+ return self.tokenizer.add_special_tokens(token, **kwargs)
385
+
386
+ def save_pretrained(self, path):
387
+ self.tokenizer.save_pretrained(path)
388
+
389
+
390
+ class FLMRContextEncoderTokenizer:
391
+ r"""
392
+ Constructs a FLMRContextEncoderTokenizer tokenizer.
393
+
394
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] or [`AutoTokenizer`], depends on whether
395
+ the tokenizer is initialized from bert.
396
+ """
397
+ def __init__(self) -> None:
398
+ pass
399
+
400
+ @classmethod
401
+ def from_pretrained(
402
+ cls,
403
+ *args,
404
+ text_config: Optional[FLMRTextConfig] = None,
405
+ **kwargs,
406
+ ):
407
+ if text_config.text_encoder_base_model == "bert-base-uncased":
408
+ return FLMRBertContextEncoderTokenizer.from_pretrained(*args, **kwargs)
409
+ else:
410
+ return FLMRAutoContextEncoderTokenizer(*args, **kwargs)
411
+
412
+ class FLMRQueryEncoderTokenizer:
413
+ r"""
414
+ Constructs a FLMRContextEncoderTokenizer tokenizer.
415
+
416
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] or [`AutoTokenizer`], depends on whether
417
+ the tokenizer is initialized from bert.
418
+ """
419
+ def __init__(self) -> None:
420
+ pass
421
+
422
+ @classmethod
423
+ def from_pretrained(
424
+ cls,
425
+ *args,
426
+ text_config: Optional[FLMRTextConfig] = None,
427
+ **kwargs,
428
+ ):
429
+ if text_config.text_encoder_base_model == "bert-base-uncased":
430
+ return FLMRBertQueryEncoderTokenizer.from_pretrained(*args, **kwargs)
431
+ else:
432
+ return FLMRAutoQueryEncoderTokenizer(*args, query_maxlen=text_config.query_maxlen, **kwargs)