LinWeizheDragon
commited on
Commit
•
66ae8fc
1
Parent(s):
4430fe1
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +84 -0
- config.json +56 -0
- configuration_flmr.py +397 -0
- context_tokenizer/sentencepiece.bpe.model +3 -0
- context_tokenizer/special_tokens_map.json +51 -0
- context_tokenizer/tokenizer.json +3 -0
- context_tokenizer/tokenizer_config.json +55 -0
- flmr_utils.py +77 -0
- model.safetensors +3 -0
- modeling_flmr.py +1527 -0
- query_tokenizer/sentencepiece.bpe.model +3 -0
- query_tokenizer/special_tokens_map.json +51 -0
- query_tokenizer/tokenizer.json +3 -0
- query_tokenizer/tokenizer_config.json +55 -0
- segmented_maxsim.cpp +97 -0
- tokenization_flmr.py +432 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
context_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
query_tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
license: mit
|
4 |
+
language:
|
5 |
+
- en
|
6 |
+
- zh
|
7 |
+
tags:
|
8 |
+
- retrieval
|
9 |
+
- multi-modal
|
10 |
+
- knowledge-based visual question answering
|
11 |
+
- FLMR
|
12 |
+
- PreFLMR
|
13 |
+
---
|
14 |
+
|
15 |
+
# PreFLMR model card
|
16 |
+
|
17 |
+
PreFLMR is an open-source model for multimodal knowledge retrieval. It is a transformer-based model that uses a combination of text and image inputs to retrieve relevant documents from a large corpus.
|
18 |
+
|
19 |
+
## Model Details
|
20 |
+
|
21 |
+
PreFLMR_ViT-L_ENCN is based on PreFLMR_ViT-L, and the text_encoder is replaced with [bge-m3](https://huggingface.co/BAAI/bge-m3) for training. The training dataset includes [Chinese](https://huggingface.co/datasets/BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR_CN) and [English](https://huggingface.co/datasets/BByrneLab/multi_task_multi_modal_knowledge_retrieval_benchmark_M2KR) datasets.
|
22 |
+
|
23 |
+
### Model Description
|
24 |
+
|
25 |
+
- **Model type:** FLMRModelForRetrieval
|
26 |
+
- **Language(s) (NLP):** English Chinese
|
27 |
+
- **License:** MIT License
|
28 |
+
|
29 |
+
### Paper and resources for more detail
|
30 |
+
|
31 |
+
- **Blog Post for quick overview:** https://www.jinghong-chen.net/preflmr-sota-open-sourced-multi/
|
32 |
+
- **Paper:** https://arxiv.org/abs/2402.08327
|
33 |
+
- **Gradio Demo:** https://u60544-b8d4-53eaa55d.westx.seetacloud.com:8443/
|
34 |
+
- **Repository:** https://github.com/LinWeizheDragon/FLMR
|
35 |
+
- **Project Page:** https://preflmr.github.io/
|
36 |
+
|
37 |
+
## Uses
|
38 |
+
|
39 |
+
### Direct Use
|
40 |
+
|
41 |
+
This model can be used directly to retrieve documents from a large corpus using a combination of text and image input queries. The retrieval usage can be found in the [official implementation](https://github.com/LinWeizheDragon/FLMR).
|
42 |
+
|
43 |
+
### Downstream Use
|
44 |
+
|
45 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
46 |
+
|
47 |
+
This model can be used combined with language models to create a retrieval-augmented language model. The use for Knowledge-based VQA can be found in [RAVQA](https://github.com/linweizhedragon/retrieval-augmented-visual-question-answering)
|
48 |
+
|
49 |
+
## How to Get Started with the Model
|
50 |
+
|
51 |
+
For details of training, indexing, and performing retrieval, please refer to [here](https://github.com/LinWeizheDragon/FLMR).
|
52 |
+
|
53 |
+
## Training datasets
|
54 |
+
|
55 |
+
The model is pre-trained on three types of tasks with a total of nine datasets:
|
56 |
+
1. Image to Text retrieval: WIT, KVQA, and CC3M
|
57 |
+
2. Question to Text retrieval: MSMARCO
|
58 |
+
3. Image & Question to Text retrieval: LLaVA, OVEN, OKVQA, Infoseek and E-VQA
|
59 |
+
|
60 |
+
These datasets were converted to retrieval format. For details on the dataset split and conversion process, please refer to the paper [PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers](https://arxiv.org/abs/2402.08327). We will release the proprocessed datasets soon.
|
61 |
+
|
62 |
+
|
63 |
+
## Evaluation datasets
|
64 |
+
We evaluate our models on WIT, LLaVA, OVEN, KVQA, IGLUE (subset of WIT), Infoseek, E-VQA, OKVQA and MSMARCO.
|
65 |
+
| Model | Vision Encoder | Text Encoder | Checkpoint Name | No. Param. | WIT(EN) | WIT(CN) | LLaVA(EN) | LLaVA(CN) | OVEN(EN) | OVEN(CN) | KVQA(EN) | KVQA(EN) | IGLUE(EN) | Infoseek(EN) | Infoseek(CN) | EVQA(EN) | EVQA(CN) | OKVQA(EN) | OKVQA(CN) | MSMARCO(EN) | MSMARCO(CN) |
|
66 |
+
| ------- | :------------- | ------------ | ------------------------------------------------------------ | ---------- | ------- | ------- | --------- | --------- | -------- | -------- | -------- | -------- | --------- | ------------ | ------------ | -------- | -------- | --------- | --------- | ----------- | ----------- |
|
67 |
+
| PreFLMR | ViT-L | Base-v2 | [LinWeizheDragon/PreFLMR_ViT-L](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L) | 543M | 60.5 | 10.9 | 71.8 | 3.2 | 59.8 | 6.6 | 43.6 | 3.2 | 69.2 | 57.9 | 7.9 | 70.8 | 2.8 | 68.5 | 2.1 | 78.7 | 10.3 |
|
68 |
+
| PreFLMR | Vit-L_ENCN | bge-m3 | [LinWeizheDragon/PreFLMR_ViT-L_ENCN](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L_ENCN) | 883M | 60.8 | 83.4 | 71.11 | 58.93 | 60.8 | 58.83 | 41.05 | 37.27 | | 41.91 | 39.70 | 57.97 | 46.64 | 13.87 | 13.32 | 82.6 | 82.33 |
|
69 |
+
|
70 |
+
For the evaluation metrics, WIT uses Recall@10, IGLUE uses Recall@1, and all the rest datasets use Recall@5.
|
71 |
+
|
72 |
+
|
73 |
+
## Citation
|
74 |
+
|
75 |
+
**BibTeX:**
|
76 |
+
```
|
77 |
+
@article{Lin_Mei_Chen_Byrne_2024,
|
78 |
+
title={PreFLMR: Scaling Up Fine-Grained Late-Interaction Multi-modal Retrievers},
|
79 |
+
url={http://arxiv.org/abs/2402.08327},
|
80 |
+
number={arXiv:2402.08327},
|
81 |
+
publisher={arXiv},
|
82 |
+
author={Lin, Weizhe and Mei, Jingbiao and Chen, Jinghong and Byrne, Bill},
|
83 |
+
year={2024}}
|
84 |
+
```
|
config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "LinWeizheDragon/PreFLMR_ViT-L_ENCN",
|
3 |
+
"architectures": [
|
4 |
+
"FLMRModelForRetrieval"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_flmr.FLMRConfig",
|
8 |
+
"AutoModel": "modeling_flmr.FLMRModelForRetrieval"
|
9 |
+
},
|
10 |
+
"context_concat_output_from_text_encoder": true,
|
11 |
+
"context_concat_output_from_vision_encoder": false,
|
12 |
+
"dim": 128,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"load_cpu_extension": true,
|
15 |
+
"mapping_network_prefix_length": 32,
|
16 |
+
"mask_instruction_token": "\n:",
|
17 |
+
"mask_punctuation": true,
|
18 |
+
"model_type": "flmr",
|
19 |
+
"query_concat_output_from_text_encoder": true,
|
20 |
+
"query_concat_output_from_vision_encoder": true,
|
21 |
+
"query_mask_input_ids_skip_list": [
|
22 |
+
1
|
23 |
+
],
|
24 |
+
"separate_query_and_context_text_encoder": false,
|
25 |
+
"separate_query_and_context_vision_encoder": false,
|
26 |
+
"text_config": {
|
27 |
+
"architectures": [
|
28 |
+
"XLMRobertaModel"
|
29 |
+
],
|
30 |
+
"gradient_checkpointing": false,
|
31 |
+
"hidden_size": 1024,
|
32 |
+
"model_type": "flmr_text_model",
|
33 |
+
"projection_dim": 1024,
|
34 |
+
"text_encoder_base_model": "BAAI/bge-m3",
|
35 |
+
"query_maxlen": 64,
|
36 |
+
"use_cache": true
|
37 |
+
},
|
38 |
+
"torch_dtype": "float32",
|
39 |
+
"transformer_mapping_config_base": "google-bert/bert-large-uncased",
|
40 |
+
"transformer_mapping_cross_attention_length": 32,
|
41 |
+
"transformer_mapping_num_hidden_layers": 1,
|
42 |
+
"transformers_version": "4.44.2",
|
43 |
+
"use_transformer_mapping_network": false,
|
44 |
+
"use_vision_encoder": true,
|
45 |
+
"vision_config": {
|
46 |
+
"dropout": 0.0,
|
47 |
+
"hidden_size": 1024,
|
48 |
+
"intermediate_size": 4096,
|
49 |
+
"model_type": "flmr_vision_model",
|
50 |
+
"num_attention_heads": 16,
|
51 |
+
"num_hidden_layers": 24,
|
52 |
+
"patch_size": 14,
|
53 |
+
"projection_dim": 768
|
54 |
+
},
|
55 |
+
"vision_model_version": "openai/clip-vit-large-patch14"
|
56 |
+
}
|
configuration_flmr.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2010, FLMR authors, The Hugging Face Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" FLMR model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
FLMR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"LinWeizheDragon/PreFLMR_ViT-L": "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/config.json",
|
28 |
+
"LinWeizheDragon/FLMR": "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/config.json",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
# Modified from transformers.models.clip.configuration_clip.CLIPVisionConfig with CLIP -> FLMR
|
33 |
+
class FLMRVisionConfig(PretrainedConfig):
|
34 |
+
r"""
|
35 |
+
This is the configuration class to store the configuration of a [`FLMRVisionModel`]. It is used to instantiate a
|
36 |
+
FLMR vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
37 |
+
configuration with the defaults will yield a similar configuration to that of the vision encoder of the FLMR
|
38 |
+
[openai/flmr-vit-base-patch32](https://huggingface.co/openai/flmr-vit-base-patch32) architecture.
|
39 |
+
|
40 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
41 |
+
documentation from [`PretrainedConfig`] for more information.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
45 |
+
Dimensionality of the encoder layers and the pooler layer.
|
46 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
47 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
48 |
+
projection_dim (`int`, *optional*, defaults to 512):
|
49 |
+
Dimentionality of text and vision projection layers.
|
50 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
51 |
+
Number of hidden layers in the Transformer encoder.
|
52 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
53 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
54 |
+
num_channels (`int`, *optional*, defaults to 3):
|
55 |
+
The number of input channels.
|
56 |
+
image_size (`int`, *optional*, defaults to 224):
|
57 |
+
The size (resolution) of each image.
|
58 |
+
patch_size (`int`, *optional*, defaults to 32):
|
59 |
+
The size (resolution) of each patch.
|
60 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
61 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
62 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
63 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
64 |
+
The epsilon used by the layer normalization layers.
|
65 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
66 |
+
The dropout ratio for the attention probabilities.
|
67 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
68 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
69 |
+
initializer_factor (`float`, *optional*, defaults to 1.0):
|
70 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
71 |
+
testing).
|
72 |
+
|
73 |
+
Example:
|
74 |
+
|
75 |
+
```python
|
76 |
+
>>> from transformers import FLMRVisionConfig, FLMRVisionModel
|
77 |
+
|
78 |
+
>>> # Initializing a FLMRVisionConfig with LinWeizheDragon/FLMR style configuration
|
79 |
+
>>> configuration = FLMRVisionConfig()
|
80 |
+
|
81 |
+
>>> # Initializing a FLMRVisionModel (with random weights) from the LinWeizheDragon/FLMR style configuration
|
82 |
+
>>> model = FLMRVisionModel(configuration)
|
83 |
+
|
84 |
+
>>> # Accessing the model configuration
|
85 |
+
>>> configuration = model.config
|
86 |
+
```"""
|
87 |
+
|
88 |
+
model_type = "flmr_vision_model"
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
hidden_size=768,
|
93 |
+
intermediate_size=3072,
|
94 |
+
projection_dim=512,
|
95 |
+
num_hidden_layers=12,
|
96 |
+
num_attention_heads=12,
|
97 |
+
num_channels=3,
|
98 |
+
image_size=224,
|
99 |
+
patch_size=32,
|
100 |
+
hidden_act="quick_gelu",
|
101 |
+
layer_norm_eps=1e-5,
|
102 |
+
attention_dropout=0.0,
|
103 |
+
initializer_range=0.02,
|
104 |
+
initializer_factor=1.0,
|
105 |
+
**kwargs,
|
106 |
+
):
|
107 |
+
super().__init__(**kwargs)
|
108 |
+
|
109 |
+
self.hidden_size = hidden_size
|
110 |
+
self.intermediate_size = intermediate_size
|
111 |
+
self.projection_dim = projection_dim
|
112 |
+
self.num_hidden_layers = num_hidden_layers
|
113 |
+
self.num_attention_heads = num_attention_heads
|
114 |
+
self.num_channels = num_channels
|
115 |
+
self.patch_size = patch_size
|
116 |
+
self.image_size = image_size
|
117 |
+
self.initializer_range = initializer_range
|
118 |
+
self.initializer_factor = initializer_factor
|
119 |
+
self.attention_dropout = attention_dropout
|
120 |
+
self.layer_norm_eps = layer_norm_eps
|
121 |
+
self.hidden_act = hidden_act
|
122 |
+
|
123 |
+
@classmethod
|
124 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
125 |
+
cls._set_token_in_kwargs(kwargs)
|
126 |
+
|
127 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
128 |
+
|
129 |
+
# get the vision config dict if we are loading from a CLIPConfig
|
130 |
+
if config_dict.get("model_type") == "clip":
|
131 |
+
config_dict = config_dict["vision_config"]
|
132 |
+
|
133 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
134 |
+
logger.warning(
|
135 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
136 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
137 |
+
)
|
138 |
+
|
139 |
+
return cls.from_dict(config_dict, **kwargs)
|
140 |
+
|
141 |
+
|
142 |
+
# Modified from transformers.models.dpr.configuration_dpr.DPRConfig with DPR -> FLMR
|
143 |
+
class FLMRTextConfig(PretrainedConfig):
|
144 |
+
r"""
|
145 |
+
[`FLMRTextConfig`] is the configuration class to store the configuration of a *FLMRTextModel*.
|
146 |
+
|
147 |
+
This is the configuration class to store the configuration of a [`FLMRTextModel`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
|
148 |
+
defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
|
149 |
+
configuration to that of the DPRContextEncoder
|
150 |
+
[facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
|
151 |
+
architecture.
|
152 |
+
|
153 |
+
This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
157 |
+
Vocabulary size of the FLMR model. Defines the different tokens that can be represented by the *inputs_ids*
|
158 |
+
passed to the forward method of [`BertModel`].
|
159 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
160 |
+
Dimensionality of the encoder layers and the pooler layer.
|
161 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
162 |
+
Number of hidden layers in the Transformer encoder.
|
163 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
164 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
165 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
166 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
167 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
168 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
169 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
170 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
171 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
172 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
173 |
+
The dropout ratio for the attention probabilities.
|
174 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
175 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
176 |
+
just in case (e.g., 512 or 1024 or 2048).
|
177 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
178 |
+
The vocabulary size of the *token_type_ids* passed into [`BertModel`].
|
179 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
180 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
181 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
182 |
+
The epsilon used by the layer normalization layers.
|
183 |
+
pad_token_id (`int`, *optional*, defaults to 0):
|
184 |
+
Padding token id.
|
185 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
186 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
187 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
188 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
189 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
190 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
191 |
+
projection_dim (`int`, *optional*, defaults to 0):
|
192 |
+
Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
|
193 |
+
projection is done.
|
194 |
+
text_encoder_base_model (`str`, *optional*, defaults to `"bert-base-uncased"`):
|
195 |
+
The text_encoder flmr based on.
|
196 |
+
query_maxlen (`int`, *optional*, defaults to 32)
|
197 |
+
The max_length for query tokenizer encoding.
|
198 |
+
|
199 |
+
Example:
|
200 |
+
|
201 |
+
```python
|
202 |
+
>>> from transformers import FLMRTextConfig, FLMRTextModel
|
203 |
+
|
204 |
+
>>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
|
205 |
+
>>> configuration = FLMRTextConfig()
|
206 |
+
|
207 |
+
>>> # Initializing a model (with random weights) from the LinWeizheDragon/FLMR style configuration
|
208 |
+
>>> model = FLMRTextModel(configuration)
|
209 |
+
|
210 |
+
>>> # Accessing the model configuration
|
211 |
+
>>> configuration = model.config
|
212 |
+
```"""
|
213 |
+
|
214 |
+
model_type = "flmr_text_model"
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
vocab_size=30522,
|
219 |
+
hidden_size=768,
|
220 |
+
num_hidden_layers=12,
|
221 |
+
num_attention_heads=12,
|
222 |
+
intermediate_size=3072,
|
223 |
+
hidden_act="gelu",
|
224 |
+
hidden_dropout_prob=0.1,
|
225 |
+
attention_probs_dropout_prob=0.1,
|
226 |
+
max_position_embeddings=512,
|
227 |
+
type_vocab_size=2,
|
228 |
+
initializer_range=0.02,
|
229 |
+
layer_norm_eps=1e-12,
|
230 |
+
pad_token_id=0,
|
231 |
+
position_embedding_type="absolute",
|
232 |
+
projection_dim: int = 0,
|
233 |
+
text_encoder_base_model="bert-base-uncased",
|
234 |
+
query_maxlen: int = 32,
|
235 |
+
**kwargs,
|
236 |
+
):
|
237 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
238 |
+
|
239 |
+
self.vocab_size = vocab_size
|
240 |
+
self.hidden_size = hidden_size
|
241 |
+
self.num_hidden_layers = num_hidden_layers
|
242 |
+
self.num_attention_heads = num_attention_heads
|
243 |
+
self.hidden_act = hidden_act
|
244 |
+
self.intermediate_size = intermediate_size
|
245 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
246 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
247 |
+
self.max_position_embeddings = max_position_embeddings
|
248 |
+
self.type_vocab_size = type_vocab_size
|
249 |
+
self.initializer_range = initializer_range
|
250 |
+
self.layer_norm_eps = layer_norm_eps
|
251 |
+
self.projection_dim = projection_dim
|
252 |
+
self.text_encoder_base_model = text_encoder_base_model
|
253 |
+
self.position_embedding_type = position_embedding_type
|
254 |
+
self.query_maxlen = query_maxlen
|
255 |
+
|
256 |
+
|
257 |
+
class FLMRConfig(PretrainedConfig):
|
258 |
+
r"""
|
259 |
+
[`FLMRConfig`] is the configuration class to store the configuration of a *FLMRModelForRetrieval*.
|
260 |
+
This is the configuration class to store the configuration of a [`FLMRModelForRetrieval`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
|
261 |
+
defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
|
262 |
+
configuration to that of the FLMR
|
263 |
+
[LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G)
|
264 |
+
architecture.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
vision_config (`FLMRVisionConfig`, *optional*):
|
268 |
+
Configuration for the vision encoder.
|
269 |
+
text_config (`FLMRTextConfig`, *optional*):
|
270 |
+
Configuration for the text encoder.
|
271 |
+
mask_punctuation (`bool`, *optional*, defaults to `True`):
|
272 |
+
Whether to mask punctuation tokens in the input.
|
273 |
+
mapping_network_prefix_length (`int`, *optional*, defaults to 32):
|
274 |
+
The output length of the linear mapping network.
|
275 |
+
dim (`int`, *optional*, defaults to 128):
|
276 |
+
The late-interaction dimension of the model. The output of the text encoder, vision encoder, transformer mapping network should all be projected to this dimension for late-interaction scoring.
|
277 |
+
use_vision_encoder (`bool`, *optional*, defaults to `True`):
|
278 |
+
Whether to load the vision encoder. When no vision encoder is loaded, `image_features` should be used in the forward pass rather than `pixel_values`.
|
279 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
280 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
281 |
+
separate_query_and_context_text_encoder (`bool`, *optional*, defaults to `False`):
|
282 |
+
Whether to use separate text encoders for query and context.
|
283 |
+
separate_query_and_context_vision_encoder (`bool`, *optional*, defaults to `False`):
|
284 |
+
Whether to use separate vision encoders for query and context.
|
285 |
+
query_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `True`):
|
286 |
+
Whether to concatenate the output from the vision encoder to the output from the text encoder for the query.
|
287 |
+
query_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
|
288 |
+
Whether to concatenate the output from the text encoder to the output from the vision encoder for the query.
|
289 |
+
context_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `False`):
|
290 |
+
Whether to concatenate the output from the vision encoder to the output from the text encoder for the context.
|
291 |
+
context_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
|
292 |
+
Whether to concatenate the output from the text encoder to the output from the vision encoder for the context.
|
293 |
+
use_transformer_mapping_network (`bool`, *optional*, defaults to `False`):
|
294 |
+
Whether to add a transformer mapping network to map the features from the vision encoder to the embedding space. This option is used in PreFLMR.
|
295 |
+
transformer_mapping_config_base (`str`, *optional*):
|
296 |
+
The base configuration for the transformer mapping network. This option is used in PreFLMR. An example of this argument is `bert-base-uncased`.
|
297 |
+
transformer_mapping_num_hidden_layers (`int`, *optional*):
|
298 |
+
The number of hidden layers in the transformer mapping network. This option is used in PreFLMR.
|
299 |
+
load_cpu_extension (`bool`, *optional*, defaults to `False`):
|
300 |
+
Whether to load the CPU extension. Only set this to `True` if a CPU is used in training and inference. In any case, GPU is recommended for training and inference.
|
301 |
+
mask_instruction_token (`str`, *optional*):
|
302 |
+
The token that indicates the end of the input instruction. All tokens before this token (the first one in a sequence) will be masked. This option is used in PreFLMR.
|
303 |
+
transformer_mapping_cross_attention_length (`int`, *optional*, defaults to 32):
|
304 |
+
The length of the cross attention in the transformer mapping network. This option is used in PreFLMR.
|
305 |
+
vision_model_version (`str`, *optional*, defaults to `"openai/clip-vit-base-patch32"`):
|
306 |
+
The version of the vision model being used in this FLMR model.
|
307 |
+
This option is used in performing retrieval only. Though it does not affect the model architecture, it is highly recommended to set this argument so that it properly reflects the version of the vision model being used in the FLMR model. This arugment will be saved in the model configuration, and it can be read by the indexing engine. The indexing engine will use this argument to initialize an image processor, which can process the input image files. Find more details under `examples/research_projects/flmr-retrieval`.
|
308 |
+
query_mask_input_ids_skip_list (`List`, *optional*, defaults to `[]`):
|
309 |
+
The input_ids need to skip when execute query_mask.
|
310 |
+
|
311 |
+
Example:
|
312 |
+
|
313 |
+
```python
|
314 |
+
>>> from transformers import FLMRConfig, FLMRModelForRetrieval
|
315 |
+
|
316 |
+
>>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
|
317 |
+
>>> configuration = FLMRConfig()
|
318 |
+
|
319 |
+
>>> # Initializing a model (with random weights) from the FLMR style configuration
|
320 |
+
>>> model = FLMRModelForRetrieval(configuration)
|
321 |
+
|
322 |
+
>>> # Accessing the model configuration
|
323 |
+
>>> configuration = model.config
|
324 |
+
```"""
|
325 |
+
|
326 |
+
model_type = "flmr"
|
327 |
+
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
vision_config: FLMRVisionConfig = None,
|
331 |
+
text_config: FLMRTextConfig = None,
|
332 |
+
mask_punctuation: bool = True,
|
333 |
+
mapping_network_prefix_length: int = 32,
|
334 |
+
dim: int = 128,
|
335 |
+
use_vision_encoder: bool = True,
|
336 |
+
initializer_range: float = 0.02,
|
337 |
+
separate_query_and_context_text_encoder: bool = False,
|
338 |
+
separate_query_and_context_vision_encoder: bool = False,
|
339 |
+
query_concat_output_from_vision_encoder: bool = True,
|
340 |
+
query_concat_output_from_text_encoder: bool = True,
|
341 |
+
context_concat_output_from_vision_encoder: bool = False,
|
342 |
+
context_concat_output_from_text_encoder: bool = True,
|
343 |
+
use_transformer_mapping_network: bool = False,
|
344 |
+
transformer_mapping_config_base: str = None,
|
345 |
+
transformer_mapping_num_hidden_layers: int = None,
|
346 |
+
load_cpu_extension: bool = False,
|
347 |
+
mask_instruction_token: str = None,
|
348 |
+
transformer_mapping_cross_attention_length: int = 32,
|
349 |
+
vision_model_version: str = "openai/clip-vit-base-patch32",
|
350 |
+
query_mask_input_ids_skip_list: list = [],
|
351 |
+
**kwargs,
|
352 |
+
):
|
353 |
+
super().__init__(**kwargs)
|
354 |
+
|
355 |
+
if vision_config is None:
|
356 |
+
vision_config = {}
|
357 |
+
if text_config is None:
|
358 |
+
text_config = {}
|
359 |
+
|
360 |
+
if not isinstance(vision_config, FLMRVisionConfig):
|
361 |
+
vision_config = FLMRVisionConfig(**vision_config)
|
362 |
+
if not isinstance(text_config, FLMRTextConfig):
|
363 |
+
text_config = FLMRTextConfig(**text_config)
|
364 |
+
|
365 |
+
self.vision_config = vision_config
|
366 |
+
self.text_config = text_config
|
367 |
+
self.dim = dim
|
368 |
+
self.initializer_range = initializer_range
|
369 |
+
self.mask_punctuation = mask_punctuation
|
370 |
+
self.mapping_network_prefix_length = mapping_network_prefix_length
|
371 |
+
self.use_vision_encoder = use_vision_encoder
|
372 |
+
self.separate_query_and_context_text_encoder = separate_query_and_context_text_encoder
|
373 |
+
self.separate_query_and_context_vision_encoder = separate_query_and_context_vision_encoder
|
374 |
+
self.query_concat_output_from_vision_encoder = query_concat_output_from_vision_encoder
|
375 |
+
self.query_concat_output_from_text_encoder = query_concat_output_from_text_encoder
|
376 |
+
self.context_concat_output_from_vision_encoder = context_concat_output_from_vision_encoder
|
377 |
+
self.context_concat_output_from_text_encoder = context_concat_output_from_text_encoder
|
378 |
+
self.use_transformer_mapping_network = use_transformer_mapping_network
|
379 |
+
self.transformer_mapping_config_base = transformer_mapping_config_base
|
380 |
+
self.transformer_mapping_num_hidden_layers = transformer_mapping_num_hidden_layers
|
381 |
+
self.load_cpu_extension = load_cpu_extension
|
382 |
+
self.mask_instruction_token = mask_instruction_token
|
383 |
+
self.transformer_mapping_cross_attention_length = transformer_mapping_cross_attention_length
|
384 |
+
self.vision_model_version = vision_model_version
|
385 |
+
self.query_mask_input_ids_skip_list = query_mask_input_ids_skip_list
|
386 |
+
|
387 |
+
@classmethod
|
388 |
+
def from_text_vision_configs(cls, text_config: FLMRTextConfig, vision_config: FLMRVisionConfig, **kwargs):
|
389 |
+
r"""
|
390 |
+
Instantiate a [`FLMRConfig`] (or a derived class) from FLMR text model configuration and FLMR vision model
|
391 |
+
configuration.
|
392 |
+
|
393 |
+
Returns:
|
394 |
+
[`FLMRConfig`]: An instance of a configuration object
|
395 |
+
"""
|
396 |
+
|
397 |
+
return cls(text_config=text_config, vision_config=vision_config, **kwargs)
|
context_tokenizer/sentencepiece.bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
|
3 |
+
size 5069051
|
context_tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "<unk>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
context_tokenizer/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:249df0778f236f6ece390de0de746838ef25b9d6954b68c2ee71249e0a9d8fd4
|
3 |
+
size 17082799
|
context_tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"250001": {
|
36 |
+
"content": "<mask>",
|
37 |
+
"lstrip": true,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "<s>",
|
47 |
+
"eos_token": "</s>",
|
48 |
+
"mask_token": "<mask>",
|
49 |
+
"model_max_length": 8192,
|
50 |
+
"pad_token": "<pad>",
|
51 |
+
"sep_token": "</s>",
|
52 |
+
"sp_model_kwargs": {},
|
53 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
54 |
+
"unk_token": "<unk>"
|
55 |
+
}
|
flmr_utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
|
9 |
+
def get_rank():
|
10 |
+
return dist.get_rank()
|
11 |
+
|
12 |
+
|
13 |
+
def get_world_size():
|
14 |
+
return dist.get_world_size()
|
15 |
+
|
16 |
+
|
17 |
+
def get_default_group():
|
18 |
+
return dist.group.WORLD
|
19 |
+
|
20 |
+
|
21 |
+
# TODO: The masking below might also be applicable in the kNN part
|
22 |
+
def colbert_score_reduce(scores_padded, D_mask):
|
23 |
+
# print('D_mask', D_mask.shape, D_mask)
|
24 |
+
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
|
25 |
+
# print('D_padding', D_padding.shape, D_padding)
|
26 |
+
# print(D_padding[0].tolist())
|
27 |
+
scores_padded[D_padding] = -9999
|
28 |
+
scores = scores_padded.max(1).values
|
29 |
+
|
30 |
+
return scores.sum(-1)
|
31 |
+
|
32 |
+
|
33 |
+
def colbert_score(Q, D_padded, D_mask, use_gpu=False):
|
34 |
+
"""
|
35 |
+
Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
|
36 |
+
If Q.size(0) is 1, the matrix will be compared with all passages.
|
37 |
+
Otherwise, each query matrix will be compared against the *aligned* passage.
|
38 |
+
|
39 |
+
EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
|
40 |
+
"""
|
41 |
+
if use_gpu:
|
42 |
+
Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
|
43 |
+
assert Q.dim() == 3, Q.size()
|
44 |
+
assert D_padded.dim() == 3, D_padded.size()
|
45 |
+
assert Q.size(0) in [1, D_padded.size(0)]
|
46 |
+
|
47 |
+
scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
|
48 |
+
|
49 |
+
return colbert_score_reduce(scores, D_mask)
|
50 |
+
|
51 |
+
|
52 |
+
def _sort_by_length(ids, mask, bsize, *args):
|
53 |
+
if ids.size(0) <= bsize:
|
54 |
+
return ids, mask, torch.arange(ids.size(0))
|
55 |
+
|
56 |
+
indices = mask.sum(-1).sort().indices
|
57 |
+
reverse_indices = indices.sort().indices
|
58 |
+
|
59 |
+
return_array = [ids[indices], mask[indices]]
|
60 |
+
for arg in args:
|
61 |
+
if isinstance(arg, torch.Tensor):
|
62 |
+
return_array.append(arg[indices])
|
63 |
+
else:
|
64 |
+
# arg is a list, and we want to sort the list according to indices
|
65 |
+
return_array.append([arg[i] for i in indices])
|
66 |
+
|
67 |
+
return *return_array, reverse_indices
|
68 |
+
|
69 |
+
|
70 |
+
def _split_into_batches(ids, mask, bsize, *args):
|
71 |
+
batches = []
|
72 |
+
for offset in range(0, ids.size(0), bsize):
|
73 |
+
batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
|
74 |
+
for arg in args:
|
75 |
+
batch.append(arg[offset : offset + bsize])
|
76 |
+
batches.append(batch)
|
77 |
+
return batches
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59cf447aaa66331417b1b346d8fdde1c21f737884364140dafe8f518e2a01ec6
|
3 |
+
size 3530548760
|
modeling_flmr.py
ADDED
@@ -0,0 +1,1527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 FLMR Authors, The Hugging Face Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch FLMR model for Knowledge-intensive Visual Question Answering."""
|
16 |
+
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import os
|
20 |
+
import pathlib
|
21 |
+
import string
|
22 |
+
from dataclasses import dataclass
|
23 |
+
from typing import Optional, Tuple, Union
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.distributed as dist
|
27 |
+
from torch import Tensor, nn
|
28 |
+
from torch.utils.cpp_extension import load
|
29 |
+
|
30 |
+
from transformers import AutoModel, AutoConfig
|
31 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
32 |
+
from transformers.modeling_utils import PreTrainedModel
|
33 |
+
from transformers.utils import (
|
34 |
+
ModelOutput,
|
35 |
+
add_start_docstrings,
|
36 |
+
add_start_docstrings_to_model_forward,
|
37 |
+
logging,
|
38 |
+
replace_return_docstrings,
|
39 |
+
)
|
40 |
+
from transformers.models.bert.modeling_bert import BertModel
|
41 |
+
from transformers.models.clip import CLIPVisionModel
|
42 |
+
from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
|
43 |
+
from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
|
44 |
+
from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
|
45 |
+
from .flmr_utils import (
|
46 |
+
colbert_score,
|
47 |
+
colbert_score_reduce,
|
48 |
+
get_rank,
|
49 |
+
get_world_size,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
logger = logging.get_logger(__name__)
|
54 |
+
|
55 |
+
_CONFIG_FOR_DOC = "FLMRConfig"
|
56 |
+
_CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L"
|
57 |
+
|
58 |
+
|
59 |
+
FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
60 |
+
"LinWeizheDragon/PreFLMR_ViT-L",
|
61 |
+
"LinWeizheDragon/FLMR",
|
62 |
+
# See all FLMR models at https://huggingface.co/models?filter=flmr
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
##########
|
67 |
+
# Outputs
|
68 |
+
##########
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class FLMRContextEncoderOutput(ModelOutput):
|
73 |
+
"""
|
74 |
+
Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`].
|
75 |
+
|
76 |
+
Args:
|
77 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
|
78 |
+
The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation.
|
79 |
+
This output can be used to embed questions for nearest neighbors queries with query embeddings.
|
80 |
+
late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
|
81 |
+
The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
|
82 |
+
This output is to be used to embed contexts for late-interaction retrieval with query embeddings.
|
83 |
+
context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`):
|
84 |
+
The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation.
|
85 |
+
text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
86 |
+
Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
|
87 |
+
tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
|
88 |
+
text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
89 |
+
Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
|
90 |
+
outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
|
91 |
+
vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
92 |
+
Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
|
93 |
+
tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
|
94 |
+
vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
95 |
+
Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
|
96 |
+
outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
|
97 |
+
transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
98 |
+
Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
|
99 |
+
is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
|
100 |
+
transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
101 |
+
Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
|
102 |
+
initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
|
103 |
+
"""
|
104 |
+
|
105 |
+
pooler_output: torch.FloatTensor
|
106 |
+
late_interaction_output: torch.FloatTensor = None
|
107 |
+
context_mask: torch.FloatTensor = None
|
108 |
+
text_encoder_attentions: Optional[Tuple[Tensor]] = None
|
109 |
+
text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
|
110 |
+
vision_encoder_attentions: Optional[Tuple[Tensor]] = None
|
111 |
+
vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
|
112 |
+
transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
|
113 |
+
transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
|
114 |
+
|
115 |
+
|
116 |
+
@dataclass
|
117 |
+
class FLMRQueryEncoderOutput(ModelOutput):
|
118 |
+
"""
|
119 |
+
Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`].
|
120 |
+
|
121 |
+
Args:
|
122 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
|
123 |
+
The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation.
|
124 |
+
This output can be used to embed questions for nearest neighbors queries with context embeddings.
|
125 |
+
late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
|
126 |
+
The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
|
127 |
+
This output is to be used to embed questions for late-interaction retrieval with context embeddings.
|
128 |
+
text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
129 |
+
Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
|
130 |
+
tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
|
131 |
+
text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
132 |
+
Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
|
133 |
+
outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
|
134 |
+
vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
135 |
+
Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
|
136 |
+
tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
|
137 |
+
vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
138 |
+
Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
|
139 |
+
outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
|
140 |
+
transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
|
141 |
+
Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
|
142 |
+
is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
|
143 |
+
transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
144 |
+
Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
|
145 |
+
initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
|
146 |
+
"""
|
147 |
+
|
148 |
+
pooler_output: torch.FloatTensor
|
149 |
+
late_interaction_output: torch.FloatTensor = None
|
150 |
+
text_encoder_attentions: Optional[Tuple[Tensor]] = None
|
151 |
+
text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
|
152 |
+
vision_encoder_attentions: Optional[Tuple[Tensor]] = None
|
153 |
+
vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
|
154 |
+
transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
|
155 |
+
transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
|
156 |
+
|
157 |
+
|
158 |
+
@dataclass
|
159 |
+
class FLMRModelForRetrievalOutput(ModelOutput):
|
160 |
+
"""
|
161 |
+
Class for outputs of [`FLMRModelForRetrieval.query()`].
|
162 |
+
|
163 |
+
Args:
|
164 |
+
loss (`torch.FloatTensor`):
|
165 |
+
contrastive loss of the input queries and positive and negative examples. This output is to be used in model training.
|
166 |
+
scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`):
|
167 |
+
The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples.
|
168 |
+
in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
|
169 |
+
The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training.
|
170 |
+
query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
|
171 |
+
The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query.
|
172 |
+
context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
|
173 |
+
The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context.
|
174 |
+
query_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
|
175 |
+
Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
|
176 |
+
query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
|
177 |
+
Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
|
178 |
+
context_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
|
179 |
+
Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
|
180 |
+
context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
|
181 |
+
Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
|
182 |
+
"""
|
183 |
+
|
184 |
+
loss: torch.FloatTensor
|
185 |
+
scores: torch.FloatTensor = None
|
186 |
+
in_batch_negative_loss: torch.FloatTensor = None
|
187 |
+
query_late_interaction_output: torch.FloatTensor = None
|
188 |
+
context_late_interaction_output: torch.FloatTensor = None
|
189 |
+
query_attentions: Optional[Tuple[Tuple[Tensor]]] = None
|
190 |
+
query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
|
191 |
+
context_attentions: Optional[Tuple[Tuple[Tensor]]] = None
|
192 |
+
context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
|
193 |
+
|
194 |
+
|
195 |
+
class FLMRPreTrainedModel(PreTrainedModel):
|
196 |
+
def _init_weights(self, module):
|
197 |
+
"""Initialize the weights"""
|
198 |
+
if isinstance(module, nn.Linear):
|
199 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
200 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
201 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
202 |
+
if module.bias is not None:
|
203 |
+
module.bias.data.zero_()
|
204 |
+
elif isinstance(module, nn.Embedding):
|
205 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
206 |
+
if module.padding_idx is not None:
|
207 |
+
module.weight.data[module.padding_idx].zero_()
|
208 |
+
elif isinstance(module, nn.LayerNorm):
|
209 |
+
module.bias.data.zero_()
|
210 |
+
module.weight.data.fill_(1.0)
|
211 |
+
|
212 |
+
|
213 |
+
##################
|
214 |
+
# PreTrainedModel
|
215 |
+
##################
|
216 |
+
|
217 |
+
|
218 |
+
class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel):
|
219 |
+
"""
|
220 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
221 |
+
models.
|
222 |
+
"""
|
223 |
+
|
224 |
+
config_class = FLMRConfig
|
225 |
+
load_tf_weights = None
|
226 |
+
base_model_prefix = "flmr"
|
227 |
+
|
228 |
+
|
229 |
+
###############
|
230 |
+
# Actual Models
|
231 |
+
###############
|
232 |
+
|
233 |
+
|
234 |
+
FLMR_START_DOCSTRING = r"""
|
235 |
+
|
236 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
237 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
238 |
+
etc.)
|
239 |
+
|
240 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
241 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
242 |
+
and behavior.
|
243 |
+
|
244 |
+
Parameters:
|
245 |
+
config ([`FLMRConfig`]): Model configuration class with all the parameters of the model.
|
246 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
247 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
248 |
+
query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query.
|
249 |
+
The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
|
250 |
+
context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context.
|
251 |
+
The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
|
252 |
+
"""
|
253 |
+
|
254 |
+
|
255 |
+
FLMR_MODEL_INPUTS_DOCSTRING = r"""
|
256 |
+
Args:
|
257 |
+
query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
|
258 |
+
Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
|
259 |
+
formatted with [CLS] and Q marker tokens as follows:
|
260 |
+
[CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
|
261 |
+
|
262 |
+
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
263 |
+
rather than the left.
|
264 |
+
|
265 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
266 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
267 |
+
|
268 |
+
[What are input IDs?](../glossary#input-ids)
|
269 |
+
query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
|
270 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
271 |
+
|
272 |
+
- 1 for tokens that are **not masked**,
|
273 |
+
- 0 for tokens that are **masked**.
|
274 |
+
|
275 |
+
[What are attention masks?](../glossary#attention-mask)
|
276 |
+
query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
277 |
+
Pixel values. Pixel values can be obtained using
|
278 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
279 |
+
query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
|
280 |
+
Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
|
281 |
+
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
|
282 |
+
context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
|
283 |
+
Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
|
284 |
+
formatted with [CLS] and D marker tokens as follows:
|
285 |
+
[CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
|
286 |
+
|
287 |
+
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
288 |
+
rather than the left.
|
289 |
+
|
290 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
291 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
292 |
+
|
293 |
+
[What are input IDs?](../glossary#input-ids)
|
294 |
+
|
295 |
+
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
|
296 |
+
|
297 |
+
context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
|
298 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
299 |
+
|
300 |
+
- 1 for tokens that are **not masked**,
|
301 |
+
- 0 for tokens that are **masked**.
|
302 |
+
|
303 |
+
[What are attention masks?](../glossary#attention-mask)
|
304 |
+
|
305 |
+
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
|
306 |
+
context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
307 |
+
Pixel values. Pixel values can be obtained using
|
308 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
309 |
+
context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
|
310 |
+
Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
|
311 |
+
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
|
312 |
+
use_in_batch_negatives (`bool`, *optional*):
|
313 |
+
Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training.
|
314 |
+
in_batch_negatives_from_all_gpus (`bool`, *optional*):
|
315 |
+
Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training.
|
316 |
+
num_negative_examples (`int`, *optional*):
|
317 |
+
The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`.
|
318 |
+
query_concat_output_from_vision_encoder (`bool`, *optional*):
|
319 |
+
Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
320 |
+
query_concat_output_from_text_encoder (`bool`, *optional*):
|
321 |
+
Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
322 |
+
|
323 |
+
This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
|
324 |
+
context_concat_output_from_vision_encoder (`bool`, *optional*):
|
325 |
+
Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
|
326 |
+
|
327 |
+
This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
|
328 |
+
context_concat_output_from_text_encoder (`bool`, *optional*):
|
329 |
+
Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
330 |
+
return_dict (`bool`, *optional*):
|
331 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
332 |
+
output_attentions (`bool`, *optional*):
|
333 |
+
Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned
|
334 |
+
tensors for more detail.
|
335 |
+
output_hidden_states (`bool`, *optional*):
|
336 |
+
Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail.
|
337 |
+
"""
|
338 |
+
|
339 |
+
|
340 |
+
FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r"""
|
341 |
+
Args:
|
342 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
|
343 |
+
Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
|
344 |
+
formatted with [CLS] and Q marker tokens as follows:
|
345 |
+
[CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
|
346 |
+
|
347 |
+
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
348 |
+
rather than the left.
|
349 |
+
|
350 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
351 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
352 |
+
|
353 |
+
[What are input IDs?](../glossary#input-ids)
|
354 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
|
355 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
356 |
+
|
357 |
+
- 1 for tokens that are **not masked**,
|
358 |
+
- 0 for tokens that are **masked**.
|
359 |
+
|
360 |
+
[What are attention masks?](../glossary#attention-mask)
|
361 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
362 |
+
Pixel values. Pixel values can be obtained using
|
363 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
364 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
|
365 |
+
Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
|
366 |
+
using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
|
367 |
+
concat_output_from_vision_encoder (`bool`, *optional*):
|
368 |
+
Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
369 |
+
concat_output_from_text_encoder (`bool`, *optional*):
|
370 |
+
Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
371 |
+
|
372 |
+
This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
|
373 |
+
"""
|
374 |
+
|
375 |
+
|
376 |
+
FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r"""
|
377 |
+
Args:
|
378 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
|
379 |
+
Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
|
380 |
+
formatted with [CLS] and D marker tokens as follows:
|
381 |
+
[CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
|
382 |
+
|
383 |
+
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
384 |
+
rather than the left.
|
385 |
+
|
386 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
387 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
388 |
+
|
389 |
+
[What are input IDs?](../glossary#input-ids)
|
390 |
+
|
391 |
+
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
|
392 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
|
393 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
394 |
+
|
395 |
+
- 1 for tokens that are **not masked**,
|
396 |
+
- 0 for tokens that are **masked**.
|
397 |
+
|
398 |
+
[What are attention masks?](../glossary#attention-mask)
|
399 |
+
|
400 |
+
The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
|
401 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
402 |
+
Pixel values. Pixel values can be obtained using
|
403 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
404 |
+
image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
|
405 |
+
Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
|
406 |
+
using [`CLIPVisionModel`]. See [`CLIPVisionModel
|
407 |
+
.__call__`] for details.
|
408 |
+
concat_output_from_vision_encoder (`bool`, *optional*):
|
409 |
+
Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
|
410 |
+
|
411 |
+
This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
|
412 |
+
concat_output_from_text_encoder (`bool`, *optional*):
|
413 |
+
Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
|
414 |
+
keep_dims (`bool`, *optional*):
|
415 |
+
Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training.
|
416 |
+
return_mask (`bool`, *optional*):
|
417 |
+
Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training.
|
418 |
+
"""
|
419 |
+
|
420 |
+
|
421 |
+
FLMR_TEXT_ENCODERS_START_DOCSTRING = r"""
|
422 |
+
|
423 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
424 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
425 |
+
etc.)
|
426 |
+
|
427 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
428 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
429 |
+
and behavior.
|
430 |
+
|
431 |
+
Parameters:
|
432 |
+
config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model.
|
433 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
434 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
435 |
+
"""
|
436 |
+
|
437 |
+
|
438 |
+
# Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR
|
439 |
+
FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r"""
|
440 |
+
Args:
|
441 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
442 |
+
Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be
|
443 |
+
formatted with [CLS] and [SEP] tokens as follows:
|
444 |
+
|
445 |
+
(a) For sequence pairs (for a pair title+text for example):
|
446 |
+
|
447 |
+
```
|
448 |
+
tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
|
449 |
+
token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
|
450 |
+
```
|
451 |
+
|
452 |
+
(b) For single sequences (for a question for example):
|
453 |
+
|
454 |
+
```
|
455 |
+
tokens: [CLS] the dog is hairy . [SEP]
|
456 |
+
token_type_ids: 0 0 0 0 0 0 0
|
457 |
+
```
|
458 |
+
|
459 |
+
FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
|
460 |
+
rather than the left.
|
461 |
+
|
462 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
463 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
464 |
+
|
465 |
+
[What are input IDs?](../glossary#input-ids)
|
466 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
467 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
468 |
+
|
469 |
+
- 1 for tokens that are **not masked**,
|
470 |
+
- 0 for tokens that are **masked**.
|
471 |
+
|
472 |
+
[What are attention masks?](../glossary#attention-mask)
|
473 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
474 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
475 |
+
1]`:
|
476 |
+
|
477 |
+
- 0 corresponds to a *sentence A* token,
|
478 |
+
- 1 corresponds to a *sentence B* token.
|
479 |
+
|
480 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
481 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
482 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
483 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
484 |
+
model's internal embedding lookup matrix.
|
485 |
+
output_attentions (`bool`, *optional*):
|
486 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
487 |
+
tensors for more detail.
|
488 |
+
output_hidden_states (`bool`, *optional*):
|
489 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
490 |
+
more detail.
|
491 |
+
return_dict (`bool`, *optional*):
|
492 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
493 |
+
"""
|
494 |
+
|
495 |
+
FLMR_VISION_ENCODERS_START_DOCSTRING = r"""
|
496 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
497 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
498 |
+
etc.)
|
499 |
+
|
500 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
501 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
502 |
+
and behavior.
|
503 |
+
|
504 |
+
Parameters:
|
505 |
+
config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model.
|
506 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
507 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
508 |
+
"""
|
509 |
+
|
510 |
+
# Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR
|
511 |
+
FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r"""
|
512 |
+
Args:
|
513 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
514 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
515 |
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
516 |
+
output_attentions (`bool`, *optional*):
|
517 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
518 |
+
tensors for more detail.
|
519 |
+
output_hidden_states (`bool`, *optional*):
|
520 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
521 |
+
more detail.
|
522 |
+
return_dict (`bool`, *optional*):
|
523 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
524 |
+
"""
|
525 |
+
|
526 |
+
|
527 |
+
class FLMRMultiLayerPerceptron(nn.Module):
|
528 |
+
"""
|
529 |
+
A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model.
|
530 |
+
"""
|
531 |
+
|
532 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
533 |
+
return self.model(x)
|
534 |
+
|
535 |
+
def __init__(self, sizes, bias=True, act=nn.Tanh):
|
536 |
+
super(FLMRMultiLayerPerceptron, self).__init__()
|
537 |
+
layers = []
|
538 |
+
for i in range(len(sizes) - 1):
|
539 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
540 |
+
if i < len(sizes) - 2:
|
541 |
+
layers.append(act())
|
542 |
+
self.model = nn.Sequential(*layers)
|
543 |
+
|
544 |
+
|
545 |
+
@add_start_docstrings(
|
546 |
+
"The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ",
|
547 |
+
FLMR_START_DOCSTRING,
|
548 |
+
)
|
549 |
+
class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
550 |
+
_keys_to_ignore_on_load_unexpected = [r"cls"]
|
551 |
+
main_input_name = "query_input_ids"
|
552 |
+
_tied_weights_keys = [] # Added dynamically at initialization depending on the architecture
|
553 |
+
|
554 |
+
def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None):
|
555 |
+
super().__init__(config)
|
556 |
+
self.config = config
|
557 |
+
self.vision_model_version = config.vision_model_version
|
558 |
+
|
559 |
+
self.context_text_encoder = FLMRTextModel(config.text_config)
|
560 |
+
self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)
|
561 |
+
|
562 |
+
self.query_tokenizer = query_tokenizer
|
563 |
+
self.context_tokenizer = context_tokenizer
|
564 |
+
|
565 |
+
if self.query_tokenizer is None:
|
566 |
+
logger.warning(
|
567 |
+
"query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
|
568 |
+
)
|
569 |
+
|
570 |
+
# initialize a FLMRQueryEncoderTokenizer
|
571 |
+
self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")
|
572 |
+
|
573 |
+
if self.context_tokenizer is None:
|
574 |
+
logger.warning(
|
575 |
+
"context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
|
576 |
+
)
|
577 |
+
|
578 |
+
# initialize a FLMRContextEncoderTokenizer
|
579 |
+
self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")
|
580 |
+
|
581 |
+
self.mapping_network_prefix_length = self.config.mapping_network_prefix_length
|
582 |
+
self.vision_encoder_embedding_size = self.config.vision_config.hidden_size
|
583 |
+
self.text_encoder_embedding_size = self.config.text_config.hidden_size
|
584 |
+
self.late_interaction_embedding_size = self.config.dim
|
585 |
+
|
586 |
+
if self.config.use_vision_encoder:
|
587 |
+
self.context_vision_projection = FLMRMultiLayerPerceptron(
|
588 |
+
(
|
589 |
+
self.vision_encoder_embedding_size,
|
590 |
+
(self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
|
591 |
+
self.late_interaction_embedding_size * self.mapping_network_prefix_length,
|
592 |
+
)
|
593 |
+
)
|
594 |
+
|
595 |
+
if self.config.use_vision_encoder:
|
596 |
+
self.context_vision_encoder = FLMRVisionModel(config.vision_config)
|
597 |
+
|
598 |
+
if self.config.use_transformer_mapping_network:
|
599 |
+
# This is a PreFLMR style model
|
600 |
+
transformer_mapping_config_base = self.config.transformer_mapping_config_base
|
601 |
+
try:
|
602 |
+
from transformers import BertConfig
|
603 |
+
from transformers.models.bert.modeling_bert import BertEncoder
|
604 |
+
except Exception as e:
|
605 |
+
raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}")
|
606 |
+
|
607 |
+
transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
|
608 |
+
|
609 |
+
assert (
|
610 |
+
self.config.text_config.hidden_size == transformer_mapping_config.hidden_size
|
611 |
+
), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match."
|
612 |
+
# shallow transformer
|
613 |
+
transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers
|
614 |
+
# add cross attention
|
615 |
+
transformer_mapping_config.is_decoder = True
|
616 |
+
transformer_mapping_config.add_cross_attention = True
|
617 |
+
|
618 |
+
# The linear layer from vision encoder to transformer input
|
619 |
+
self.transformer_mapping_input_linear = nn.Linear(
|
620 |
+
self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size
|
621 |
+
)
|
622 |
+
|
623 |
+
# The transformer encoder
|
624 |
+
self.transformer_mapping_network = BertEncoder(transformer_mapping_config)
|
625 |
+
|
626 |
+
# The linear layer from transformer output to FLMR dim
|
627 |
+
self.transformer_mapping_output_linear = nn.Linear(
|
628 |
+
transformer_mapping_config.hidden_size, self.late_interaction_embedding_size
|
629 |
+
)
|
630 |
+
|
631 |
+
if self.config.separate_query_and_context_text_encoder:
|
632 |
+
self.query_text_encoder = copy.deepcopy(self.context_text_encoder)
|
633 |
+
self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear)
|
634 |
+
else:
|
635 |
+
self.query_text_encoder = self.context_text_encoder
|
636 |
+
self.query_text_encoder_linear = self.context_text_encoder_linear
|
637 |
+
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
|
638 |
+
|
639 |
+
if self.config.use_vision_encoder:
|
640 |
+
if self.config.separate_query_and_context_vision_encoder:
|
641 |
+
self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
|
642 |
+
self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
|
643 |
+
else:
|
644 |
+
self.query_vision_encoder = self.context_vision_encoder
|
645 |
+
self.query_vision_projection = self.context_vision_projection
|
646 |
+
self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
|
647 |
+
|
648 |
+
if self.config.load_cpu_extension:
|
649 |
+
try:
|
650 |
+
FLMRModelForRetrieval.try_load_torch_extensions()
|
651 |
+
except Exception as e:
|
652 |
+
raise(f"Unable to load `segmented_maxsim.cpp`. hf-hub does not download this file automatically. Please download it manually from `https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/blob/main/segmented_maxsim.cpp` and put it under the same folder as the model file.\n {e}")
|
653 |
+
|
654 |
+
if self.config.mask_punctuation:
|
655 |
+
self.skiplist = {
|
656 |
+
w: True
|
657 |
+
for symbol in string.punctuation
|
658 |
+
for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
|
659 |
+
}
|
660 |
+
|
661 |
+
if self.config.mask_instruction_token is not None:
|
662 |
+
self.mask_instruction = True
|
663 |
+
# obtain the token id of the instruction token
|
664 |
+
self.instruction_token_id = self.query_tokenizer.encode(
|
665 |
+
self.config.mask_instruction_token, add_special_tokens=False
|
666 |
+
)[0]
|
667 |
+
else:
|
668 |
+
self.mask_instruction = False
|
669 |
+
|
670 |
+
self.loss_fn = torch.nn.CrossEntropyLoss()
|
671 |
+
|
672 |
+
# Initialize weights and apply final processing
|
673 |
+
self.post_init()
|
674 |
+
|
675 |
+
@property
|
676 |
+
def use_gpu(self):
|
677 |
+
return self.device.type == "cuda"
|
678 |
+
|
679 |
+
@classmethod
|
680 |
+
def from_pretrained(self, name_or_path, **kwargs):
|
681 |
+
obj = super().from_pretrained(name_or_path, **kwargs)
|
682 |
+
return obj
|
683 |
+
|
684 |
+
@classmethod
|
685 |
+
def try_load_torch_extensions(cls):
|
686 |
+
if hasattr(cls, "loaded_extensions"):
|
687 |
+
return
|
688 |
+
|
689 |
+
logger.info(
|
690 |
+
"Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..."
|
691 |
+
)
|
692 |
+
segmented_maxsim_cpp = load(
|
693 |
+
name="segmented_maxsim_cpp",
|
694 |
+
sources=[
|
695 |
+
os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"),
|
696 |
+
],
|
697 |
+
extra_cflags=["-O3"],
|
698 |
+
verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
|
699 |
+
)
|
700 |
+
cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp
|
701 |
+
|
702 |
+
cls.loaded_extensions = True
|
703 |
+
|
704 |
+
def query_mask(self, input_ids, skiplist):
|
705 |
+
if not self.mask_instruction:
|
706 |
+
return self.mask(input_ids, skiplist)
|
707 |
+
|
708 |
+
# find the position of end of instruction in input_ids
|
709 |
+
# mask the tokens before the position
|
710 |
+
sep_id = self.instruction_token_id
|
711 |
+
sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
|
712 |
+
# if any of the positions is lower than 1, set to 1
|
713 |
+
for i, x in enumerate(sep_positions):
|
714 |
+
if x < 1:
|
715 |
+
sep_positions[i] = 1
|
716 |
+
logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
|
717 |
+
mask = [
|
718 |
+
[
|
719 |
+
(x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
|
720 |
+
for index, x in enumerate(d)
|
721 |
+
]
|
722 |
+
for seq_index, d in enumerate(input_ids.cpu().tolist())
|
723 |
+
]
|
724 |
+
return mask
|
725 |
+
|
726 |
+
@add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING)
|
727 |
+
@replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
|
728 |
+
def forward(
|
729 |
+
self,
|
730 |
+
query_input_ids: Optional[torch.Tensor] = None,
|
731 |
+
query_attention_mask: Optional[torch.Tensor] = None,
|
732 |
+
query_pixel_values: Optional[torch.Tensor] = None,
|
733 |
+
query_image_features: Optional[torch.Tensor] = None,
|
734 |
+
context_input_ids: Optional[torch.Tensor] = None,
|
735 |
+
context_attention_mask: Optional[torch.Tensor] = None,
|
736 |
+
context_pixel_values: Optional[torch.Tensor] = None,
|
737 |
+
context_image_features: Optional[torch.Tensor] = None,
|
738 |
+
use_in_batch_negatives: bool = True,
|
739 |
+
in_batch_negatives_from_all_gpus: bool = False,
|
740 |
+
num_negative_examples: int = 1,
|
741 |
+
query_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
|
742 |
+
query_concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
|
743 |
+
context_concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
|
744 |
+
context_concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
|
745 |
+
return_dict: bool = None,
|
746 |
+
output_attentions: bool = None,
|
747 |
+
output_hidden_states: bool = None,
|
748 |
+
) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]:
|
749 |
+
r"""
|
750 |
+
Return:
|
751 |
+
|
752 |
+
Examples:
|
753 |
+
|
754 |
+
```python
|
755 |
+
>>> import torch
|
756 |
+
>>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor
|
757 |
+
|
758 |
+
>>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L"
|
759 |
+
>>> image_processor_name = "openai/clip-vit-large-patch14"
|
760 |
+
>>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer")
|
761 |
+
>>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer")
|
762 |
+
|
763 |
+
>>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path,
|
764 |
+
query_tokenizer=query_tokenizer,
|
765 |
+
context_tokenizer=context_tokenizer,
|
766 |
+
)
|
767 |
+
>>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
|
768 |
+
|
769 |
+
>>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"])
|
770 |
+
>>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.",
|
771 |
+
"Paris is the capital of France.", "Beijing is the capital of China."])
|
772 |
+
>>> Q_pixel_values = torch.zeros(2, 3, 224, 224)
|
773 |
+
>>> inputs = dict(
|
774 |
+
query_input_ids=Q_encoding['input_ids'],
|
775 |
+
query_attention_mask=Q_encoding['attention_mask'],
|
776 |
+
query_pixel_values=Q_pixel_values,
|
777 |
+
context_input_ids=D_encoding['input_ids'],
|
778 |
+
context_attention_mask=D_encoding['attention_mask'],
|
779 |
+
use_in_batch_negatives=True,
|
780 |
+
)
|
781 |
+
|
782 |
+
>>> model.forward(**inputs)
|
783 |
+
FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16,
|
784 |
+
grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562],
|
785 |
+
[39.4375, 48.4062]], device='cuda:0', dtype=torch.float16,
|
786 |
+
grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...)
|
787 |
+
```
|
788 |
+
"""
|
789 |
+
|
790 |
+
if query_concat_output_from_vision_encoder is None:
|
791 |
+
query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
|
792 |
+
|
793 |
+
if query_concat_output_from_text_encoder is None:
|
794 |
+
query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
|
795 |
+
|
796 |
+
if context_concat_output_from_vision_encoder is None:
|
797 |
+
context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
|
798 |
+
|
799 |
+
if context_concat_output_from_text_encoder is None:
|
800 |
+
context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
|
801 |
+
|
802 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
803 |
+
output_hidden_states = (
|
804 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
805 |
+
)
|
806 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
807 |
+
|
808 |
+
query_outputs = self.query(
|
809 |
+
input_ids=query_input_ids,
|
810 |
+
attention_mask=query_attention_mask,
|
811 |
+
pixel_values=query_pixel_values,
|
812 |
+
image_features=query_image_features,
|
813 |
+
concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
|
814 |
+
concat_output_from_text_encoder=query_concat_output_from_text_encoder,
|
815 |
+
output_attentions=output_attentions,
|
816 |
+
output_hidden_states=output_hidden_states,
|
817 |
+
)
|
818 |
+
Q = query_outputs.late_interaction_output
|
819 |
+
|
820 |
+
context_outputs = self.doc(
|
821 |
+
input_ids=context_input_ids,
|
822 |
+
attention_mask=context_attention_mask,
|
823 |
+
pixel_values=context_pixel_values,
|
824 |
+
image_features=context_image_features,
|
825 |
+
concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
|
826 |
+
concat_output_from_text_encoder=context_concat_output_from_text_encoder,
|
827 |
+
keep_dims=True,
|
828 |
+
return_mask=True,
|
829 |
+
output_attentions=output_attentions,
|
830 |
+
output_hidden_states=output_hidden_states,
|
831 |
+
)
|
832 |
+
D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
|
833 |
+
|
834 |
+
# Gather tensors from other GPUs
|
835 |
+
if in_batch_negatives_from_all_gpus:
|
836 |
+
Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask)
|
837 |
+
# Repeat each query encoding for every corresponding document.
|
838 |
+
Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous()
|
839 |
+
|
840 |
+
scores = self.score(Q_duplicated, D, D_mask)
|
841 |
+
|
842 |
+
# Use contrastive learning
|
843 |
+
batch_size = query_input_ids.shape[0]
|
844 |
+
scores = scores.view(-1, num_negative_examples + 1)
|
845 |
+
labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
|
846 |
+
|
847 |
+
|
848 |
+
if use_in_batch_negatives:
|
849 |
+
ib_loss = self.compute_ib_loss_new(Q, D, D_mask)
|
850 |
+
loss = ib_loss
|
851 |
+
else:
|
852 |
+
loss = self.loss_fn(scores, labels)
|
853 |
+
ib_loss = None
|
854 |
+
|
855 |
+
if output_attentions:
|
856 |
+
query_attentions = (
|
857 |
+
query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None,
|
858 |
+
query_outputs.vision_encoder_attentions
|
859 |
+
if query_outputs.vision_encoder_attentions is not None
|
860 |
+
else None,
|
861 |
+
query_outputs.transformer_mapping_network_attentions
|
862 |
+
if query_outputs.transformer_mapping_network_attentions is not None
|
863 |
+
else None,
|
864 |
+
)
|
865 |
+
context_attentions = (
|
866 |
+
context_outputs.text_encoder_attentions
|
867 |
+
if context_outputs.text_encoder_attentions is not None
|
868 |
+
else None,
|
869 |
+
context_outputs.vision_encoder_attentions
|
870 |
+
if context_outputs.vision_encoder_attentions is not None
|
871 |
+
else None,
|
872 |
+
context_outputs.transformer_mapping_network_attentions
|
873 |
+
if context_outputs.transformer_mapping_network_attentions is not None
|
874 |
+
else None,
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
query_attentions = None
|
878 |
+
context_attentions = None
|
879 |
+
|
880 |
+
if output_hidden_states:
|
881 |
+
query_hidden_states = (
|
882 |
+
query_outputs.text_encoder_hidden_states
|
883 |
+
if query_outputs.text_encoder_hidden_states is not None
|
884 |
+
else None,
|
885 |
+
query_outputs.vision_encoder_hidden_states
|
886 |
+
if query_outputs.vision_encoder_hidden_states is not None
|
887 |
+
else None,
|
888 |
+
query_outputs.transformer_mapping_network_hidden_states
|
889 |
+
if query_outputs.transformer_mapping_network_hidden_states is not None
|
890 |
+
else None,
|
891 |
+
)
|
892 |
+
context_hidden_states = (
|
893 |
+
context_outputs.text_encoder_hidden_states
|
894 |
+
if context_outputs.text_encoder_hidden_states is not None
|
895 |
+
else None,
|
896 |
+
context_outputs.vision_encoder_hidden_states
|
897 |
+
if context_outputs.vision_encoder_hidden_states is not None
|
898 |
+
else None,
|
899 |
+
context_outputs.transformer_mapping_network_hidden_states
|
900 |
+
if context_outputs.transformer_mapping_network_hidden_states is not None
|
901 |
+
else None,
|
902 |
+
)
|
903 |
+
else:
|
904 |
+
query_hidden_states = None
|
905 |
+
context_hidden_states = None
|
906 |
+
|
907 |
+
if not return_dict:
|
908 |
+
if output_attentions and output_hidden_states:
|
909 |
+
return (
|
910 |
+
loss,
|
911 |
+
scores,
|
912 |
+
ib_loss,
|
913 |
+
query_outputs.late_interaction_output,
|
914 |
+
context_outputs.late_interaction_output,
|
915 |
+
query_attentions,
|
916 |
+
query_hidden_states,
|
917 |
+
context_attentions,
|
918 |
+
context_hidden_states,
|
919 |
+
)
|
920 |
+
elif output_attentions:
|
921 |
+
return (
|
922 |
+
loss,
|
923 |
+
scores,
|
924 |
+
ib_loss,
|
925 |
+
query_outputs.late_interaction_output,
|
926 |
+
context_outputs.late_interaction_output,
|
927 |
+
query_attentions,
|
928 |
+
context_attentions,
|
929 |
+
)
|
930 |
+
elif output_hidden_states:
|
931 |
+
return (
|
932 |
+
loss,
|
933 |
+
scores,
|
934 |
+
ib_loss,
|
935 |
+
query_outputs.late_interaction_output,
|
936 |
+
context_outputs.late_interaction_output,
|
937 |
+
query_hidden_states,
|
938 |
+
context_hidden_states,
|
939 |
+
)
|
940 |
+
else:
|
941 |
+
return (
|
942 |
+
loss,
|
943 |
+
scores,
|
944 |
+
ib_loss,
|
945 |
+
query_outputs.late_interaction_output,
|
946 |
+
context_outputs.late_interaction_output,
|
947 |
+
)
|
948 |
+
|
949 |
+
return FLMRModelForRetrievalOutput(
|
950 |
+
loss=loss,
|
951 |
+
scores=scores,
|
952 |
+
in_batch_negative_loss=ib_loss,
|
953 |
+
query_late_interaction_output=query_outputs.late_interaction_output,
|
954 |
+
context_late_interaction_output=context_outputs.late_interaction_output,
|
955 |
+
query_attentions=query_attentions if output_attentions else None,
|
956 |
+
query_hidden_states=query_hidden_states if output_hidden_states else None,
|
957 |
+
context_attentions=context_attentions if output_attentions else None,
|
958 |
+
context_hidden_states=context_hidden_states if output_hidden_states else None,
|
959 |
+
)
|
960 |
+
|
961 |
+
def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
|
962 |
+
# Q: batch_size x q_len x dim
|
963 |
+
# D: batch_size*n_docs x i_len x dim
|
964 |
+
# D_mask: batch_size*n_docs x i_len x dim
|
965 |
+
# 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim
|
966 |
+
# = batch_size x batch_size*n_docs x i_len x q_len
|
967 |
+
|
968 |
+
scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten(
|
969 |
+
0, 1
|
970 |
+
) # query-major unsqueeze
|
971 |
+
scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1))
|
972 |
+
|
973 |
+
in_batch_scores = scores.reshape(Q.size(0), -1)
|
974 |
+
|
975 |
+
batch_size = Q.shape[0]
|
976 |
+
batch_size_with_pos_and_neg = D.shape[0]
|
977 |
+
num_pos_and_neg = batch_size_with_pos_and_neg // batch_size
|
978 |
+
|
979 |
+
# batch_size x dim matmul dim x (num_pos+num_neg)*batch_size
|
980 |
+
# --> batch_size x (num_pos+num_neg)*batch_size
|
981 |
+
in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device)
|
982 |
+
step = num_pos_and_neg
|
983 |
+
for i in range(batch_size):
|
984 |
+
in_batch_labels[i, step * i] = 1
|
985 |
+
# print('in_batch_labels', in_batch_labels)
|
986 |
+
in_batch_labels = torch.argmax(in_batch_labels, dim=1)
|
987 |
+
# print('in_batch_labels', in_batch_labels)
|
988 |
+
|
989 |
+
loss = self.loss_fn(in_batch_scores, in_batch_labels)
|
990 |
+
|
991 |
+
return loss
|
992 |
+
|
993 |
+
def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask):
|
994 |
+
# print("get rank", get_rank())
|
995 |
+
# print("get world size", get_world_size())
|
996 |
+
# Gather embeddings from other GPUs
|
997 |
+
n_nodes = get_world_size()
|
998 |
+
if n_nodes == 1:
|
999 |
+
return query_embeddings, item_embeddings, item_mask
|
1000 |
+
# Create placeholder to hold embeddings passed from other ranks
|
1001 |
+
global_query_embeddings_placeholder = [
|
1002 |
+
torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device)
|
1003 |
+
for _ in range(n_nodes)
|
1004 |
+
]
|
1005 |
+
global_item_embeddings_placeholder = [
|
1006 |
+
torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device)
|
1007 |
+
for _ in range(n_nodes)
|
1008 |
+
]
|
1009 |
+
global_item_mask_placeholder = [
|
1010 |
+
torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes)
|
1011 |
+
]
|
1012 |
+
dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach())
|
1013 |
+
dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach())
|
1014 |
+
dist.all_gather(global_item_mask_placeholder, item_mask.detach())
|
1015 |
+
|
1016 |
+
global_query_embeddings = []
|
1017 |
+
global_item_embeddings = []
|
1018 |
+
global_item_mask = []
|
1019 |
+
# print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings)
|
1020 |
+
# print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings)
|
1021 |
+
# input()
|
1022 |
+
current_rank = get_rank()
|
1023 |
+
for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder):
|
1024 |
+
# We append the embeddings from other GPUs if this embedding does not require gradients
|
1025 |
+
if rank_index != current_rank:
|
1026 |
+
global_query_embeddings.append(remote_q_embeddings)
|
1027 |
+
else:
|
1028 |
+
global_query_embeddings.append(query_embeddings)
|
1029 |
+
|
1030 |
+
for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder):
|
1031 |
+
# We append the embeddings from other GPUs if this embedding does not require gradients
|
1032 |
+
if rank_index != current_rank:
|
1033 |
+
global_item_embeddings.append(remote_item_embeddings)
|
1034 |
+
else:
|
1035 |
+
global_item_embeddings.append(item_embeddings)
|
1036 |
+
|
1037 |
+
for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder):
|
1038 |
+
# We append the embeddings from other GPUs if this embedding does not require gradients
|
1039 |
+
if rank_index != current_rank:
|
1040 |
+
global_item_mask.append(remote_item_mask)
|
1041 |
+
else:
|
1042 |
+
global_item_mask.append(item_mask)
|
1043 |
+
|
1044 |
+
# Replace the previous variables with gathered tensors
|
1045 |
+
query_embeddings = torch.cat(global_query_embeddings)
|
1046 |
+
item_embeddings = torch.cat(global_item_embeddings)
|
1047 |
+
item_mask = torch.cat(global_item_mask)
|
1048 |
+
|
1049 |
+
return query_embeddings, item_embeddings, item_mask
|
1050 |
+
|
1051 |
+
@add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING)
|
1052 |
+
@replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
1053 |
+
def query(
|
1054 |
+
self,
|
1055 |
+
input_ids: torch.Tensor,
|
1056 |
+
attention_mask: torch.Tensor,
|
1057 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1058 |
+
image_features: Optional[torch.Tensor] = None,
|
1059 |
+
concat_output_from_vision_encoder: Optional[Union[bool, list]] = None,
|
1060 |
+
concat_output_from_text_encoder: Optional[Union[bool, list]] = None,
|
1061 |
+
output_attentions: Optional[bool] = None,
|
1062 |
+
output_hidden_states: Optional[bool] = None,
|
1063 |
+
):
|
1064 |
+
r"""
|
1065 |
+
Returns:
|
1066 |
+
|
1067 |
+
"""
|
1068 |
+
|
1069 |
+
if concat_output_from_vision_encoder is None:
|
1070 |
+
concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
|
1071 |
+
|
1072 |
+
if concat_output_from_text_encoder is None:
|
1073 |
+
concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
|
1074 |
+
|
1075 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1076 |
+
output_hidden_states = (
|
1077 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
input_modality = []
|
1081 |
+
if pixel_values is not None or image_features is not None:
|
1082 |
+
input_modality.append("image")
|
1083 |
+
if input_ids is not None and attention_mask is not None:
|
1084 |
+
input_modality.append("text")
|
1085 |
+
|
1086 |
+
text_encoder_outputs = None
|
1087 |
+
vision_encoder_outputs = None
|
1088 |
+
transformer_mapping_outputs = None
|
1089 |
+
|
1090 |
+
if "image" in input_modality:
|
1091 |
+
assert (
|
1092 |
+
pixel_values is not None or image_features is not None
|
1093 |
+
), "pixel_values or image_features must be provided if image modality is used"
|
1094 |
+
assert (
|
1095 |
+
pixel_values is None or image_features is None
|
1096 |
+
), "pixel_values and image_features cannot be provided at the same time"
|
1097 |
+
|
1098 |
+
if "text" in input_modality:
|
1099 |
+
assert (
|
1100 |
+
input_ids is not None and attention_mask is not None
|
1101 |
+
), "input_ids and attention_mask must be provided if text modality is used"
|
1102 |
+
# Forward the text encoder
|
1103 |
+
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
|
1104 |
+
text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask)
|
1105 |
+
text_encoder_hidden_states = text_encoder_outputs[0]
|
1106 |
+
text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states)
|
1107 |
+
mask = torch.tensor(self.query_mask(input_ids, skiplist=self.config.query_mask_input_ids_skip_list), device=self.device).unsqueeze(2).float()
|
1108 |
+
|
1109 |
+
text_embeddings = text_embeddings * mask
|
1110 |
+
|
1111 |
+
if "image" in input_modality:
|
1112 |
+
if pixel_values is not None:
|
1113 |
+
batch_size = pixel_values.shape[0]
|
1114 |
+
# Forward the vision encoder
|
1115 |
+
pixel_values = pixel_values.to(self.device)
|
1116 |
+
if len(pixel_values.shape) == 5:
|
1117 |
+
# Multiple ROIs are provided
|
1118 |
+
# merge the first two dimensions
|
1119 |
+
pixel_values = pixel_values.reshape(
|
1120 |
+
-1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
|
1121 |
+
)
|
1122 |
+
vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True)
|
1123 |
+
vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
|
1124 |
+
|
1125 |
+
if image_features is not None:
|
1126 |
+
batch_size = image_features.shape[0]
|
1127 |
+
vision_embeddings = image_features.to(self.device)
|
1128 |
+
|
1129 |
+
# Forward the vision projection / mapping network
|
1130 |
+
vision_embeddings = self.query_vision_projection(vision_embeddings)
|
1131 |
+
vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size)
|
1132 |
+
|
1133 |
+
if self.config.use_transformer_mapping_network:
|
1134 |
+
# select the second last layer
|
1135 |
+
vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
|
1136 |
+
# transformer_mapping
|
1137 |
+
transformer_mapping_input_features = self.transformer_mapping_input_linear(
|
1138 |
+
vision_second_last_layer_hidden_states
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# Cross attention only attends to the first 32 tokens
|
1142 |
+
encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
|
1143 |
+
if len(self.config.query_mask_input_ids_skip_list) > 0:
|
1144 |
+
encoder_mask[torch.isin(input_ids, torch.tensor(self.config.query_mask_input_ids_skip_list))] = 0
|
1145 |
+
cross_attention_length = self.config.transformer_mapping_cross_attention_length
|
1146 |
+
if text_encoder_hidden_states.shape[1] > cross_attention_length:
|
1147 |
+
text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
|
1148 |
+
encoder_mask = encoder_mask[:, :cross_attention_length]
|
1149 |
+
|
1150 |
+
# Obtain cross attention mask
|
1151 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1))
|
1152 |
+
# Pass through the transformer mapping
|
1153 |
+
transformer_mapping_outputs = self.transformer_mapping_network(
|
1154 |
+
transformer_mapping_input_features,
|
1155 |
+
encoder_hidden_states=text_encoder_hidden_states,
|
1156 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
1157 |
+
)
|
1158 |
+
transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
|
1159 |
+
# Convert the dimension to FLMR dim
|
1160 |
+
transformer_mapping_output_features = self.transformer_mapping_output_linear(
|
1161 |
+
transformer_mapping_output_features
|
1162 |
+
)
|
1163 |
+
# Merge with the vision embeddings
|
1164 |
+
vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)
|
1165 |
+
|
1166 |
+
if concat_output_from_vision_encoder and concat_output_from_text_encoder:
|
1167 |
+
Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
|
1168 |
+
if isinstance(concat_output_from_vision_encoder, list) or isinstance(concat_output_from_text_encoder, list):
|
1169 |
+
# When lists are passed in, mask the output accordingly
|
1170 |
+
assert isinstance(concat_output_from_vision_encoder, list) and isinstance(concat_output_from_text_encoder, list), "concat_output_from_vision_encoder and concat_output_from_text_encoder must be of the same type."
|
1171 |
+
# obtain the size of each output
|
1172 |
+
text_size = text_embeddings.shape[1]
|
1173 |
+
vision_size = vision_embeddings.shape[1]
|
1174 |
+
|
1175 |
+
# Prepare the mask
|
1176 |
+
concat_output_mask = torch.zeros_like(Q).to(Q.device)
|
1177 |
+
|
1178 |
+
# Mask the late interaction outputs
|
1179 |
+
concat_output_mask[:, :text_size] = torch.tensor(concat_output_from_text_encoder).bool().unsqueeze(-1).unsqueeze(-1)
|
1180 |
+
concat_output_mask[:, text_size:] = torch.tensor(concat_output_from_vision_encoder).bool().unsqueeze(-1).unsqueeze(-1)
|
1181 |
+
|
1182 |
+
Q = Q * concat_output_mask
|
1183 |
+
|
1184 |
+
elif concat_output_from_vision_encoder:
|
1185 |
+
Q = vision_embeddings
|
1186 |
+
elif concat_output_from_text_encoder:
|
1187 |
+
Q = text_embeddings
|
1188 |
+
|
1189 |
+
vision_encoder_attentions = (
|
1190 |
+
vision_encoder_outputs.attentions
|
1191 |
+
if vision_encoder_outputs is not None
|
1192 |
+
and hasattr(vision_encoder_outputs, "attentions")
|
1193 |
+
and output_attentions
|
1194 |
+
else None
|
1195 |
+
)
|
1196 |
+
vision_encoder_hidden_states = (
|
1197 |
+
vision_encoder_outputs.hidden_states
|
1198 |
+
if vision_encoder_outputs is not None
|
1199 |
+
and hasattr(vision_encoder_outputs, "hidden_states")
|
1200 |
+
and output_hidden_states
|
1201 |
+
else None
|
1202 |
+
)
|
1203 |
+
text_encoder_attentions = (
|
1204 |
+
text_encoder_outputs.attentions
|
1205 |
+
if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
|
1206 |
+
else None
|
1207 |
+
)
|
1208 |
+
text_encoder_hidden_states = (
|
1209 |
+
text_encoder_outputs.hidden_states
|
1210 |
+
if text_encoder_outputs is not None
|
1211 |
+
and hasattr(text_encoder_outputs, "hidden_states")
|
1212 |
+
and output_hidden_states
|
1213 |
+
else None
|
1214 |
+
)
|
1215 |
+
transformer_mapping_network_attentions = (
|
1216 |
+
transformer_mapping_outputs.attentions
|
1217 |
+
if transformer_mapping_outputs is not None
|
1218 |
+
and hasattr(transformer_mapping_outputs, "attentions")
|
1219 |
+
and output_attentions
|
1220 |
+
else None
|
1221 |
+
)
|
1222 |
+
transformer_mapping_network_hidden_states = (
|
1223 |
+
transformer_mapping_outputs.hidden_states
|
1224 |
+
if transformer_mapping_outputs is not None
|
1225 |
+
and hasattr(transformer_mapping_outputs, "hidden_states")
|
1226 |
+
and output_hidden_states
|
1227 |
+
else None
|
1228 |
+
)
|
1229 |
+
|
1230 |
+
return FLMRQueryEncoderOutput(
|
1231 |
+
pooler_output=Q[:, 0, :],
|
1232 |
+
late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
|
1233 |
+
vision_encoder_attentions=vision_encoder_attentions,
|
1234 |
+
vision_encoder_hidden_states=vision_encoder_hidden_states,
|
1235 |
+
text_encoder_attentions=text_encoder_attentions,
|
1236 |
+
text_encoder_hidden_states=text_encoder_hidden_states,
|
1237 |
+
transformer_mapping_network_attentions=transformer_mapping_network_attentions,
|
1238 |
+
transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
@add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING)
|
1242 |
+
@replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
1243 |
+
def doc(
|
1244 |
+
self,
|
1245 |
+
input_ids: torch.Tensor,
|
1246 |
+
attention_mask: torch.Tensor,
|
1247 |
+
pixel_values: Optional[torch.Tensor] = None,
|
1248 |
+
image_features: Optional[torch.Tensor] = None,
|
1249 |
+
concat_output_from_vision_encoder: Optional[bool] = None,
|
1250 |
+
concat_output_from_text_encoder: Optional[bool] = None,
|
1251 |
+
keep_dims: Optional[bool] = True,
|
1252 |
+
return_mask: Optional[bool] = True,
|
1253 |
+
output_attentions: Optional[bool] = None,
|
1254 |
+
output_hidden_states: Optional[bool] = None,
|
1255 |
+
):
|
1256 |
+
r"""
|
1257 |
+
Returns:
|
1258 |
+
|
1259 |
+
"""
|
1260 |
+
assert keep_dims in [True, False]
|
1261 |
+
|
1262 |
+
if concat_output_from_vision_encoder is None:
|
1263 |
+
concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
|
1264 |
+
|
1265 |
+
if concat_output_from_text_encoder is None:
|
1266 |
+
concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
|
1267 |
+
|
1268 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1269 |
+
output_hidden_states = (
|
1270 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1271 |
+
)
|
1272 |
+
|
1273 |
+
input_modality = []
|
1274 |
+
if pixel_values is not None or image_features is not None:
|
1275 |
+
input_modality.append("image")
|
1276 |
+
if input_ids is not None and attention_mask is not None:
|
1277 |
+
input_modality.append("text")
|
1278 |
+
|
1279 |
+
text_encoder_outputs = None
|
1280 |
+
vision_encoder_outputs = None
|
1281 |
+
|
1282 |
+
if "image" in input_modality:
|
1283 |
+
assert (
|
1284 |
+
pixel_values is not None or image_features is not None
|
1285 |
+
), "pixel_values or image_features must be provided if image modality is used"
|
1286 |
+
assert (
|
1287 |
+
pixel_values is None or image_features is None
|
1288 |
+
), "pixel_values and image_features cannot be provided at the same time"
|
1289 |
+
|
1290 |
+
if "text" in input_modality:
|
1291 |
+
assert (
|
1292 |
+
input_ids is not None and attention_mask is not None
|
1293 |
+
), "input_ids and attention_mask must be provided if text modality is used"
|
1294 |
+
# Forward the text encoder
|
1295 |
+
input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
|
1296 |
+
text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask)
|
1297 |
+
text_embeddings = text_encoder_outputs[0]
|
1298 |
+
text_embeddings = self.context_text_encoder_linear(text_embeddings)
|
1299 |
+
|
1300 |
+
mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
|
1301 |
+
text_embeddings = text_embeddings * mask
|
1302 |
+
|
1303 |
+
if "image" in input_modality:
|
1304 |
+
if pixel_values is not None:
|
1305 |
+
# Forward the vision encoder
|
1306 |
+
pixel_values = pixel_values.to(self.device)
|
1307 |
+
vision_encoder_outputs = self.context_vision_encoder(pixel_values)
|
1308 |
+
vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
|
1309 |
+
|
1310 |
+
if image_features is not None:
|
1311 |
+
vision_embeddings = image_features.to(self.device)
|
1312 |
+
|
1313 |
+
batch_size = vision_embeddings.shape[0]
|
1314 |
+
|
1315 |
+
# Forward the vision projection / mapping network
|
1316 |
+
vision_embeddings = self.context_vision_projection(vision_embeddings)
|
1317 |
+
vision_embeddings = vision_embeddings.view(
|
1318 |
+
-1, self.mapping_network_prefix_length, self.late_interaction_embedding_size
|
1319 |
+
)
|
1320 |
+
|
1321 |
+
image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device)
|
1322 |
+
|
1323 |
+
if concat_output_from_vision_encoder and concat_output_from_text_encoder:
|
1324 |
+
# Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask
|
1325 |
+
# TODO: fix the engine to support masks with discontinuous 0 and 1.
|
1326 |
+
D = torch.cat([vision_embeddings, text_embeddings], dim=1)
|
1327 |
+
# concatenate the mask
|
1328 |
+
mask = torch.cat([image_mask, mask], dim=1)
|
1329 |
+
elif concat_output_from_vision_encoder:
|
1330 |
+
D = vision_embeddings
|
1331 |
+
mask = image_mask
|
1332 |
+
elif concat_output_from_text_encoder:
|
1333 |
+
D = text_embeddings
|
1334 |
+
mask = mask
|
1335 |
+
|
1336 |
+
D = torch.nn.functional.normalize(D, p=2, dim=2)
|
1337 |
+
|
1338 |
+
if self.use_gpu:
|
1339 |
+
D = D.half()
|
1340 |
+
|
1341 |
+
if keep_dims is False:
|
1342 |
+
D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
|
1343 |
+
D = [d[mask[idx]] for idx, d in enumerate(D)]
|
1344 |
+
|
1345 |
+
vision_encoder_attentions = (
|
1346 |
+
vision_encoder_outputs.attentions
|
1347 |
+
if vision_encoder_outputs is not None
|
1348 |
+
and hasattr(vision_encoder_outputs, "attentions")
|
1349 |
+
and output_attentions
|
1350 |
+
else None
|
1351 |
+
)
|
1352 |
+
vision_encoder_hidden_states = (
|
1353 |
+
vision_encoder_outputs.hidden_states
|
1354 |
+
if vision_encoder_outputs is not None
|
1355 |
+
and hasattr(vision_encoder_outputs, "hidden_states")
|
1356 |
+
and output_hidden_states
|
1357 |
+
else None
|
1358 |
+
)
|
1359 |
+
text_encoder_attentions = (
|
1360 |
+
text_encoder_outputs.attentions
|
1361 |
+
if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
|
1362 |
+
else None
|
1363 |
+
)
|
1364 |
+
text_encoder_hidden_states = (
|
1365 |
+
text_encoder_outputs.hidden_states
|
1366 |
+
if text_encoder_outputs is not None
|
1367 |
+
and hasattr(text_encoder_outputs, "hidden_states")
|
1368 |
+
and output_hidden_states
|
1369 |
+
else None
|
1370 |
+
)
|
1371 |
+
|
1372 |
+
return FLMRContextEncoderOutput(
|
1373 |
+
pooler_output=D[:, 0, :],
|
1374 |
+
late_interaction_output=D,
|
1375 |
+
context_mask=mask.bool() if return_mask else None,
|
1376 |
+
vision_encoder_attentions=vision_encoder_attentions,
|
1377 |
+
vision_encoder_hidden_states=vision_encoder_hidden_states,
|
1378 |
+
text_encoder_attentions=text_encoder_attentions,
|
1379 |
+
text_encoder_hidden_states=text_encoder_hidden_states,
|
1380 |
+
)
|
1381 |
+
|
1382 |
+
def score(self, Q, D_padded, D_mask):
|
1383 |
+
# assert self.colbert_config.similarity == 'cosine'
|
1384 |
+
# if self.colbert_config.similarity == 'l2':
|
1385 |
+
# assert self.colbert_config.interaction == 'colbert'
|
1386 |
+
# return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
|
1387 |
+
return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu)
|
1388 |
+
|
1389 |
+
def mask(self, input_ids, skiplist):
|
1390 |
+
mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
|
1391 |
+
return mask
|
1392 |
+
|
1393 |
+
|
1394 |
+
@add_start_docstrings(
|
1395 |
+
"The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.",
|
1396 |
+
FLMR_TEXT_ENCODERS_START_DOCSTRING,
|
1397 |
+
)
|
1398 |
+
class FLMRTextModel(FLMRPreTrainedModel):
|
1399 |
+
base_model_prefix = "flmr_text_model"
|
1400 |
+
config_class = FLMRTextConfig
|
1401 |
+
|
1402 |
+
def __init__(self, config: FLMRTextConfig, *args, **kwargs):
|
1403 |
+
super().__init__(config)
|
1404 |
+
if config.text_encoder_base_model == "bert-base-uncased":
|
1405 |
+
self.bert_model = BertModel(config, add_pooling_layer=True)
|
1406 |
+
else:
|
1407 |
+
self.bert_model = AutoModel.from_pretrained(config.text_encoder_base_model, *args, **kwargs)
|
1408 |
+
if self.bert_model.config.hidden_size <= 0:
|
1409 |
+
raise ValueError("Encoder hidden_size can't be zero")
|
1410 |
+
self.projection_dim = config.projection_dim
|
1411 |
+
if self.projection_dim > 0:
|
1412 |
+
self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
|
1413 |
+
# Initialize weights and apply final processing
|
1414 |
+
self.post_init()
|
1415 |
+
self.text_model = self.bert_model
|
1416 |
+
|
1417 |
+
@add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING)
|
1418 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig)
|
1419 |
+
def forward(
|
1420 |
+
self,
|
1421 |
+
input_ids: Optional[Tensor] = None,
|
1422 |
+
attention_mask: Optional[Tensor] = None,
|
1423 |
+
token_type_ids: Optional[Tensor] = None,
|
1424 |
+
inputs_embeds: Optional[Tensor] = None,
|
1425 |
+
output_attentions: bool = None,
|
1426 |
+
output_hidden_states: bool = None,
|
1427 |
+
return_dict: bool = None,
|
1428 |
+
) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
|
1429 |
+
r"""
|
1430 |
+
Returns:
|
1431 |
+
|
1432 |
+
"""
|
1433 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1434 |
+
output_hidden_states = (
|
1435 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1436 |
+
)
|
1437 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1438 |
+
|
1439 |
+
outputs = self.text_model(
|
1440 |
+
input_ids=input_ids,
|
1441 |
+
attention_mask=attention_mask,
|
1442 |
+
token_type_ids=token_type_ids,
|
1443 |
+
inputs_embeds=inputs_embeds,
|
1444 |
+
output_attentions=output_attentions,
|
1445 |
+
output_hidden_states=output_hidden_states,
|
1446 |
+
return_dict=return_dict,
|
1447 |
+
)
|
1448 |
+
sequence_output = outputs[0]
|
1449 |
+
pooled_output = sequence_output[:, 0, :]
|
1450 |
+
|
1451 |
+
if self.projection_dim > 0:
|
1452 |
+
pooled_output = self.encode_proj(pooled_output)
|
1453 |
+
|
1454 |
+
if not return_dict:
|
1455 |
+
return (sequence_output, pooled_output) + outputs[2:]
|
1456 |
+
|
1457 |
+
return BaseModelOutputWithPooling(
|
1458 |
+
last_hidden_state=sequence_output,
|
1459 |
+
pooler_output=pooled_output,
|
1460 |
+
hidden_states=outputs.hidden_states,
|
1461 |
+
attentions=outputs.attentions,
|
1462 |
+
)
|
1463 |
+
|
1464 |
+
@property
|
1465 |
+
def embeddings_size(self) -> int:
|
1466 |
+
if self.projection_dim > 0:
|
1467 |
+
return self.encode_proj.out_features
|
1468 |
+
return self.text_model.config.hidden_size
|
1469 |
+
|
1470 |
+
|
1471 |
+
@add_start_docstrings(
|
1472 |
+
"The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.",
|
1473 |
+
FLMR_VISION_ENCODERS_START_DOCSTRING,
|
1474 |
+
)
|
1475 |
+
class FLMRVisionModel(FLMRPreTrainedModel):
|
1476 |
+
base_model_prefix = "vision_model"
|
1477 |
+
config_class = FLMRVisionConfig
|
1478 |
+
main_input_name = "pixel_values"
|
1479 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
1480 |
+
|
1481 |
+
def __init__(self, config: FLMRVisionConfig):
|
1482 |
+
super().__init__(config)
|
1483 |
+
self.vision_model = CLIPVisionModel(config)
|
1484 |
+
self.post_init()
|
1485 |
+
|
1486 |
+
def get_input_embeddings(self) -> nn.Module:
|
1487 |
+
return self.vision_model.vision_model.embeddings.patch_embedding
|
1488 |
+
|
1489 |
+
@add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING)
|
1490 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig)
|
1491 |
+
def forward(
|
1492 |
+
self,
|
1493 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
1494 |
+
output_attentions: Optional[bool] = None,
|
1495 |
+
output_hidden_states: Optional[bool] = None,
|
1496 |
+
return_dict: Optional[bool] = None,
|
1497 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
1498 |
+
r"""
|
1499 |
+
Returns:
|
1500 |
+
|
1501 |
+
Examples:
|
1502 |
+
|
1503 |
+
```python
|
1504 |
+
>>> from PIL import Image
|
1505 |
+
>>> import requests
|
1506 |
+
>>> from transformers import AutoProcessor, FLMRVisionModel
|
1507 |
+
|
1508 |
+
>>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
1509 |
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
1510 |
+
|
1511 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
1512 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1513 |
+
|
1514 |
+
>>> inputs = processor(images=image, return_tensors="pt")
|
1515 |
+
|
1516 |
+
>>> outputs = model(**inputs)
|
1517 |
+
>>> last_hidden_state = outputs.last_hidden_state
|
1518 |
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
1519 |
+
```"""
|
1520 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1521 |
+
|
1522 |
+
return self.vision_model(
|
1523 |
+
pixel_values=pixel_values,
|
1524 |
+
output_attentions=output_attentions,
|
1525 |
+
output_hidden_states=output_hidden_states,
|
1526 |
+
return_dict=return_dict,
|
1527 |
+
)
|
query_tokenizer/sentencepiece.bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
|
3 |
+
size 5069051
|
query_tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "<unk>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
query_tokenizer/tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:249df0778f236f6ece390de0de746838ef25b9d6954b68c2ee71249e0a9d8fd4
|
3 |
+
size 17082799
|
query_tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"250001": {
|
36 |
+
"content": "<mask>",
|
37 |
+
"lstrip": true,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "<s>",
|
47 |
+
"eos_token": "</s>",
|
48 |
+
"mask_token": "<mask>",
|
49 |
+
"model_max_length": 8192,
|
50 |
+
"pad_token": "<pad>",
|
51 |
+
"sep_token": "</s>",
|
52 |
+
"sp_model_kwargs": {},
|
53 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
54 |
+
"unk_token": "<unk>"
|
55 |
+
}
|
segmented_maxsim.cpp
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <pthread.h>
|
2 |
+
#include <torch/extension.h>
|
3 |
+
|
4 |
+
#include <algorithm>
|
5 |
+
#include <numeric>
|
6 |
+
|
7 |
+
typedef struct {
|
8 |
+
int tid;
|
9 |
+
int nthreads;
|
10 |
+
|
11 |
+
int ndocs;
|
12 |
+
int ndoc_vectors;
|
13 |
+
int nquery_vectors;
|
14 |
+
|
15 |
+
int64_t* lengths;
|
16 |
+
float* scores;
|
17 |
+
int64_t* offsets;
|
18 |
+
|
19 |
+
float* max_scores;
|
20 |
+
} max_args_t;
|
21 |
+
|
22 |
+
void* max(void* args) {
|
23 |
+
max_args_t* max_args = (max_args_t*)args;
|
24 |
+
|
25 |
+
int ndocs_per_thread =
|
26 |
+
std::ceil(((float)max_args->ndocs) / max_args->nthreads);
|
27 |
+
int start = max_args->tid * ndocs_per_thread;
|
28 |
+
int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
|
29 |
+
|
30 |
+
auto max_scores_offset =
|
31 |
+
max_args->max_scores + (start * max_args->nquery_vectors);
|
32 |
+
auto scores_offset =
|
33 |
+
max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
|
34 |
+
|
35 |
+
for (int i = start; i < end; i++) {
|
36 |
+
for (int j = 0; j < max_args->lengths[i]; j++) {
|
37 |
+
std::transform(max_scores_offset,
|
38 |
+
max_scores_offset + max_args->nquery_vectors,
|
39 |
+
scores_offset, max_scores_offset,
|
40 |
+
[](float a, float b) { return std::max(a, b); });
|
41 |
+
scores_offset += max_args->nquery_vectors;
|
42 |
+
}
|
43 |
+
max_scores_offset += max_args->nquery_vectors;
|
44 |
+
}
|
45 |
+
|
46 |
+
return NULL;
|
47 |
+
}
|
48 |
+
|
49 |
+
torch::Tensor segmented_maxsim(const torch::Tensor scores,
|
50 |
+
const torch::Tensor lengths) {
|
51 |
+
auto lengths_a = lengths.data_ptr<int64_t>();
|
52 |
+
auto scores_a = scores.data_ptr<float>();
|
53 |
+
auto ndocs = lengths.size(0);
|
54 |
+
auto ndoc_vectors = scores.size(0);
|
55 |
+
auto nquery_vectors = scores.size(1);
|
56 |
+
auto nthreads = at::get_num_threads();
|
57 |
+
|
58 |
+
torch::Tensor max_scores =
|
59 |
+
torch::zeros({ndocs, nquery_vectors}, scores.options());
|
60 |
+
|
61 |
+
int64_t offsets[ndocs + 1];
|
62 |
+
offsets[0] = 0;
|
63 |
+
std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
|
64 |
+
|
65 |
+
pthread_t threads[nthreads];
|
66 |
+
max_args_t args[nthreads];
|
67 |
+
|
68 |
+
for (int i = 0; i < nthreads; i++) {
|
69 |
+
args[i].tid = i;
|
70 |
+
args[i].nthreads = nthreads;
|
71 |
+
|
72 |
+
args[i].ndocs = ndocs;
|
73 |
+
args[i].ndoc_vectors = ndoc_vectors;
|
74 |
+
args[i].nquery_vectors = nquery_vectors;
|
75 |
+
|
76 |
+
args[i].lengths = lengths_a;
|
77 |
+
args[i].scores = scores_a;
|
78 |
+
args[i].offsets = offsets;
|
79 |
+
|
80 |
+
args[i].max_scores = max_scores.data_ptr<float>();
|
81 |
+
|
82 |
+
int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
|
83 |
+
if (rc) {
|
84 |
+
fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
|
85 |
+
}
|
86 |
+
}
|
87 |
+
|
88 |
+
for (int i = 0; i < nthreads; i++) {
|
89 |
+
pthread_join(threads[i], NULL);
|
90 |
+
}
|
91 |
+
|
92 |
+
return max_scores.sum(1);
|
93 |
+
}
|
94 |
+
|
95 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
96 |
+
m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
|
97 |
+
}
|
tokenization_flmr.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for FLMR."""
|
16 |
+
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
from transformers.utils import TensorType, logging
|
21 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
22 |
+
from transformers import AutoTokenizer
|
23 |
+
from .configuration_flmr import FLMRTextConfig
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
|
28 |
+
|
29 |
+
CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
30 |
+
"vocab_file": {
|
31 |
+
"LinWeizheDragon/PreFLMR_ViT-L": (
|
32 |
+
"https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
|
33 |
+
),
|
34 |
+
"LinWeizheDragon/FLMR": (
|
35 |
+
"https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
|
36 |
+
),
|
37 |
+
},
|
38 |
+
"tokenizer_file": {
|
39 |
+
"LinWeizheDragon/PreFLMR_ViT-L": (
|
40 |
+
"https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
|
41 |
+
),
|
42 |
+
"LinWeizheDragon/FLMR": (
|
43 |
+
"https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
|
44 |
+
),
|
45 |
+
},
|
46 |
+
}
|
47 |
+
QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
|
48 |
+
"vocab_file": {
|
49 |
+
"LinWeizheDragon/PreFLMR_ViT-L": (
|
50 |
+
"https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
|
51 |
+
),
|
52 |
+
"LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
|
53 |
+
},
|
54 |
+
"tokenizer_file": {
|
55 |
+
"LinWeizheDragon/PreFLMR_ViT-L": (
|
56 |
+
"https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
|
57 |
+
),
|
58 |
+
"LinWeizheDragon/FLMR": (
|
59 |
+
"https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
|
60 |
+
),
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
66 |
+
"LinWeizheDragon/PreFLMR_ViT-L": 512,
|
67 |
+
"LinWeizheDragon/FLMR": 512,
|
68 |
+
}
|
69 |
+
QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
70 |
+
"LinWeizheDragon/PreFLMR_ViT-L": 512,
|
71 |
+
"LinWeizheDragon/FLMR": 512,
|
72 |
+
}
|
73 |
+
|
74 |
+
|
75 |
+
CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
76 |
+
"LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
|
77 |
+
"LinWeizheDragon/FLMR": {"do_lower_case": True},
|
78 |
+
}
|
79 |
+
QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
|
80 |
+
"LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
|
81 |
+
"LinWeizheDragon/FLMR": {"do_lower_case": True},
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
# Modified from colbert.modeling.tokenization
|
86 |
+
class FLMRBertContextEncoderTokenizer(BertTokenizer):
|
87 |
+
r"""
|
88 |
+
Construct a FLMRContextEncoder tokenizer.
|
89 |
+
|
90 |
+
[`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
|
91 |
+
splitting and wordpiece.
|
92 |
+
|
93 |
+
Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
|
94 |
+
"""
|
95 |
+
|
96 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
97 |
+
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
98 |
+
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
99 |
+
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
doc_maxlen: Optional[int] = 512,
|
104 |
+
**kwargs,
|
105 |
+
):
|
106 |
+
super().__init__(
|
107 |
+
doc_maxlen=doc_maxlen,
|
108 |
+
**kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.doc_maxlen = doc_maxlen
|
112 |
+
self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
|
113 |
+
|
114 |
+
def __call__(
|
115 |
+
self,
|
116 |
+
text: List[str],
|
117 |
+
padding: Optional[Union[str, bool]] = "max_length",
|
118 |
+
truncation: Optional[Union[bool, str]] = "longest_first",
|
119 |
+
max_length: Optional[int] = 512,
|
120 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
# add placehold for the [D] marker
|
124 |
+
text = [". " + x for x in text]
|
125 |
+
|
126 |
+
if max_length > self.doc_maxlen:
|
127 |
+
# can not exceed the pre-set length
|
128 |
+
max_length = self.doc_maxlen
|
129 |
+
|
130 |
+
encoding = super().__call__(
|
131 |
+
text,
|
132 |
+
padding=padding,
|
133 |
+
truncation=truncation,
|
134 |
+
return_tensors=return_tensors,
|
135 |
+
max_length=max_length,
|
136 |
+
**kwargs,
|
137 |
+
)
|
138 |
+
|
139 |
+
ids, mask = encoding["input_ids"], encoding["attention_mask"]
|
140 |
+
|
141 |
+
# postprocess for the [D] marker
|
142 |
+
ids[:, 1] = self.D_marker_token_id
|
143 |
+
|
144 |
+
# if bsize:
|
145 |
+
# # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
|
146 |
+
# if image_features is not None:
|
147 |
+
# ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
|
148 |
+
# batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
|
149 |
+
# else:
|
150 |
+
# ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
|
151 |
+
# batches = _split_into_batches(ids, mask, bsize)
|
152 |
+
|
153 |
+
# return batches, reverse_indices
|
154 |
+
|
155 |
+
encoding["input_ids"] = ids
|
156 |
+
encoding["attention_mask"] = mask
|
157 |
+
|
158 |
+
return encoding
|
159 |
+
|
160 |
+
|
161 |
+
# Modified from colbert.modeling.tokenization
|
162 |
+
class FLMRBertQueryEncoderTokenizer(BertTokenizer):
|
163 |
+
r"""
|
164 |
+
Constructs a FLMRQueryEncoder tokenizer.
|
165 |
+
|
166 |
+
[`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
|
167 |
+
splitting and wordpiece.
|
168 |
+
|
169 |
+
Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
|
170 |
+
"""
|
171 |
+
|
172 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
173 |
+
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
174 |
+
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
175 |
+
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
176 |
+
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
*args,
|
180 |
+
query_maxlen: Optional[int] = 32,
|
181 |
+
attend_to_mask_tokens: Optional[bool] = False,
|
182 |
+
**kwargs,
|
183 |
+
):
|
184 |
+
super().__init__(
|
185 |
+
*args,
|
186 |
+
query_maxlen=query_maxlen,
|
187 |
+
attend_to_mask_tokens=attend_to_mask_tokens,
|
188 |
+
**kwargs,
|
189 |
+
)
|
190 |
+
|
191 |
+
self.query_maxlen = query_maxlen
|
192 |
+
self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
|
193 |
+
self.attend_to_mask_tokens = attend_to_mask_tokens
|
194 |
+
|
195 |
+
self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
|
196 |
+
|
197 |
+
def __call__(
|
198 |
+
self,
|
199 |
+
text: Union[str, List[str]],
|
200 |
+
padding: Optional[Union[str, bool]] = "max_length",
|
201 |
+
truncation: Optional[Union[bool, str]] = True,
|
202 |
+
max_length: Optional[int] = None,
|
203 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
204 |
+
**kwargs,
|
205 |
+
):
|
206 |
+
if isinstance(text, str):
|
207 |
+
# convert to list if input is a single string
|
208 |
+
text = [text]
|
209 |
+
|
210 |
+
# add placehold for the [Q] marker
|
211 |
+
text = [". " + x for x in text]
|
212 |
+
|
213 |
+
if max_length is not None:
|
214 |
+
# use user specified max_length
|
215 |
+
pass
|
216 |
+
else:
|
217 |
+
# use default max length
|
218 |
+
max_length = self.query_maxlen
|
219 |
+
|
220 |
+
encoding = super().__call__(
|
221 |
+
text,
|
222 |
+
padding=padding,
|
223 |
+
truncation=truncation,
|
224 |
+
return_tensors=return_tensors,
|
225 |
+
max_length=max_length,
|
226 |
+
**kwargs,
|
227 |
+
)
|
228 |
+
|
229 |
+
ids, mask = encoding["input_ids"], encoding["attention_mask"]
|
230 |
+
|
231 |
+
# postprocess for the [Q] marker and the [MASK] augmentation
|
232 |
+
ids[:, 1] = self.Q_marker_token_id
|
233 |
+
ids[ids == self.pad_token_id] = self.mask_token_id
|
234 |
+
|
235 |
+
if self.attend_to_mask_tokens:
|
236 |
+
# When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
|
237 |
+
mask[ids == self.mask_token_id] = 1
|
238 |
+
assert mask.sum().item() == mask.size(0) * mask.size(1), mask
|
239 |
+
|
240 |
+
return {"input_ids": ids, "attention_mask": mask}
|
241 |
+
|
242 |
+
class FLMRAutoContextEncoderTokenizer:
|
243 |
+
r"""
|
244 |
+
Construct a ContextEncoderTokenizer tokenizer with AutoTokenizer.
|
245 |
+
|
246 |
+
[`FLMRAutoContextEncoderTokenizer`] is identical to [`AutoTokenizer`] and runs end-to-end tokenization: punctuation
|
247 |
+
splitting and wordpiece.
|
248 |
+
|
249 |
+
Refer to superclass [`AutoTokenizer`] for usage examples and documentation concerning parameters.
|
250 |
+
"""
|
251 |
+
|
252 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
253 |
+
pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
254 |
+
max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
255 |
+
pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
*args,
|
260 |
+
doc_maxlen: Optional[int] = 512,
|
261 |
+
**kwargs,
|
262 |
+
):
|
263 |
+
self.doc_maxlen = doc_maxlen
|
264 |
+
self.tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
|
265 |
+
self.additional_special_tokens = self.tokenizer.additional_special_tokens
|
266 |
+
|
267 |
+
|
268 |
+
def __call__(
|
269 |
+
self,
|
270 |
+
text: List[str],
|
271 |
+
padding: Optional[Union[str, bool]] = "max_length",
|
272 |
+
truncation: Optional[Union[bool, str]] = "longest_first",
|
273 |
+
max_length: Optional[int] = 512,
|
274 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
275 |
+
**kwargs,
|
276 |
+
):
|
277 |
+
# add placehold for the [D] marker
|
278 |
+
text = [". " + x for x in text]
|
279 |
+
|
280 |
+
if max_length > self.doc_maxlen:
|
281 |
+
# can not exceed the pre-set length
|
282 |
+
max_length = self.doc_maxlen
|
283 |
+
|
284 |
+
encoding = self.tokenizer(
|
285 |
+
text,
|
286 |
+
padding=padding,
|
287 |
+
truncation=True,
|
288 |
+
return_tensors=return_tensors,
|
289 |
+
max_length=max_length,
|
290 |
+
**kwargs,
|
291 |
+
)
|
292 |
+
|
293 |
+
ids, mask = encoding["input_ids"], encoding["attention_mask"]
|
294 |
+
|
295 |
+
encoding["input_ids"] = ids
|
296 |
+
encoding["attention_mask"] = mask
|
297 |
+
|
298 |
+
return encoding
|
299 |
+
|
300 |
+
def encode(self, text, text_pair=None, add_special_tokens=True, **kwargs):
|
301 |
+
return self.tokenizer.encode(text, text_pair, add_special_tokens, **kwargs)
|
302 |
+
|
303 |
+
def add_special_tokens(self, token, **kwargs):
|
304 |
+
return self.tokenizer.add_special_tokens(token, **kwargs)
|
305 |
+
|
306 |
+
def save_pretrained(self, path):
|
307 |
+
self.tokenizer.save_pretrained(path)
|
308 |
+
|
309 |
+
|
310 |
+
# Modified from colbert.modeling.tokenization
|
311 |
+
class FLMRAutoQueryEncoderTokenizer:
|
312 |
+
r"""
|
313 |
+
Constructs a QueryEncoderTokenizer tokenizer with AutoTokenizer.
|
314 |
+
|
315 |
+
[`FLMRAutoQueryEncoderTokenizer`] is identical to [`AutoTokenizer`] and runs end-to-end tokenization: punctuation
|
316 |
+
splitting and wordpiece.
|
317 |
+
|
318 |
+
Refer to superclass [`AutoTokenizer`] for usage examples and documentation concerning parameters.
|
319 |
+
"""
|
320 |
+
|
321 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
322 |
+
pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
|
323 |
+
max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
324 |
+
pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
|
325 |
+
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
*args,
|
329 |
+
query_maxlen: Optional[int] = 32,
|
330 |
+
attend_to_mask_tokens: Optional[bool] = False,
|
331 |
+
**kwargs,
|
332 |
+
):
|
333 |
+
self.tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
|
334 |
+
self.additional_special_tokens = self.tokenizer.additional_special_tokens
|
335 |
+
self.query_maxlen = query_maxlen
|
336 |
+
self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
|
337 |
+
self.attend_to_mask_tokens = attend_to_mask_tokens
|
338 |
+
|
339 |
+
def __call__(
|
340 |
+
self,
|
341 |
+
text: Union[str, List[str]],
|
342 |
+
padding: Optional[Union[str, bool]] = "max_length",
|
343 |
+
truncation: Optional[Union[bool, str]] = True,
|
344 |
+
max_length: Optional[int] = None,
|
345 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
346 |
+
**kwargs,
|
347 |
+
):
|
348 |
+
if isinstance(text, str):
|
349 |
+
# convert to list if input is a single string
|
350 |
+
text = [text]
|
351 |
+
|
352 |
+
# add placehold for the [Q] marker
|
353 |
+
text = [". " + x for x in text]
|
354 |
+
|
355 |
+
if max_length is not None:
|
356 |
+
# use user specified max_length
|
357 |
+
pass
|
358 |
+
else:
|
359 |
+
# use default max length
|
360 |
+
max_length = self.query_maxlen
|
361 |
+
|
362 |
+
encoding = self.tokenizer(
|
363 |
+
text,
|
364 |
+
padding=padding,
|
365 |
+
truncation=True,
|
366 |
+
return_tensors=return_tensors,
|
367 |
+
max_length=max_length,
|
368 |
+
**kwargs,
|
369 |
+
)
|
370 |
+
|
371 |
+
ids, mask = encoding["input_ids"], encoding["attention_mask"]
|
372 |
+
|
373 |
+
if self.attend_to_mask_tokens:
|
374 |
+
# When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
|
375 |
+
mask[ids == self.mask_token_id] = 1
|
376 |
+
assert mask.sum().item() == mask.size(0) * mask.size(1), mask
|
377 |
+
|
378 |
+
return {"input_ids": ids, "attention_mask": mask}
|
379 |
+
|
380 |
+
def encode(self, text, text_pair=None, add_special_tokens=True, **kwargs):
|
381 |
+
return self.tokenizer.encode(text, text_pair, add_special_tokens, **kwargs)
|
382 |
+
|
383 |
+
def add_special_tokens(self, token, **kwargs):
|
384 |
+
return self.tokenizer.add_special_tokens(token, **kwargs)
|
385 |
+
|
386 |
+
def save_pretrained(self, path):
|
387 |
+
self.tokenizer.save_pretrained(path)
|
388 |
+
|
389 |
+
|
390 |
+
class FLMRContextEncoderTokenizer:
|
391 |
+
r"""
|
392 |
+
Constructs a FLMRContextEncoderTokenizer tokenizer.
|
393 |
+
|
394 |
+
[`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] or [`AutoTokenizer`], depends on whether
|
395 |
+
the tokenizer is initialized from bert.
|
396 |
+
"""
|
397 |
+
def __init__(self) -> None:
|
398 |
+
pass
|
399 |
+
|
400 |
+
@classmethod
|
401 |
+
def from_pretrained(
|
402 |
+
cls,
|
403 |
+
*args,
|
404 |
+
text_config: Optional[FLMRTextConfig] = None,
|
405 |
+
**kwargs,
|
406 |
+
):
|
407 |
+
if text_config.text_encoder_base_model == "bert-base-uncased":
|
408 |
+
return FLMRBertContextEncoderTokenizer.from_pretrained(*args, **kwargs)
|
409 |
+
else:
|
410 |
+
return FLMRAutoContextEncoderTokenizer(*args, **kwargs)
|
411 |
+
|
412 |
+
class FLMRQueryEncoderTokenizer:
|
413 |
+
r"""
|
414 |
+
Constructs a FLMRContextEncoderTokenizer tokenizer.
|
415 |
+
|
416 |
+
[`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] or [`AutoTokenizer`], depends on whether
|
417 |
+
the tokenizer is initialized from bert.
|
418 |
+
"""
|
419 |
+
def __init__(self) -> None:
|
420 |
+
pass
|
421 |
+
|
422 |
+
@classmethod
|
423 |
+
def from_pretrained(
|
424 |
+
cls,
|
425 |
+
*args,
|
426 |
+
text_config: Optional[FLMRTextConfig] = None,
|
427 |
+
**kwargs,
|
428 |
+
):
|
429 |
+
if text_config.text_encoder_base_model == "bert-base-uncased":
|
430 |
+
return FLMRBertQueryEncoderTokenizer.from_pretrained(*args, **kwargs)
|
431 |
+
else:
|
432 |
+
return FLMRAutoQueryEncoderTokenizer(*args, query_maxlen=text_config.query_maxlen, **kwargs)
|