HUANG-Stephanie commited on
Commit
9ff79dc
1 Parent(s): c4d37d5

Upload 88 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. colpali-main/.env.dist +5 -0
  2. colpali-main/.gitattributes +31 -0
  3. colpali-main/.gitignore +179 -0
  4. colpali-main/.python-version +1 -0
  5. colpali-main/LICENSE +21 -0
  6. colpali-main/README.md +222 -0
  7. colpali-main/colpali_engine/__init__.py +0 -0
  8. colpali-main/colpali_engine/__pycache__/__init__.cpython-310.pyc +0 -0
  9. colpali-main/colpali_engine/dataset/__init__.py +0 -0
  10. colpali-main/colpali_engine/dataset/custom_collator.py +244 -0
  11. colpali-main/colpali_engine/dataset/hf_dataset_names.py +52 -0
  12. colpali-main/colpali_engine/evaluation/__init__.py +1 -0
  13. colpali-main/colpali_engine/evaluation/eval_manager.py +178 -0
  14. colpali-main/colpali_engine/interpretability/__init__.py +4 -0
  15. colpali-main/colpali_engine/interpretability/gen_interpretability_plots.py +113 -0
  16. colpali-main/colpali_engine/interpretability/plot_utils.py +131 -0
  17. colpali-main/colpali_engine/interpretability/processor.py +116 -0
  18. colpali-main/colpali_engine/interpretability/torch_utils.py +60 -0
  19. colpali-main/colpali_engine/interpretability/vit_configs.py +23 -0
  20. colpali-main/colpali_engine/loss/__init__.py +1 -0
  21. colpali-main/colpali_engine/loss/colbert_loss.py +122 -0
  22. colpali-main/colpali_engine/models/__init__.py +0 -0
  23. colpali-main/colpali_engine/models/__pycache__/__init__.cpython-310.pyc +0 -0
  24. colpali-main/colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc +0 -0
  25. colpali-main/colpali_engine/models/clip_baselines.py +144 -0
  26. colpali-main/colpali_engine/models/colbert_architectures.py +177 -0
  27. colpali-main/colpali_engine/models/idefics_colbert_architecture.py +57 -0
  28. colpali-main/colpali_engine/models/paligemma_colbert_architecture.py +191 -0
  29. colpali-main/colpali_engine/trainer/__init__.py +0 -0
  30. colpali-main/colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  31. colpali-main/colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc +0 -0
  32. colpali-main/colpali_engine/trainer/contrastive_trainer.py +64 -0
  33. colpali-main/colpali_engine/trainer/retrieval_evaluator.py +72 -0
  34. colpali-main/colpali_engine/utils/__init__.py +0 -0
  35. colpali-main/colpali_engine/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  36. colpali-main/colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc +0 -0
  37. colpali-main/colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc +0 -0
  38. colpali-main/colpali_engine/utils/colidefics_processing_utils.py +53 -0
  39. colpali-main/colpali_engine/utils/colpali_processing_utils.py +36 -0
  40. colpali-main/colpali_engine/utils/dataset_transformation.py +158 -0
  41. colpali-main/colpali_engine/utils/gpu_stats.py +24 -0
  42. colpali-main/colpali_engine/utils/image_from_page_utils.py +21 -0
  43. colpali-main/colpali_engine/utils/image_utils.py +64 -0
  44. colpali-main/colpali_engine/utils/iter_utils.py +42 -0
  45. colpali-main/colpali_engine/utils/pdf_utils.py +87 -0
  46. colpali-main/colpali_engine/utils/plot_utils.py +6 -0
  47. colpali-main/colpali_engine/utils/torch_utils.py +18 -0
  48. colpali-main/colpali_engine/utils/train_colpali_engine_models.py +247 -0
  49. colpali-main/colpali_engine/utils/wrapper.py +83 -0
  50. colpali-main/demo/README.md +6 -0
colpali-main/.env.dist ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ HF_TOKEN=
2
+ HF_DATASETS_CACHE=
3
+
4
+ VERTEX_PROJECT=
5
+ VERTEX_LOCATION=
colpali-main/.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jsonl filter=lfs diff=lfs merge=lfs -text
2
+ *.csv filter=lfs diff=lfs merge=lfs -text
3
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
4
+
5
+ *.7z filter=lfs diff=lfs merge=lfs -text
6
+ *.arrow filter=lfs diff=lfs merge=lfs -text
7
+ *.bin filter=lfs diff=lfs merge=lfs -text
8
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
9
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.gz filter=lfs diff=lfs merge=lfs -text
12
+ *.h5 filter=lfs diff=lfs merge=lfs -text
13
+ *.joblib filter=lfs diff=lfs merge=lfs -text
14
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
15
+ *.model filter=lfs diff=lfs merge=lfs -text
16
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
26
+ *.tflite filter=lfs diff=lfs merge=lfs -text
27
+ *.tgz filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
colpali-main/.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ .DS_Store
3
+ .env
4
+ .litellm_cache/
5
+ data/litellm_cache_captionning/
6
+ .idea
7
+ .venv/
8
+ colbert/models/
9
+ logs/
10
+ data/downloaded_datasets/rimes_raw_dataset/
11
+ models/
12
+ !colpali_engine/models
13
+ data/
14
+ !*/configs/data/
15
+ data_dir/
16
+
17
+ # Byte-compiled / optimized / DLL files
18
+ __pycache__/
19
+ *.py[cod]
20
+ *$py.class
21
+
22
+ # C extensions
23
+ *.so
24
+
25
+ # Distribution / packaging
26
+ .Python
27
+ build/
28
+ develop-eggs/
29
+ dist/
30
+ downloads/
31
+ eggs/
32
+ .eggs/
33
+ lib/
34
+ lib64/
35
+ parts/
36
+ sdist/
37
+ var/
38
+ wheels/
39
+ share/python-wheels/
40
+ *.egg-info/
41
+ .installed.cfg
42
+ *.egg
43
+ MANIFEST
44
+
45
+ # PyInstaller
46
+ # Usually these files are written by a python script from a template
47
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
48
+ *.manifest
49
+ *.spec
50
+
51
+ # Installer logs
52
+ pip-log.txt
53
+ pip-delete-this-directory.txt
54
+
55
+ # Unit test / coverage reports
56
+ htmlcov/
57
+ .tox/
58
+ .nox/
59
+ .coverage
60
+ .coverage.*
61
+ .cache
62
+ nosetests.xml
63
+ coverage.xml
64
+ *.cover
65
+ *.py,cover
66
+ .hypothesis/
67
+ .pytest_cache/
68
+ cover/
69
+
70
+ # Translations
71
+ *.mo
72
+ *.pot
73
+
74
+ # Django stuff:
75
+ *.log
76
+ local_settings.py
77
+ db.sqlite3
78
+ db.sqlite3-journal
79
+
80
+ # Flask stuff:
81
+ instance/
82
+ .webassets-cache
83
+
84
+ # Scrapy stuff:
85
+ .scrapy
86
+
87
+ # Sphinx documentation
88
+ docs/_build/
89
+
90
+ # PyBuilder
91
+ .pybuilder/
92
+ target/
93
+
94
+ # Jupyter Notebook
95
+ .ipynb_checkpoints
96
+ notebooks/*.png
97
+
98
+ # IPython
99
+ profile_default/
100
+ ipython_config.py
101
+
102
+ # pyenv
103
+ # For a library or package, you might want to ignore these files since the code is
104
+ # intended to run in multiple environments; otherwise, check them in:
105
+ # .python-version
106
+
107
+ # pipenv
108
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
110
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
111
+ # install all needed dependencies.
112
+ #Pipfile.lock
113
+
114
+ # poetry
115
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
116
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
117
+ # commonly ignored for libraries.
118
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
119
+ #poetry.lock
120
+
121
+ # pdm
122
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
123
+ #pdm.lock
124
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
125
+ # in version control.
126
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
127
+ .pdm.toml
128
+ .pdm-python
129
+ .pdm-build/
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .venv
144
+ env/
145
+ venv/
146
+ ENV/
147
+ env.bak/
148
+ venv.bak/
149
+
150
+ # Spyder project settings
151
+ .spyderproject
152
+ .spyproject
153
+
154
+ # Rope project settings
155
+ .ropeproject
156
+
157
+ # mkdocs documentation
158
+ /site
159
+
160
+ # mypy
161
+ .mypy_cache/
162
+ .dmypy.json
163
+ dmypy.json
164
+
165
+ # Pyre type checker
166
+ .pyre/
167
+
168
+ # pytype static type analyzer
169
+ .pytype/
170
+
171
+ # Cython debug symbols
172
+ cython_debug/
173
+
174
+ # PyCharm
175
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
176
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
177
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
178
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
179
+ #.idea/
colpali-main/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.6
colpali-main/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Manuel Faysse, Hugues Sibille, Tony Wu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
colpali-main/README.md ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ColPali: Efficient Document Retrieval with Vision Language Models
2
+
3
+
4
+ [[Blog]](https://huggingface.co/blog/manu/colpali)
5
+ [[Paper]](https://arxiv.org/abs/2407.01449)
6
+ [[ColPali Model card]](https://huggingface.co/vidore/colpali)
7
+ [[ViDoRe Benchmark]](https://huggingface.co/vidore)
8
+ <!---[[Colab example]]()-->
9
+ [[HuggingFace Demo]](https://huggingface.co/spaces/manu/ColPali-demo)
10
+
11
+
12
+ ## Associated Paper
13
+
14
+ **ColPali: Efficient Document Retrieval with Vision Language Models**
15
+ Manuel Faysse, Hugues Sibille, Tony Wu, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo
16
+
17
+ This repository contains the code for training custom Colbert retriever models.
18
+ Notably, we train colbert with LLMs (decoders) as well as Image Language models !
19
+
20
+ ## Installation
21
+
22
+ ### From git
23
+ ```bash
24
+ pip install git+https://github.com/illuin-tech/colpali
25
+ ```
26
+
27
+ ### From source
28
+ ```bash
29
+ git clone https://github.com/illuin-tech/colpali
30
+ mv colpali
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ Example usage of the model is shown in the `scripts` directory.
37
+
38
+ ```bash
39
+ # hackable example script to adapt
40
+ python scripts/infer/run_inference_with_python.py
41
+ ```
42
+
43
+
44
+ ```python
45
+ import torch
46
+ import typer
47
+ from torch.utils.data import DataLoader
48
+ from tqdm import tqdm
49
+ from transformers import AutoProcessor
50
+ from PIL import Image
51
+
52
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
53
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
54
+ from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
55
+ from colpali_engine.utils.image_from_page_utils import load_from_dataset
56
+
57
+
58
+ def main() -> None:
59
+ """Example script to run inference with ColPali"""
60
+ # Load model
61
+ model_name = "vidore/colpali"
62
+ model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
63
+ model.load_adapter(model_name)
64
+ processor = AutoProcessor.from_pretrained(model_name)
65
+
66
+ # select images -> load_from_pdf(<pdf_path>), load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
67
+ images = load_from_dataset("vidore/docvqa_test_subsampled")
68
+ queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]
69
+
70
+ # run inference - docs
71
+ dataloader = DataLoader(
72
+ images,
73
+ batch_size=4,
74
+ shuffle=False,
75
+ collate_fn=lambda x: process_images(processor, x),
76
+ )
77
+ ds = []
78
+ for batch_doc in tqdm(dataloader):
79
+ with torch.no_grad():
80
+ batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
81
+ embeddings_doc = model(**batch_doc)
82
+ ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
83
+
84
+ # run inference - queries
85
+ dataloader = DataLoader(
86
+ queries,
87
+ batch_size=4,
88
+ shuffle=False,
89
+ collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
90
+ )
91
+
92
+ qs = []
93
+ for batch_query in dataloader:
94
+ with torch.no_grad():
95
+ batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
96
+ embeddings_query = model(**batch_query)
97
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
98
+
99
+ # run evaluation
100
+ retriever_evaluator = CustomEvaluator(is_multi_vector=True)
101
+ scores = retriever_evaluator.evaluate(qs, ds)
102
+ print(scores.argmax(axis=1))
103
+
104
+
105
+ if __name__ == "__main__":
106
+ typer.run(main)
107
+ ```
108
+
109
+ Detais are also given in the model card for the base Colpali model on HuggingFace: [ColPali Model card](https://huggingface.co/vidore/colpali).
110
+
111
+ ## Training
112
+
113
+ ```bash
114
+ USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml
115
+ ```
116
+
117
+ or
118
+
119
+ ```bash
120
+ accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml
121
+ ```
122
+
123
+ ### Configurations
124
+ All training arguments can be set through a configuration file.
125
+ The configuration file is a yaml file that contains all the arguments for training.
126
+
127
+ The construction is as follows:
128
+
129
+ ```python
130
+ @dataclass
131
+ class ColModelTrainingConfig:
132
+ model: PreTrainedModel
133
+ tr_args: TrainingArguments = None
134
+ output_dir: str = None
135
+ max_length: int = 256
136
+ run_eval: bool = True
137
+ run_train: bool = True
138
+ peft_config: Optional[LoraConfig] = None
139
+ add_suffix: bool = False
140
+ processor: Idefics2Processor = None
141
+ tokenizer: PreTrainedTokenizer = None
142
+ loss_func: Optional[Callable] = ColbertLoss()
143
+ dataset_loading_func: Optional[Callable] = None
144
+ eval_dataset_loader: Optional[Dict[str, Callable]] = None
145
+ pretrained_peft_model_name_or_path: Optional[str] = None
146
+ ```
147
+ ### Example
148
+
149
+ An example configuration file is:
150
+
151
+ ```yaml
152
+ config:
153
+ (): colpali_engine.utils.train_colpali_engine_models.ColModelTrainingConfig
154
+ output_dir: !path ../../../models/without_tabfquad/train_colpali-3b-mix-448
155
+ processor:
156
+ () : colpali_engine.utils.wrapper.AutoProcessorWrapper
157
+ pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
158
+ max_length: 50
159
+ model:
160
+ (): colpali_engine.utils.wrapper.AutoColModelWrapper
161
+ pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
162
+ training_objective: "colbertv1"
163
+ # attn_implementation: "eager"
164
+ torch_dtype: !ext torch.bfloat16
165
+ # device_map: "auto"
166
+ # quantization_config:
167
+ # (): transformers.BitsAndBytesConfig
168
+ # load_in_4bit: true
169
+ # bnb_4bit_quant_type: "nf4"
170
+ # bnb_4bit_compute_dtype: "bfloat16"
171
+ # bnb_4bit_use_double_quant: true
172
+
173
+ dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
174
+ eval_dataset_loader: !import ../data/test_data.yaml
175
+
176
+ max_length: 50
177
+ run_eval: true
178
+ add_suffix: true
179
+ loss_func:
180
+ (): colpali_engine.loss.colbert_loss.ColbertPairwiseCELoss
181
+ tr_args: !import ../tr_args/default_tr_args.yaml
182
+ peft_config:
183
+ (): peft.LoraConfig
184
+ r: 32
185
+ lora_alpha: 32
186
+ lora_dropout: 0.1
187
+ init_lora_weights: "gaussian"
188
+ bias: "none"
189
+ task_type: "FEATURE_EXTRACTION"
190
+ target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
191
+ # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
192
+ ```
193
+
194
+
195
+ #### Local training
196
+
197
+ ```bash
198
+ USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml
199
+ ```
200
+
201
+
202
+ #### SLURM
203
+
204
+ ```bash
205
+ sbatch --nodes=1 --cpus-per-task=16 --mem-per-cpu=32GB --time=20:00:00 --gres=gpu:1 -p gpua100 --job-name=colidefics --output=colidefics.out --error=colidefics.err --wrap="accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml"
206
+
207
+ sbatch --nodes=1 --time=5:00:00 -A cad15443 --gres=gpu:8 --constraint=MI250 --job-name=colpali --wrap="python scripts/train/train_colbert.py scripts/configs/train_colpali_model.yaml"
208
+ ```
209
+
210
+ ## CITATION
211
+
212
+ ```bibtex
213
+ @misc{faysse2024colpaliefficientdocumentretrieval,
214
+ title={ColPali: Efficient Document Retrieval with Vision Language Models},
215
+ author={Manuel Faysse and Hugues Sibille and Tony Wu and Bilel Omrani and Gautier Viaud and Céline Hudelot and Pierre Colombo},
216
+ year={2024},
217
+ eprint={2407.01449},
218
+ archivePrefix={arXiv},
219
+ primaryClass={cs.IR},
220
+ url={https://arxiv.org/abs/2407.01449},
221
+ }
222
+ ```
colpali-main/colpali_engine/__init__.py ADDED
File without changes
colpali-main/colpali_engine/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
colpali-main/colpali_engine/dataset/__init__.py ADDED
File without changes
colpali-main/colpali_engine/dataset/custom_collator.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer, ProcessorMixin
2
+
3
+
4
+ class CustomCollator:
5
+ def __init__(
6
+ self,
7
+ processor: ProcessorMixin = None,
8
+ tokenizer: PreTrainedTokenizer = None,
9
+ max_length: int = 2048,
10
+ add_suffix: bool = False,
11
+ ):
12
+ self.processor = processor
13
+ self.tokenizer = tokenizer
14
+ self.image_token_id = None
15
+ self.max_length = max_length
16
+ self.suffix = ""
17
+ if add_suffix:
18
+ self.suffix = "\n" * 10
19
+
20
+ if tokenizer is None and processor is None:
21
+ raise ValueError("Either processor or tokenizer should be provided.")
22
+
23
+ if self.processor is not None:
24
+ if self.processor.__class__.__name__ != "SiglipProcessor":
25
+ self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
26
+ self.processor.tokenizer.additional_special_tokens.index("<image>")
27
+ ]
28
+
29
+ if self.tokenizer is not None:
30
+ raise ValueError("Only one of processor or tokenizer should be provided.")
31
+
32
+ if self.tokenizer and self.tokenizer.pad_token is None:
33
+ self.tokenizer.pad_token = self.tokenizer.eos_token
34
+
35
+ def __call__(self, examples):
36
+ if self.processor is None:
37
+ return self.forward_text(examples)
38
+ if self.processor.__class__.__name__ == "Idefics2Processor":
39
+ return self.forward_vision_idefics(examples)
40
+ if self.processor.__class__.__name__ == "PaliGemmaProcessor":
41
+ return self.forward_vision_pali(examples)
42
+ if self.processor.__class__.__name__ == "SiglipProcessor":
43
+ return self.forward_vision_siglip(examples)
44
+ raise ValueError("Processor not supported")
45
+
46
+ def forward_text(self, examples):
47
+ texts_doc = []
48
+ texts_query = []
49
+ for example in examples:
50
+ text_query = example["query"] + self.suffix
51
+ text_doc = example["doc"]
52
+
53
+ texts_doc.append(text_doc.strip())
54
+ texts_query.append(text_query.strip())
55
+
56
+ batch_doc = self.tokenizer(
57
+ texts_doc, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
58
+ )
59
+ batch_query = self.tokenizer(
60
+ texts_query, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
61
+ )
62
+
63
+ # prefix each key with "doc_" or "query_" to avoid key conflicts
64
+ batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
65
+ batch_query = {f"query_{k}": v for k, v in batch_query.items()}
66
+ batch_doc.update(batch_query)
67
+
68
+ return batch_doc
69
+
70
+ def forward_vision_idefics(self, examples):
71
+ texts_doc = []
72
+ texts_query = []
73
+ images = []
74
+ for example in examples:
75
+ image = example["image"]
76
+
77
+ text_query = None
78
+ if example["query"] is not None:
79
+ query = example["query"]
80
+ messages_query = [
81
+ {
82
+ "role": "user",
83
+ "content": [
84
+ {
85
+ "type": "text",
86
+ "text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
87
+ },
88
+ ],
89
+ },
90
+ ]
91
+ text_query = self.processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
92
+
93
+ messages_doc = [
94
+ {
95
+ "role": "user",
96
+ "content": [
97
+ {"type": "text", "text": "Describe the image."},
98
+ {"type": "image"},
99
+ ],
100
+ },
101
+ ]
102
+
103
+ text_doc = self.processor.apply_chat_template(messages_doc, add_generation_prompt=False)
104
+
105
+ texts_doc.append(text_doc.strip())
106
+ texts_query.append(text_query)
107
+ images.append([image])
108
+
109
+ batch_doc = self.processor(
110
+ text=texts_doc, images=images, return_tensors="pt", padding="longest", max_length=self.max_length
111
+ )
112
+
113
+ batch_query = None
114
+ if all([t is None for t in texts_query]):
115
+ print("All queries are None. Returning None for all queries.")
116
+ elif any([t is None for t in texts_query]):
117
+ raise ValueError("Some queries are None. This collator does not support None queries yet.")
118
+ else:
119
+ batch_query = self.processor(
120
+ text=texts_query, return_tensors="pt", padding="longest", max_length=self.max_length
121
+ )
122
+
123
+ # prefix each key with "doc_" or "query_" to avoid key conflicts
124
+ batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
125
+
126
+ if batch_query is not None:
127
+ batch_query = {f"query_{k}": v for k, v in batch_query.items()}
128
+ batch_doc.update(batch_query)
129
+
130
+ return batch_doc
131
+
132
+ def forward_vision_pali(self, examples):
133
+ texts_doc = []
134
+ texts_query = []
135
+ images = []
136
+ for example in examples:
137
+
138
+ if example["image"] is None:
139
+ raise ValueError("Image is None - This collator does not support None images yet.")
140
+
141
+ image = example["image"].convert("RGB")
142
+ images.append(image)
143
+ texts_doc.append("Describe the image.")
144
+
145
+ if example["query"] is None:
146
+ texts_query.append(None)
147
+ else:
148
+ query = example["query"]
149
+ query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
150
+ texts_query.append(query)
151
+
152
+ batch_doc = self.processor(
153
+ text=texts_doc,
154
+ images=images,
155
+ return_tensors="pt",
156
+ padding="longest",
157
+ max_length=self.max_length + self.processor.image_seq_length,
158
+ )
159
+
160
+ batch_query = None
161
+ # check if some but not all queries are None
162
+ if all([t is None for t in texts_query]):
163
+ print("All queries are None. Returning None for all queries.")
164
+ elif any([t is None for t in texts_query]):
165
+ raise ValueError("Some queries are None. This collator does not support None queries yet.")
166
+ else:
167
+ batch_query = self.processor(
168
+ images=images, # NOTE: the image is not used in batch_query but it is required for calling the processor
169
+ text=texts_query,
170
+ return_tensors="pt",
171
+ padding="longest",
172
+ max_length=self.max_length + self.processor.image_seq_length,
173
+ )
174
+ del batch_query["pixel_values"]
175
+ batch_query["input_ids"] = batch_query["input_ids"][..., self.processor.image_seq_length :]
176
+ batch_query["attention_mask"] = batch_query["attention_mask"][..., self.processor.image_seq_length :]
177
+
178
+ # prefix each key with "doc_" or "query_" to avoid key conflicts
179
+ batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
180
+
181
+ if batch_query is not None:
182
+ batch_query = {f"query_{k}": v for k, v in batch_query.items()}
183
+ batch_doc.update(batch_query)
184
+
185
+ return batch_doc
186
+
187
+ def forward_vision_siglip(self, examples):
188
+ texts_doc = []
189
+ texts_query = []
190
+ images = []
191
+ for example in examples:
192
+
193
+ if example["image"] is None:
194
+ raise ValueError("Image is None - This collator does not support None images yet.")
195
+
196
+ image = example["image"].convert("RGB")
197
+ images.append(image)
198
+ texts_doc.append("Describe the image.")
199
+
200
+ if example["query"] is None:
201
+ texts_query.append(None)
202
+ else:
203
+ query = f"Question: {example['query']}"
204
+ texts_query.append(query)
205
+
206
+ batch_doc = self.processor(
207
+ text=texts_doc,
208
+ images=images,
209
+ return_tensors="pt",
210
+ padding="max_length",
211
+ truncation=True,
212
+ )
213
+
214
+ batch_query = None
215
+ # check if some but not all queries are None
216
+ if all([t is None for t in texts_query]):
217
+ # print("All queries are None.")
218
+ pass
219
+ elif any([t is None for t in texts_query]):
220
+ raise ValueError("Some queries are None. This collator does not support None queries yet.")
221
+ else:
222
+ batch_query = self.processor(
223
+ images=images,
224
+ text=texts_query,
225
+ return_tensors="pt",
226
+ padding="max_length",
227
+ max_length=self.max_length,
228
+ truncation=True,
229
+ )
230
+ del batch_query["pixel_values"]
231
+
232
+ # prefix each key with "doc_" or "query_" to avoid key conflicts
233
+ batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
234
+
235
+ if batch_query is not None:
236
+ batch_query = {f"query_{k}": v for k, v in batch_query.items()}
237
+ batch_doc.update(batch_query)
238
+ # add attention mask for queries
239
+ batch_doc["query_attention_mask"] = batch_doc["query_input_ids"].ne(0).long()
240
+
241
+ # add attention mask for docs
242
+ batch_doc["doc_attention_mask"] = batch_doc["doc_input_ids"].ne(0).long()
243
+
244
+ return batch_doc
colpali-main/colpali_engine/dataset/hf_dataset_names.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class TrainDatasets(Enum):
5
+ """
6
+ Dataset names for the training datasets used in HuggingFace Datasets.
7
+ """
8
+
9
+ government_reports = "vidore/syntheticDocQA_government_reports_train"
10
+ healthcare_industry = "vidore/syntheticDocQA_healthcare_industry_train"
11
+ energy = "vidore/syntheticDocQA_energy_train"
12
+ artificial_intelligence = "vidore/syntheticDocQA_artificial_intelligence_train"
13
+ arxivqa = "vidore/arxivqa_train"
14
+ docvqa = "vidore/docvqa_train"
15
+ infovqa = "vidore/infovqa_train"
16
+ tatqa = "vidore/tatqa_train"
17
+
18
+ @staticmethod
19
+ def get_synthetic_datasets():
20
+ return [
21
+ TrainDatasets.government_reports,
22
+ TrainDatasets.healthcare_industry,
23
+ TrainDatasets.energy,
24
+ TrainDatasets.artificial_intelligence,
25
+ ]
26
+
27
+
28
+ class TestImagesDirpath(Enum):
29
+ """
30
+ Dataset names for the test datasets used in HuggingFace Datasets.
31
+ """
32
+
33
+ government_reports = "data/government_reports"
34
+ healthcare_industry = "data/healthcare_industry"
35
+ energy = "data/energy"
36
+ artificial_intelligence = "data/scrapped_pdfs_split/pages_extracted/artificial_intelligence_test"
37
+ arxivqa = "data/arxivqa"
38
+ docvqa = "data/docvqa"
39
+ infovqa = "data/infovqa"
40
+ tatqa = "data/tatqa"
41
+
42
+
43
+ class CaptionedSyntheticDatasets(Enum):
44
+ """
45
+ Dataset names for the captioned synthetic datasets used in HuggingFace Datasets.
46
+ """
47
+
48
+ shift = "vidore/baseline_cap_shiftproject_test"
49
+
50
+
51
+ class SyntheticDocQATest(Enum):
52
+ shift = "vidore/shiftproject_test"
colpali-main/colpali_engine/evaluation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .eval_manager import EvalManager
colpali-main/colpali_engine/evaluation/eval_manager.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, ClassVar, Dict, Optional
5
+
6
+ import pandas as pd
7
+
8
+
9
+ class EvalManager:
10
+ """
11
+ Stores evaluation results for various datasets and metrics.
12
+
13
+ The data is stored in a pandas DataFrame with a MultiIndex for columns.
14
+ The first level of the MultiIndex is the dataset name and the second level is the metric name.
15
+
16
+ Usage:
17
+ >>> evaluator = Evaluator.from_dirpath("data/evaluation_results/")
18
+ >>> print(evaluator.data)
19
+
20
+ """
21
+
22
+ model_col: ClassVar[str] = "model"
23
+ dataset_col: ClassVar[str] = "dataset"
24
+ metric_col: ClassVar[str] = "metric"
25
+
26
+ def __init__(self, data: Optional[pd.DataFrame] = None):
27
+ if data is None:
28
+ data = pd.DataFrame()
29
+ self._df = data
30
+ self._df.index = self._df.index.rename(EvalManager.model_col)
31
+
32
+ def __str__(self) -> str:
33
+ return self.data.__str__()
34
+
35
+ @staticmethod
36
+ def from_dict(data: Dict[Any, Any]) -> EvalManager:
37
+ """
38
+ Load evaluation results from a dictionary.
39
+
40
+ Expected format:
41
+ {
42
+ "model1": pd.read_json(path1).T.stack(),
43
+ "model2": pd.read_json(path2).T.stack(),
44
+ }
45
+
46
+ """
47
+ df = pd.DataFrame.from_dict(data, orient="index")
48
+ return EvalManager(df)
49
+
50
+ @staticmethod
51
+ def from_json(path: str | Path) -> EvalManager:
52
+ datapath = Path(path)
53
+ if not datapath.is_file():
54
+ raise FileNotFoundError(f"{path} is not a file")
55
+ data = {}
56
+ data[datapath.stem] = pd.read_json(datapath).T.stack() # pylint: disable=no-member
57
+ return EvalManager.from_dict(data)
58
+
59
+ @staticmethod
60
+ def from_dir(datadir: str | Path) -> EvalManager:
61
+ datadir_ = Path(datadir)
62
+ if not datadir_.is_dir():
63
+ raise FileNotFoundError(f"{datadir} is not a directory")
64
+
65
+ eval_files = list(datadir_.glob("*.json"))
66
+
67
+ data = {}
68
+
69
+ for filepath in eval_files:
70
+ data[filepath.stem] = pd.read_json(filepath).T.stack() # pylint: disable=no-member
71
+
72
+ return EvalManager.from_dict(data)
73
+
74
+ @staticmethod
75
+ def from_csv(path: str | Path) -> EvalManager:
76
+ """
77
+ Load evaluation results from a CSV file.
78
+ """
79
+ try:
80
+ df = pd.read_csv(path, index_col=0, header=[0, 1])
81
+ return EvalManager(df)
82
+ except Exception as e:
83
+ print(f"Error loading {path}: {e}")
84
+ raise e
85
+
86
+ @property
87
+ def data(self) -> pd.DataFrame:
88
+ """
89
+ Returns the evaluation results as a pandas DataFrame.
90
+ """
91
+ return self._df.copy()
92
+
93
+ @property
94
+ def models(self) -> pd.Index:
95
+ """
96
+ Returns the models for which there are evaluation results.
97
+ """
98
+ return self.data.index
99
+
100
+ @property
101
+ def datasets(self) -> pd.Index:
102
+ """
103
+ Returns the datasets for which there are evaluation results.
104
+ """
105
+ return self.data.columns.get_level_values(0).unique()
106
+
107
+ @property
108
+ def metrics(self) -> pd.Index:
109
+ """
110
+ Returns the metrics for which there are evaluation results.
111
+ """
112
+ return self.data.columns.get_level_values(1)
113
+
114
+ @staticmethod
115
+ def melt(df: pd.DataFrame) -> pd.DataFrame:
116
+ """
117
+ Melt a suitable DataFrame (e.g. returned by `get_df_for_dataset` and
118
+ `get_df_for_metric`) into a 'long' format.
119
+ """
120
+ return df.T.reset_index(names=[EvalManager.dataset_col, EvalManager.metric_col]).melt(
121
+ id_vars=[EvalManager.dataset_col, EvalManager.metric_col],
122
+ var_name=EvalManager.model_col,
123
+ value_name="score",
124
+ )
125
+
126
+ @property
127
+ def melted(self) -> pd.DataFrame:
128
+ """
129
+ Returns the evaluation results as a 'melted' DataFrame.
130
+ Useful for plotting with seaborn.
131
+ """
132
+ return EvalManager.melt(self.data)
133
+
134
+ def get_df_for_model(self, model: str) -> pd.DataFrame:
135
+ if model not in self.data.index:
136
+ raise ValueError(f"Model {model} not found in the evaluation results")
137
+ return self.data.loc[[model], :] # type: ignore
138
+
139
+ def get_df_for_dataset(self, dataset: str) -> pd.DataFrame:
140
+ if dataset not in self.datasets:
141
+ raise ValueError(f"Dataset {dataset} not found in the evaluation results")
142
+ return self.data.loc[:, (dataset, slice(None))] # type: ignore
143
+
144
+ def get_df_for_metric(self, metric: str) -> pd.DataFrame:
145
+ if metric not in self.metrics:
146
+ raise ValueError(f"Metric {metric} not found in the evaluation results")
147
+ return self.data.loc[:, (slice(None), metric)] # type: ignore
148
+
149
+ def sort_by_dataset(self, ascending: bool = True) -> EvalManager:
150
+ """
151
+ Sort the evaluation results by dataset name.
152
+ """
153
+ df = self.data.T.sort_index(level=0, ascending=ascending).T
154
+ return EvalManager(df)
155
+
156
+ def sort_by_metric(self, ascending: bool = True) -> EvalManager:
157
+ """
158
+ Sort the evaluation results by metric name.
159
+ """
160
+ df = self.data.T.sort_index(level=1, ascending=ascending).T
161
+ return EvalManager(df)
162
+
163
+ def sort_columns(self, ascending: bool = True) -> EvalManager:
164
+ """
165
+ Sort the evaluation results by dataset name and then by metric name.
166
+ """
167
+ df = self.data.T.sort_index(level=[0, 1], ascending=ascending).T
168
+ return EvalManager(df)
169
+
170
+ def to_csv(self, path: str | Path):
171
+ """
172
+ Save the evaluation results to a CSV file.
173
+
174
+ Using `Evaluation.from_csv(path_to_saved_csv)` will load the evaluation results back into memory.
175
+ """
176
+ savepath = Path(path)
177
+ savepath.parent.mkdir(parents=True, exist_ok=True)
178
+ self.data.to_csv(savepath)
colpali-main/colpali_engine/interpretability/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .plot_utils import *
2
+ from .processor import *
3
+ from .torch_utils import *
4
+ from .vit_configs import *
colpali-main/colpali_engine/interpretability/gen_interpretability_plots.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from dataclasses import asdict, dataclass
3
+ from pathlib import Path
4
+ from uuid import uuid4
5
+
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ from einops import rearrange
9
+ from PIL import Image
10
+ from tqdm import trange
11
+
12
+ from colpali_engine.interpretability.plot_utils import plot_patches
13
+ from colpali_engine.interpretability.processor import ColPaliProcessor
14
+ from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
15
+ from colpali_engine.interpretability.vit_configs import VIT_CONFIG
16
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
17
+
18
+ OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
19
+
20
+
21
+ @dataclass
22
+ class InterpretabilityInput:
23
+ query: str
24
+ image: Image.Image
25
+ start_idx_token: int
26
+ end_idx_token: int
27
+
28
+
29
+ def generate_interpretability_plots(
30
+ model: ColPali,
31
+ processor: ColPaliProcessor,
32
+ query: str,
33
+ image: Image.Image,
34
+ savedir: str | Path | None = None,
35
+ add_special_prompt_to_doc: bool = True,
36
+ ) -> None:
37
+
38
+ # Sanity checks
39
+ if len(model.active_adapters()) != 1:
40
+ raise ValueError("The model must have exactly one active adapter.")
41
+
42
+ if model.config.name_or_path not in VIT_CONFIG:
43
+ raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
44
+ vit_config = VIT_CONFIG[model.config.name_or_path]
45
+
46
+ # Handle savepath
47
+ if not savedir:
48
+ savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
49
+ print(f"No savepath provided. Results will be saved to: `{savedir}`.")
50
+ elif isinstance(savedir, str):
51
+ savedir = Path(savedir)
52
+ savedir.mkdir(parents=True, exist_ok=True)
53
+
54
+ # Resize the image to square
55
+ input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
56
+
57
+ # Preprocess the inputs
58
+ input_text_processed = processor.process_text(query).to(model.device)
59
+ input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
60
+ model.device
61
+ )
62
+
63
+ # Forward pass
64
+ with torch.no_grad():
65
+ output_text = model.forward(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
66
+
67
+ # NOTE: `output_image`` will have shape:
68
+ # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
69
+ # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
70
+ with torch.no_grad():
71
+ output_image = model.forward(**asdict(input_image_processed))
72
+
73
+ if add_special_prompt_to_doc: # remove the special tokens
74
+ output_image = output_image[
75
+ :, : processor.processor.image_seq_length, :
76
+ ] # (1, n_patch_x * n_patch_y, hidden_dim)
77
+
78
+ output_image = rearrange(
79
+ output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
80
+ ) # (1, n_patch_x, n_patch_y, hidden_dim)
81
+
82
+ # Get the unnormalized attention map
83
+ attention_map = torch.einsum(
84
+ "bnk,bijk->bnij", output_text, output_image
85
+ ) # (1, n_text_tokens, n_patch_x, n_patch_y)
86
+ attention_map_normalized = normalize_attention_map_per_query_token(
87
+ attention_map
88
+ ) # (1, n_text_tokens, n_patch_x, n_patch_y)
89
+ attention_map_normalized = attention_map_normalized.float()
90
+
91
+ # Get text token information
92
+ n_tokens = input_text_processed.input_ids.size(1)
93
+ text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
94
+ print("Text tokens:")
95
+ pprint.pprint(text_tokens)
96
+ print("\n")
97
+
98
+ for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."): # exclude the <bos> and the "\n" tokens
99
+ fig, axis = plot_patches(
100
+ input_image_square,
101
+ vit_config.patch_size,
102
+ vit_config.resolution,
103
+ patch_opacities=attention_map_normalized[0, token_idx, :, :],
104
+ style="dark_background",
105
+ )
106
+
107
+ fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
108
+ savepath = savedir / f"token_{token_idx}.png"
109
+ fig.savefig(savepath)
110
+ print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
111
+ plt.close(fig)
112
+
113
+ return
colpali-main/colpali_engine/interpretability/plot_utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, cast
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import seaborn as sns
7
+ import torch
8
+ from PIL import Image
9
+
10
+ MAX_OPACITY = 255
11
+
12
+
13
+ def plot_patches(
14
+ img: Image.Image,
15
+ patch_size: int,
16
+ image_resolution: int,
17
+ patch_opacities: Optional[npt.NDArray | torch.Tensor] = None,
18
+ figsize: Tuple[int, int] = (8, 8),
19
+ style: Dict[str, Any] | str | None = None,
20
+ ) -> Tuple[plt.Figure, plt.Axes]:
21
+ """
22
+ Plot patches of a square image.
23
+ Set `style` to "dark_background" if your image has a light background.
24
+ """
25
+
26
+ # Get the number of patches
27
+ if image_resolution % patch_size != 0:
28
+ raise ValueError("The image resolution must be divisible by the patch size.")
29
+ num_patches = image_resolution // patch_size
30
+
31
+ # Default style
32
+ if style is None:
33
+ style = {}
34
+
35
+ # Sanity checks
36
+ if patch_opacities is not None:
37
+ if isinstance(patch_opacities, torch.Tensor):
38
+ patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy())
39
+ if patch_opacities.shape != (num_patches, num_patches):
40
+ raise ValueError("The shape of the patch_opacities tensor is not correct.")
41
+ if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)):
42
+ raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
43
+
44
+ # If the image is not square, raise an error
45
+ if img.size[0] != img.size[1]:
46
+ raise ValueError("The image must be square.")
47
+
48
+ # Get the image as a numpy array
49
+ img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
50
+
51
+ # Create a figure
52
+ with plt.style.context(style):
53
+ fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)
54
+
55
+ # Plot the patches
56
+ for i in range(num_patches):
57
+ for j in range(num_patches):
58
+ patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
59
+ # Set the opacity of the patch
60
+ if patch_opacities is not None:
61
+ patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY)
62
+ axis[i, j].imshow(patch)
63
+ axis[i, j].axis("off")
64
+
65
+ fig.subplots_adjust(wspace=0.1, hspace=0.1)
66
+
67
+ fig.tight_layout()
68
+
69
+ return fig, axis
70
+
71
+
72
+ def plot_attention_heatmap(
73
+ img: Image.Image,
74
+ patch_size: int,
75
+ image_resolution: int,
76
+ attention_map: npt.NDArray | torch.Tensor,
77
+ figsize: Tuple[int, int] = (8, 8),
78
+ style: Dict[str, Any] | str | None = None,
79
+ show_colorbar: bool = False,
80
+ show_axes: bool = False,
81
+ ) -> Tuple[plt.Figure, plt.Axes]:
82
+ """
83
+ Plot a heatmap of the attention map over the image.
84
+ The image must be square and `attention_map` must be normalized between 0 and 1.
85
+ """
86
+
87
+ # Get the number of patches
88
+ if image_resolution % patch_size != 0:
89
+ raise ValueError("The image resolution must be divisible by the patch size.")
90
+ num_patches = image_resolution // patch_size
91
+
92
+ # Default style
93
+ if style is None:
94
+ style = {}
95
+
96
+ # Sanity checks
97
+ if isinstance(attention_map, torch.Tensor):
98
+ attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
99
+ if attention_map.shape != (num_patches, num_patches):
100
+ raise ValueError("The shape of the patch_opacities tensor is not correct.")
101
+ if not np.all((0 <= attention_map) & (attention_map <= 1)):
102
+ raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
103
+
104
+ # If the image is not square, raise an error
105
+ if img.size[0] != img.size[1]:
106
+ raise ValueError("The image must be square.")
107
+
108
+ # Get the image as a numpy array
109
+ img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
110
+
111
+ # Get the attention map as a numpy array
112
+ attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
113
+ img.size, Image.Resampling.BICUBIC
114
+ )
115
+
116
+ # Create a figure
117
+ with plt.style.context(style):
118
+ fig, ax = plt.subplots(figsize=figsize)
119
+ ax.imshow(img_array)
120
+ im = ax.imshow(
121
+ attention_map_image,
122
+ cmap=sns.color_palette("mako", as_cmap=True),
123
+ alpha=0.5,
124
+ )
125
+ if show_colorbar:
126
+ fig.colorbar(im)
127
+ if not show_axes:
128
+ ax.set_axis_off()
129
+ fig.tight_layout()
130
+
131
+ return fig, ax
colpali-main/colpali_engine/interpretability/processor.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, cast
5
+
6
+ import torch
7
+ from PIL import Image
8
+ from transformers import LlamaTokenizerFast, PaliGemmaProcessor
9
+
10
+
11
+ @dataclass
12
+ class ColPaliTextInput:
13
+ input_ids: torch.Tensor
14
+ attention_mask: torch.Tensor
15
+
16
+ def to(self, device: torch.device) -> ColPaliTextInput:
17
+ return ColPaliTextInput(
18
+ input_ids=self.input_ids.to(device),
19
+ attention_mask=self.attention_mask.to(device),
20
+ )
21
+
22
+
23
+ @dataclass
24
+ class ColPaliImageInput:
25
+ input_ids: torch.Tensor
26
+ pixel_values: torch.Tensor
27
+ attention_mask: torch.Tensor
28
+
29
+ def to(self, device: str | torch.device) -> ColPaliImageInput:
30
+ return ColPaliImageInput(
31
+ input_ids=self.input_ids.to(device),
32
+ pixel_values=self.pixel_values.to(device),
33
+ attention_mask=self.attention_mask.to(device),
34
+ )
35
+
36
+
37
+ class ColPaliProcessor:
38
+ def __init__(self, processor: PaliGemmaProcessor):
39
+ self.processor = processor
40
+ self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore
41
+
42
+ @staticmethod
43
+ def from_pretrained(model_name: str) -> ColPaliProcessor:
44
+ return ColPaliProcessor(processor=cast(PaliGemmaProcessor, PaliGemmaProcessor.from_pretrained(model_name)))
45
+
46
+ def process_text(
47
+ self,
48
+ text: str | List[str],
49
+ padding: str = "longest",
50
+ return_tensors: str = "pt",
51
+ add_special_tokens: bool = True,
52
+ ) -> ColPaliTextInput:
53
+ """
54
+ Process text inputs for the model.
55
+ If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n".
56
+ """
57
+ if add_special_tokens:
58
+ if isinstance(text, str):
59
+ text = self.tokenizer.bos_token + text + "\n"
60
+ elif isinstance(text, list):
61
+ text = [self.tokenizer.bos_token + t + "\n" for t in text]
62
+ else:
63
+ raise ValueError("text must be a string or a list of strings.")
64
+
65
+ batch_output = self.tokenizer(
66
+ text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens
67
+ )
68
+
69
+ return ColPaliTextInput(
70
+ input_ids=cast(torch.Tensor, batch_output["input_ids"]),
71
+ attention_mask=cast(torch.Tensor, batch_output["attention_mask"]),
72
+ )
73
+
74
+ def process_image(
75
+ self,
76
+ image: Image.Image | List[Image.Image],
77
+ padding: str = "longest",
78
+ do_convert_rgb: bool = True,
79
+ return_tensors: str = "pt",
80
+ add_special_prompt: bool = True,
81
+ ) -> ColPaliImageInput:
82
+ # NOTE: The special prompt was used at training time,
83
+ special_prompt = "Describe the image." if add_special_prompt else None
84
+ if isinstance(image, Image.Image):
85
+ text_input = [special_prompt]
86
+ elif isinstance(image, list):
87
+ text_input = [special_prompt] * len(image)
88
+ else:
89
+ raise ValueError("image must be a PIL Image or a list of PIL Images.")
90
+
91
+ batch_output = self.processor(
92
+ text=text_input,
93
+ images=image,
94
+ padding=padding,
95
+ do_convert_rgb=do_convert_rgb,
96
+ return_tensors=return_tensors,
97
+ )
98
+
99
+ if add_special_prompt:
100
+ return ColPaliImageInput(
101
+ input_ids=batch_output["input_ids"],
102
+ pixel_values=batch_output["pixel_values"],
103
+ attention_mask=batch_output["attention_mask"],
104
+ )
105
+ else:
106
+ return ColPaliImageInput(
107
+ input_ids=batch_output["input_ids"][:, : self.processor.image_seq_length],
108
+ pixel_values=batch_output["pixel_values"][:, : self.processor.image_seq_length],
109
+ attention_mask=batch_output["attention_mask"][:, : self.processor.image_seq_length],
110
+ )
111
+
112
+ def decode(self, *args, **kwargs):
113
+ return self.tokenizer.decode(*args, **kwargs)
114
+
115
+ def batch_decode(self, *args, **kwargs):
116
+ return self.tokenizer.batch_decode(*args, **kwargs)
colpali-main/colpali_engine/interpretability/torch_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ EPSILON = 1e-10
8
+
9
+
10
+ def normalize_attention_map_per_query_token(x: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ Normalizes the attention map for ColPali for each query token.
13
+ The output tensor will have values in the range [0, 1] and the
14
+ same shape as the input tensor.
15
+
16
+ Args:
17
+ x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
18
+ """
19
+ if x.ndim != 4:
20
+ raise ValueError("The input tensor must have 4 dimensions.")
21
+
22
+ # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
23
+ min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
24
+
25
+ # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
26
+ max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
27
+
28
+ # Normalize the tensor
29
+ x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
30
+
31
+ return x_normalized
32
+
33
+
34
+ def normalize_attention_map_per_query(x: torch.Tensor) -> torch.Tensor:
35
+ """
36
+ Normalizes the attention map for ColPali for each query token.
37
+ The output tensor will have values in the range [0, 1] and the
38
+ same shape as the input tensor.
39
+
40
+ Args:
41
+ x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
42
+ """
43
+ # Log warning
44
+ logger.warning(
45
+ "This function should not be used for ColPali because it doesn't make sense to normalize the attention map across the text tokens."
46
+ )
47
+
48
+ if x.ndim != 4:
49
+ raise ValueError("The input tensor must have 4 dimensions.")
50
+
51
+ # Compute the minimum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
52
+ min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0].min(dim=-3, keepdim=True)[0]
53
+
54
+ # Compute the maximum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
55
+ max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0].max(dim=-3, keepdim=True)[0]
56
+
57
+ # Normalize the tensor
58
+ x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
59
+
60
+ return x_normalized
colpali-main/colpali_engine/interpretability/vit_configs.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+
4
+
5
+ @dataclass
6
+ class ViTConfig:
7
+ patch_size: int
8
+ resolution: int
9
+
10
+ @property
11
+ def n_patch_per_dim(self) -> int:
12
+ if self.resolution % self.patch_size != 0:
13
+ raise ValueError(f"Resolution {self.resolution} is not divisible by patch size {self.patch_size}")
14
+ return self.resolution // self.patch_size
15
+
16
+
17
+ VIT_CONFIG: Dict[str, ViTConfig] = {
18
+ "google/siglip-so400m-patch14-384": ViTConfig(patch_size=14, resolution=384),
19
+ "timm/ViT-SO400M-14-SigLIP-384": ViTConfig(patch_size=14, resolution=384),
20
+ "google/paligemma-3b-mix-448": ViTConfig(
21
+ patch_size=14, resolution=448
22
+ ), # based on "timm/ViT-SO400M-14-SigLIP-384" with increased resolution
23
+ }
colpali-main/colpali_engine/loss/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .colbert_loss import ColbertLoss
colpali-main/colpali_engine/loss/colbert_loss.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.nn import CrossEntropyLoss
4
+
5
+
6
+ class BiEncoderLoss(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.ce_loss = CrossEntropyLoss()
10
+ # self.pooling_strategy = pooling_strategy
11
+
12
+ def forward(self, query_embeddings, doc_embeddings):
13
+ """
14
+ query_embeddings: (batch_size, dim)
15
+ doc_embeddings: (batch_size, dim)
16
+ """
17
+
18
+ scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
19
+
20
+ loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
21
+ # loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
22
+ # loss = (loss_rowwise + loss_columnwise) / 2
23
+ return loss_rowwise
24
+
25
+
26
+ class ColbertLoss(torch.nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.ce_loss = CrossEntropyLoss()
30
+
31
+ def forward(self, query_embeddings, doc_embeddings):
32
+ """
33
+ query_embeddings: (batch_size, num_query_tokens, dim)
34
+ doc_embeddings: (batch_size, num_doc_tokens, dim)
35
+ """
36
+
37
+ scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
38
+
39
+ # scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
40
+ # for i in range(query_embeddings.shape[0]):
41
+ # for j in range(doc_embeddings.shape[0]):
42
+ # # step 1 - dot product --> (s1,s2)
43
+ # q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
44
+ # # step 2 -> max on doc --> (s1)
45
+ # q_scores = torch.max(q2d_scores, dim=1)[0]
46
+ # # step 3 --> sum the max score --> (1)
47
+ # sum_q_score = torch.sum(q_scores)
48
+ # # step 4 --> assert is scalar
49
+ # scores[i, j] = sum_q_score
50
+
51
+ # assert (scores_einsum - scores < 0.0001).all().item()
52
+
53
+ loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
54
+ # TODO: comparing between queries might not make sense since it's a sum over the length of the query
55
+ # loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
56
+ # loss = (loss_rowwise + loss_columnwise) / 2
57
+ return loss_rowwise
58
+
59
+
60
+ class ColbertPairwiseCELoss(torch.nn.Module):
61
+ def __init__(self):
62
+ super().__init__()
63
+ self.ce_loss = CrossEntropyLoss()
64
+
65
+ def forward(self, query_embeddings, doc_embeddings):
66
+ """
67
+ query_embeddings: (batch_size, num_query_tokens, dim)
68
+ doc_embeddings: (batch_size, num_doc_tokens, dim)
69
+
70
+ Positive scores are the diagonal of the scores matrix.
71
+ """
72
+
73
+ # Compute the ColBERT scores
74
+ scores = (
75
+ torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
76
+ ) # (batch_size, batch_size)
77
+
78
+ # Positive scores are the diagonal of the scores matrix.
79
+ pos_scores = scores.diagonal() # (batch_size,)
80
+
81
+ # Negative score for a given query is the maximum of the scores against all all other pages.
82
+ # NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
83
+ # we can subtract 1 from the diagonal to exclude it from the maximum operation.
84
+ neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
85
+ neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
86
+
87
+ # Compute the loss
88
+ # The loss is computed as the negative log of the softmax of the positive scores
89
+ # relative to the negative scores.
90
+ # This can be simplified to log-sum-exp of negative scores minus the positive score
91
+ # for numerical stability.
92
+ # torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
93
+ loss = F.softplus(neg_scores - pos_scores).mean()
94
+
95
+ return loss
96
+
97
+
98
+ class BiPairwiseCELoss(torch.nn.Module):
99
+ def __init__(self):
100
+ super().__init__()
101
+ self.ce_loss = CrossEntropyLoss()
102
+
103
+ def forward(self, query_embeddings, doc_embeddings):
104
+ """
105
+ query_embeddings: (batch_size, dim)
106
+ doc_embeddings: (batch_size, dim)
107
+ """
108
+
109
+ scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
110
+
111
+ pos_scores = scores.diagonal()
112
+ neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
113
+ neg_scores = neg_scores.max(dim=1)[0]
114
+
115
+ # Compute the loss
116
+ # The loss is computed as the negative log of the softmax of the positive scores
117
+ # relative to the negative scores.
118
+ # This can be simplified to log-sum-exp of negative scores minus the positive score
119
+ # for numerical stability.
120
+ loss = F.softplus(neg_scores - pos_scores).mean()
121
+
122
+ return loss
colpali-main/colpali_engine/models/__init__.py ADDED
File without changes
colpali-main/colpali_engine/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
colpali-main/colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
colpali-main/colpali_engine/models/clip_baselines.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from transformers import SiglipModel
6
+
7
+
8
+ class SigLIP(SiglipModel):
9
+ def forward(self, *args, **kwargs):
10
+ """
11
+ Forward pass through Llama and the linear layer for dimensionality reduction
12
+
13
+ Args:
14
+ - input_ids (torch.LongTensor): The input tokens tensor.
15
+ - attention_mask (torch.LongTensor): The attention mask tensor.
16
+
17
+ Returns:
18
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
19
+ """
20
+ return self.forward_branch(*args, **kwargs)
21
+
22
+ def forward_branch(
23
+ self,
24
+ input_ids: Optional[torch.LongTensor] = None,
25
+ pixel_values: Optional[torch.FloatTensor] = None,
26
+ attention_mask: Optional[torch.Tensor] = None,
27
+ position_ids: Optional[torch.LongTensor] = None,
28
+ return_loss: Optional[bool] = None,
29
+ output_attentions: Optional[bool] = None,
30
+ output_hidden_states: Optional[bool] = None,
31
+ return_dict: Optional[bool] = None,
32
+ interpolate_pos_encoding: bool = False,
33
+ ):
34
+
35
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
36
+ output_hidden_states = (
37
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
38
+ )
39
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
40
+
41
+ if pixel_values is not None:
42
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
43
+
44
+ outputs = self.vision_model(
45
+ pixel_values=pixel_values.to(dtype=self.dtype),
46
+ output_attentions=output_attentions,
47
+ output_hidden_states=output_hidden_states,
48
+ return_dict=return_dict,
49
+ interpolate_pos_encoding=interpolate_pos_encoding,
50
+ )
51
+
52
+ else:
53
+ outputs = self.text_model(
54
+ input_ids=input_ids,
55
+ attention_mask=attention_mask,
56
+ position_ids=position_ids,
57
+ output_attentions=output_attentions,
58
+ output_hidden_states=output_hidden_states,
59
+ return_dict=return_dict,
60
+ )
61
+
62
+ embeds = outputs[1]
63
+
64
+ # normalized features
65
+ embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
66
+ return embeds
67
+
68
+
69
+ class ColSigLIP(SiglipModel):
70
+ def __init__(self, config):
71
+ super(ColSigLIP, self).__init__(config=config)
72
+ self.dim = 128
73
+ self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim)
74
+ self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim)
75
+ self.main_input_name = "doc_input_ids"
76
+
77
+ def forward(self, *args, **kwargs):
78
+ """
79
+ Forward pass through Llama and the linear layer for dimensionality reduction
80
+
81
+ Args:
82
+ - input_ids (torch.LongTensor): The input tokens tensor.
83
+ - attention_mask (torch.LongTensor): The attention mask tensor.
84
+
85
+ Returns:
86
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
87
+ """
88
+ return self.forward_branch(*args, **kwargs)
89
+
90
+ def forward_branch(
91
+ self,
92
+ input_ids: Optional[torch.LongTensor] = None,
93
+ pixel_values: Optional[torch.FloatTensor] = None,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ position_ids: Optional[torch.LongTensor] = None,
96
+ return_loss: Optional[bool] = None,
97
+ output_attentions: Optional[bool] = None,
98
+ output_hidden_states: Optional[bool] = None,
99
+ return_dict: Optional[bool] = None,
100
+ interpolate_pos_encoding: bool = False,
101
+ ):
102
+
103
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
104
+ output_hidden_states = (
105
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
106
+ )
107
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
108
+
109
+ if pixel_values is not None:
110
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
111
+
112
+ outputs = self.vision_model(
113
+ pixel_values=pixel_values.to(dtype=self.dtype),
114
+ output_attentions=output_attentions,
115
+ output_hidden_states=output_hidden_states,
116
+ return_dict=return_dict,
117
+ interpolate_pos_encoding=interpolate_pos_encoding,
118
+ )
119
+
120
+ last_hidden_states = outputs.last_hidden_state
121
+
122
+ proj = self.custom_vision_proj(last_hidden_states)
123
+ # normalize l2 norm
124
+ proj = proj / proj.norm(dim=-1, keepdim=True)
125
+
126
+ else:
127
+ outputs = self.text_model(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ position_ids=position_ids,
131
+ output_attentions=output_attentions,
132
+ output_hidden_states=output_hidden_states,
133
+ return_dict=return_dict,
134
+ )
135
+
136
+ last_hidden_states = outputs.last_hidden_state
137
+
138
+ proj = self.custom_text_proj(last_hidden_states)
139
+ # normalize l2 norm
140
+ proj = proj / proj.norm(dim=-1, keepdim=True)
141
+ proj = proj * attention_mask.unsqueeze(-1)
142
+
143
+ # normalized features
144
+ return proj
colpali-main/colpali_engine/models/colbert_architectures.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import (
3
+ BertModel,
4
+ BertPreTrainedModel,
5
+ CamembertModel,
6
+ CamembertPreTrainedModel,
7
+ LlamaModel,
8
+ LlamaPreTrainedModel,
9
+ XLMRobertaModel,
10
+ XLMRobertaPreTrainedModel,
11
+ )
12
+
13
+
14
+ class ColCamembert(CamembertPreTrainedModel):
15
+ def __init__(self, config):
16
+ super(ColCamembert, self).__init__(config=config)
17
+ self.roberta: CamembertPreTrainedModel = CamembertModel(config)
18
+ self.dim = 128
19
+ self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
20
+ self.main_input_name = "doc_input_ids"
21
+
22
+ def forward(self, *args, **kwargs):
23
+ """
24
+ Forward pass through Camenbert and the linear layer for dimensionality reduction
25
+
26
+ Args:
27
+ - input_ids (torch.LongTensor): The input tokens tensor.
28
+ - attention_mask (torch.LongTensor): The attention mask tensor.
29
+
30
+ Returns:
31
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
32
+ """
33
+ outputs = self.roberta(*args, **kwargs)
34
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
35
+ proj = self.linear(last_hidden_states)
36
+ # normalize l2 norm
37
+ proj = proj / proj.norm(dim=-1, keepdim=True)
38
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
39
+ return proj
40
+
41
+
42
+ class ColXLMRoBERTa(XLMRobertaPreTrainedModel):
43
+ def __init__(self, config):
44
+ super(ColXLMRoBERTa, self).__init__(config=config)
45
+ self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
46
+ self.dim = 128
47
+ self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
48
+ self.main_input_name = "doc_input_ids"
49
+
50
+ def forward(self, *args, **kwargs):
51
+ """
52
+ Forward pass through Roberta and the linear layer for dimensionality reduction
53
+
54
+ Args:
55
+ - input_ids (torch.LongTensor): The input tokens tensor.
56
+ - attention_mask (torch.LongTensor): The attention mask tensor.
57
+
58
+ Returns:
59
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
60
+ """
61
+ outputs = self.roberta(*args, **kwargs)
62
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
63
+ proj = self.linear(last_hidden_states)
64
+ # normalize l2 norm
65
+ proj = proj / proj.norm(dim=-1, keepdim=True)
66
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
67
+ return proj
68
+
69
+
70
+ class BiXLMRoBERTa(XLMRobertaPreTrainedModel):
71
+ def __init__(self, config):
72
+ super(BiXLMRoBERTa, self).__init__(config=config)
73
+ self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
74
+ self.main_input_name = "doc_input_ids"
75
+
76
+ def forward(self, *args, **kwargs):
77
+ """
78
+ Forward pass through Roberta and the linear layer for dimensionality reduction
79
+
80
+ Args:
81
+ - input_ids (torch.LongTensor): The input tokens tensor.
82
+ - attention_mask (torch.LongTensor): The attention mask tensor.
83
+
84
+ Returns:
85
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
86
+ """
87
+ outputs = self.roberta(*args, **kwargs)
88
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
89
+ # pooling - mean tokens that have attention mask == 1
90
+ proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
91
+ proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
92
+ # normalize l2 norm
93
+ proj = proj / proj.norm(dim=-1, keepdim=True)
94
+ return proj
95
+
96
+
97
+ class ColBERT(BertPreTrainedModel):
98
+ def __init__(self, config):
99
+ super(ColBERT, self).__init__(config=config)
100
+ self.bert: BertModel = BertModel(config)
101
+ self.dim = 128
102
+ self.linear = nn.Linear(self.bert.config.hidden_size, self.dim)
103
+ self.main_input_name = "doc_input_ids"
104
+
105
+ def forward(self, *args, **kwargs):
106
+ """
107
+ Forward pass through BERT and the linear layer for dimensionality reduction
108
+
109
+ Args:
110
+ - input_ids (torch.LongTensor): The input tokens tensor.
111
+ - attention_mask (torch.LongTensor): The attention mask tensor.
112
+
113
+ Returns:
114
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
115
+ """
116
+ outputs = self.bert(*args, **kwargs)
117
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
118
+ proj = self.linear(last_hidden_states)
119
+ # normalize l2 norm
120
+ proj = proj / proj.norm(dim=-1, keepdim=True)
121
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
122
+ return proj
123
+
124
+
125
+ class BiBERT(BertPreTrainedModel):
126
+ def __init__(self, config):
127
+ super(BiBERT, self).__init__(config=config)
128
+ self.bert: BertModel = BertModel(config)
129
+ self.main_input_name = "doc_input_ids"
130
+
131
+ def forward(self, *args, **kwargs):
132
+ """
133
+ Forward pass through BERT and the linear layer for dimensionality reduction
134
+
135
+ Args:
136
+ - input_ids (torch.LongTensor): The input tokens tensor.
137
+ - attention_mask (torch.LongTensor): The attention mask tensor.
138
+
139
+ Returns:
140
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
141
+ """
142
+ outputs = self.bert(*args, **kwargs)
143
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
144
+ # pooling - mean tokens that have attention mask == 1
145
+ proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
146
+ proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
147
+ # normalize l2 norm
148
+ proj = proj / proj.norm(dim=-1, keepdim=True)
149
+ return proj
150
+
151
+
152
+ class ColLlama(LlamaPreTrainedModel):
153
+ def __init__(self, config):
154
+ super(ColLlama, self).__init__(config=config)
155
+ self.model: LlamaModel = LlamaModel(config)
156
+ self.dim = 128
157
+ self.linear = nn.Linear(self.model.config.hidden_size, self.dim)
158
+ self.main_input_name = "doc_input_ids"
159
+
160
+ def forward(self, *args, **kwargs):
161
+ """
162
+ Forward pass through Llama and the linear layer for dimensionality reduction
163
+
164
+ Args:
165
+ - input_ids (torch.LongTensor): The input tokens tensor.
166
+ - attention_mask (torch.LongTensor): The attention mask tensor.
167
+
168
+ Returns:
169
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
170
+ """
171
+ outputs = self.model(*args, **kwargs)
172
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
173
+ proj = self.linear(last_hidden_states)
174
+ # normalize l2 norm
175
+ proj = proj / proj.norm(dim=-1, keepdim=True)
176
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
177
+ return proj
colpali-main/colpali_engine/models/idefics_colbert_architecture.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from transformers import Idefics2Model, Idefics2PreTrainedModel
3
+
4
+
5
+ class BiIdefics(Idefics2PreTrainedModel):
6
+ def __init__(self, config):
7
+ super(BiIdefics, self).__init__(config=config)
8
+ self.model: Idefics2Model = Idefics2Model(config)
9
+ self.pooling_strategy = "last"
10
+ self.main_input_name = "doc_input_ids"
11
+
12
+ def forward(self, *args, **kwargs):
13
+ """
14
+ Forward pass through Llama and the linear layer for dimensionality reduction
15
+
16
+ Args:
17
+ - input_ids (torch.LongTensor): The input tokens tensor.
18
+ - attention_mask (torch.LongTensor): The attention mask tensor.
19
+
20
+ Returns:
21
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
22
+ """
23
+ outputs = self.model(*args, **kwargs)
24
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
25
+ # pooling - last token
26
+ proj = last_hidden_states[:, -1, :]
27
+ # normalize l2 norm
28
+ proj = proj / proj.norm(dim=-1, keepdim=True)
29
+ return proj
30
+
31
+
32
+ class ColIdefics(Idefics2PreTrainedModel):
33
+ def __init__(self, config):
34
+ super(ColIdefics, self).__init__(config=config)
35
+ self.model: Idefics2Model = Idefics2Model(config)
36
+ self.dim = 128
37
+ self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
38
+ self.main_input_name = "doc_input_ids"
39
+
40
+ def forward(self, *args, **kwargs):
41
+ """
42
+ Forward pass through Llama and the linear layer for dimensionality reduction
43
+
44
+ Args:
45
+ - input_ids (torch.LongTensor): The input tokens tensor.
46
+ - attention_mask (torch.LongTensor): The attention mask tensor.
47
+
48
+ Returns:
49
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
50
+ """
51
+ outputs = self.model(*args, **kwargs)
52
+ last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
53
+ proj = self.linear(last_hidden_states)
54
+ # normalize l2 norm
55
+ proj = proj / proj.norm(dim=-1, keepdim=True)
56
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
57
+ return proj
colpali-main/colpali_engine/models/paligemma_colbert_architecture.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel
4
+
5
+
6
+ class BiPaliLast(PaliGemmaPreTrainedModel):
7
+ def __init__(self, config):
8
+ super(BiPaliLast, self).__init__(config=config)
9
+ self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
10
+ self.pooling_strategy = "last"
11
+ self.main_input_name = "doc_input_ids"
12
+
13
+ def forward(self, *args, **kwargs):
14
+ """
15
+ Forward pass through Llama and the linear layer for dimensionality reduction
16
+
17
+ Args:
18
+ - input_ids (torch.LongTensor): The input tokens tensor.
19
+ - attention_mask (torch.LongTensor): The attention mask tensor.
20
+
21
+ Returns:
22
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
23
+ """
24
+ outputs = self.model(*args, output_hidden_states=True, **kwargs)
25
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
26
+ # pooling - last token
27
+ proj = last_hidden_states[:, -1, :]
28
+ # normalize l2 norm
29
+ proj = proj / proj.norm(dim=-1, keepdim=True)
30
+ return proj
31
+
32
+
33
+ class BiPaliMean(PaliGemmaPreTrainedModel):
34
+ def __init__(self, config):
35
+ super(BiPaliMean, self).__init__(config=config)
36
+ self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
37
+ self.pooling_strategy = "mean"
38
+ self.main_input_name = "doc_input_ids"
39
+
40
+ def forward(self, *args, **kwargs):
41
+ """
42
+ Forward pass through Llama and the linear layer for dimensionality reduction
43
+
44
+ Args:
45
+ - input_ids (torch.LongTensor): The input tokens tensor.
46
+ - attention_mask (torch.LongTensor): The attention mask tensor.
47
+
48
+ Returns:
49
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
50
+ """
51
+ outputs = self.model(*args, output_hidden_states=True, **kwargs)
52
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
53
+ # pooling -mean on attention mask==1
54
+ proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
55
+ kwargs["attention_mask"], dim=1, keepdim=True
56
+ )
57
+ proj = proj / proj.norm(dim=-1, keepdim=True)
58
+ return proj
59
+
60
+
61
+ class ColPali(PaliGemmaPreTrainedModel):
62
+ def __init__(self, config):
63
+ super(ColPali, self).__init__(config=config)
64
+ self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
65
+ self.dim = 128
66
+ self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
67
+ self.main_input_name = "doc_input_ids"
68
+
69
+ def forward(self, *args, **kwargs):
70
+ """
71
+ Forward pass through Llama and the linear layer for dimensionality reduction
72
+
73
+ Args:
74
+ - input_ids (torch.LongTensor): The input tokens tensor.
75
+ - attention_mask (torch.LongTensor): The attention mask tensor.
76
+
77
+ Returns:
78
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
79
+ """
80
+ outputs = self.model(*args, output_hidden_states=True, **kwargs)
81
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
82
+ proj = self.custom_text_proj(last_hidden_states)
83
+ # normalize l2 norm
84
+ proj = proj / proj.norm(dim=-1, keepdim=True)
85
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
86
+ return proj
87
+
88
+
89
+ class ColNewSiglip(PaliGemmaPreTrainedModel):
90
+ def __init__(self, config):
91
+ super(ColNewSiglip, self).__init__(config=config)
92
+ self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
93
+ self.dim = 128
94
+ self.custom_image_proj = nn.Linear(self.model.config.vision_config.projection_dim, self.dim)
95
+ self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
96
+ self.main_input_name = "doc_input_ids"
97
+
98
+ def forward(self, *args, **kwargs):
99
+ """
100
+ Forward pass through Llama and the linear layer for dimensionality reduction
101
+
102
+ Args:
103
+ - input_ids (torch.LongTensor): The input tokens tensor.
104
+ - attention_mask (torch.LongTensor): The attention mask tensor.
105
+
106
+ Returns:
107
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
108
+ """
109
+ # outputs = self.model(*args, output_hidden_states=True, **kwargs)
110
+ if "pixel_values" in kwargs:
111
+ image_features = self.vision_model_output(*args, **kwargs)
112
+ # print(f"Doc: {image_features.shape}")
113
+ proj = self.custom_image_proj(image_features)
114
+ # print(f"Doc proj: {proj.shape}")
115
+ proj = proj / proj.norm(dim=-1, keepdim=True)
116
+ else:
117
+ outputs = self.model(*args, output_hidden_states=True, **kwargs)
118
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
119
+ # print(f"Query: {last_hidden_states.shape}")
120
+ proj = self.custom_text_proj(last_hidden_states)
121
+ # print(f"Query proj: {proj.shape}")
122
+ # normalize l2 norm
123
+ proj = proj / proj.norm(dim=-1, keepdim=True)
124
+ proj = proj * kwargs["attention_mask"].unsqueeze(-1)
125
+ return proj
126
+
127
+ def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
128
+
129
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
130
+ # 2. Merge text and images
131
+ if pixel_values is not None and input_ids.shape[1] != 1:
132
+ image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
133
+ selected_image_feature = image_outputs.last_hidden_state
134
+ image_features = self.model.multi_modal_projector(selected_image_feature)
135
+
136
+ return image_features
137
+
138
+ raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
139
+
140
+
141
+ class BiNewSiglip(PaliGemmaPreTrainedModel):
142
+ def __init__(self, config):
143
+ super(BiNewSiglip, self).__init__(config=config)
144
+ self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
145
+ self.main_input_name = "doc_input_ids"
146
+
147
+ def forward(self, *args, **kwargs):
148
+ """
149
+ Forward pass through Llama and the linear layer for dimensionality reduction
150
+
151
+ Args:
152
+ - input_ids (torch.LongTensor): The input tokens tensor.
153
+ - attention_mask (torch.LongTensor): The attention mask tensor.
154
+
155
+ Returns:
156
+ - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
157
+ """
158
+ # outputs = self.model(*args, output_hidden_states=True, **kwargs)
159
+ if "pixel_values" in kwargs:
160
+ image_features = self.vision_model_output(*args, **kwargs)
161
+ # print(f"Doc: {image_features.shape}")
162
+ # pool image features
163
+ proj = torch.mean(image_features, dim=1)
164
+ # print(f"Doc proj: {proj.shape}")
165
+ norm = proj.norm(dim=-1, keepdim=True)
166
+ proj = proj / norm
167
+ else:
168
+ outputs = self.model(*args, output_hidden_states=True, **kwargs)
169
+ last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
170
+ # pooling -mean on attention mask==1
171
+
172
+ proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
173
+ kwargs["attention_mask"], dim=1, keepdim=True
174
+ )
175
+ # print(f"Query proj: {proj.shape}")
176
+ norm = proj.norm(dim=-1, keepdim=True)
177
+ proj = proj / norm
178
+ return proj
179
+
180
+ def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
181
+
182
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
183
+ # 2. Merge text and images
184
+ if pixel_values is not None and input_ids.shape[1] != 1:
185
+ image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
186
+ selected_image_feature = image_outputs.last_hidden_state
187
+ image_features = self.model.multi_modal_projector(selected_image_feature)
188
+
189
+ return image_features
190
+
191
+ raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
colpali-main/colpali_engine/trainer/__init__.py ADDED
File without changes
colpali-main/colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
colpali-main/colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
colpali-main/colpali_engine/trainer/contrastive_trainer.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Trainer
3
+
4
+
5
+ class ContrastiveTrainer(Trainer):
6
+ def __init__(self, loss_func, is_vision_model, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.loss_func = loss_func
9
+ self.is_vision_model = is_vision_model
10
+
11
+ def compute_loss(self, model, inputs, return_outputs=False):
12
+ query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
13
+ if self.is_vision_model:
14
+ if "doc_pixel_attention_mask" not in inputs:
15
+ doc_outputs = model(
16
+ input_ids=inputs["doc_input_ids"],
17
+ attention_mask=inputs["doc_attention_mask"],
18
+ pixel_values=inputs["doc_pixel_values"],
19
+ )
20
+ else:
21
+ doc_outputs = model(
22
+ input_ids=inputs["doc_input_ids"],
23
+ attention_mask=inputs["doc_attention_mask"],
24
+ pixel_values=inputs["doc_pixel_values"],
25
+ pixel_attention_mask=inputs["doc_pixel_attention_mask"],
26
+ )
27
+ else:
28
+ doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
29
+
30
+ loss = self.loss_func(query_outputs, doc_outputs)
31
+ return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
32
+
33
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
34
+ """This function is used to generate predictions and return the loss for the given inputs."""
35
+ if not prediction_loss_only:
36
+ raise ValueError("prediction_step is only called with prediction_loss_only=True")
37
+
38
+ with torch.no_grad():
39
+ if self.is_vision_model:
40
+ if "doc_pixel_attention_mask" not in inputs:
41
+ doc_outputs = model(
42
+ input_ids=inputs["doc_input_ids"],
43
+ attention_mask=inputs["doc_attention_mask"],
44
+ pixel_values=inputs["doc_pixel_values"],
45
+ )
46
+ else:
47
+ doc_outputs = model(
48
+ input_ids=inputs["doc_input_ids"],
49
+ attention_mask=inputs["doc_attention_mask"],
50
+ pixel_values=inputs["doc_pixel_values"],
51
+ pixel_attention_mask=inputs["doc_pixel_attention_mask"],
52
+ )
53
+ query_outputs = model(
54
+ input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
55
+ )
56
+ else:
57
+
58
+ query_outputs = model(
59
+ input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
60
+ )
61
+ doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
62
+
63
+ loss = self.loss_func(query_outputs, doc_outputs)
64
+ return loss, None, None
colpali-main/colpali_engine/trainer/retrieval_evaluator.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mteb.evaluation.evaluators import RetrievalEvaluator
3
+
4
+
5
+ class CustomEvaluator:
6
+ def __init__(self, is_multi_vector=False):
7
+ self.is_multi_vector = is_multi_vector
8
+ self.mteb_evaluator = RetrievalEvaluator()
9
+
10
+ def evaluate(self, qs, ps):
11
+ if self.is_multi_vector:
12
+ scores = self.evaluate_colbert(qs, ps)
13
+ else:
14
+ scores = self.evaluate_biencoder(qs, ps)
15
+
16
+ assert scores.shape[0] == len(qs)
17
+
18
+ arg_score = scores.argmax(dim=1)
19
+ # compare to arange
20
+ accuracy = (arg_score == torch.arange(scores.shape[0], device=scores.device)).sum().item() / scores.shape[0]
21
+ print(arg_score)
22
+ print(f"Top 1 Accuracy (verif): {accuracy}")
23
+
24
+ # cast to numpy
25
+ # scores = scores.cpu().numpy()
26
+ scores = scores.to(torch.float32).cpu().numpy()
27
+ return scores
28
+
29
+ def compute_metrics(self, relevant_docs, results, **kwargs):
30
+ # wrap mteb package
31
+
32
+ ndcg, _map, recall, precision, naucs = self.mteb_evaluator.evaluate(
33
+ relevant_docs,
34
+ results,
35
+ self.mteb_evaluator.k_values,
36
+ ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
37
+ )
38
+ mrr = self.mteb_evaluator.evaluate_custom(relevant_docs, results, self.mteb_evaluator.k_values, "mrr")
39
+ scores = {
40
+ **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
41
+ **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
42
+ **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
43
+ **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
44
+ **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
45
+ **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
46
+ }
47
+ return scores
48
+
49
+ def evaluate_colbert(self, qs, ps, batch_size=128) -> torch.Tensor:
50
+ scores = []
51
+ for i in range(0, len(qs), batch_size):
52
+ scores_batch = []
53
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
54
+ "cpu"
55
+ )
56
+ for j in range(0, len(ps), batch_size):
57
+ ps_batch = torch.nn.utils.rnn.pad_sequence(
58
+ ps[j : j + batch_size], batch_first=True, padding_value=0
59
+ ).to("cpu")
60
+ scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
61
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
62
+ scores.append(scores_batch)
63
+ scores = torch.cat(scores, dim=0)
64
+ return scores
65
+
66
+ def evaluate_biencoder(self, qs, ps) -> torch.Tensor:
67
+
68
+ qs = torch.stack(qs)
69
+ ps = torch.stack(ps)
70
+
71
+ scores = torch.einsum("bd,cd->bc", qs, ps)
72
+ return scores
colpali-main/colpali_engine/utils/__init__.py ADDED
File without changes
colpali-main/colpali_engine/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
colpali-main/colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc ADDED
Binary file (1.2 kB). View file
 
colpali-main/colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc ADDED
Binary file (998 Bytes). View file
 
colpali-main/colpali_engine/utils/colidefics_processing_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils for processing images and queries for ColPaLi
2
+
3
+ def process_images(processor, images, max_length: int = 50):
4
+ texts_doc = []
5
+ images = [image.convert("RGB") for image in images]
6
+
7
+ for _ in images:
8
+ messages_doc = [
9
+ {
10
+ "role": "user",
11
+ "content": [
12
+ {"type": "text", "text": "Describe the image."},
13
+ {"type": "image"},
14
+ ],
15
+ },
16
+ ]
17
+
18
+ text_doc = processor.apply_chat_template(messages_doc, add_generation_prompt=False)
19
+ texts_doc.append(text_doc.strip())
20
+
21
+ batch_doc = processor(
22
+ text=texts_doc,
23
+ images=images,
24
+ return_tensors="pt",
25
+ padding="longest",
26
+ )
27
+ return batch_doc
28
+
29
+
30
+ def process_queries(processor, queries, mock_image, max_length: int = 50):
31
+ texts_query = []
32
+ for query in queries:
33
+ messages_query = [
34
+ {
35
+ "role": "user",
36
+ "content": [
37
+ {
38
+ "type": "text",
39
+ "text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
40
+ },
41
+ ],
42
+ },
43
+ ]
44
+ text_query = processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
45
+ texts_query.append(text_query)
46
+
47
+ batch_query = processor(
48
+ text=texts_query,
49
+ return_tensors="pt",
50
+ padding="longest",
51
+ max_length=max_length,
52
+ )
53
+ return batch_query
colpali-main/colpali_engine/utils/colpali_processing_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils for processing images and queries for ColPaLi
2
+
3
+
4
+ def process_images(processor, images, max_length: int = 50):
5
+ texts_doc = ["Describe the image."] * len(images)
6
+ images = [image.convert("RGB") for image in images]
7
+
8
+ batch_doc = processor(
9
+ text=texts_doc,
10
+ images=images,
11
+ return_tensors="pt",
12
+ padding="longest",
13
+ max_length=max_length + processor.image_seq_length,
14
+ )
15
+ return batch_doc
16
+
17
+
18
+ def process_queries(processor, queries, mock_image, max_length: int = 50):
19
+ texts_query = []
20
+ for query in queries:
21
+ query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
22
+ texts_query.append(query)
23
+
24
+ batch_query = processor(
25
+ images=[mock_image.convert("RGB")] * len(texts_query),
26
+ # NOTE: the image is not used in batch_query but it is required for calling the processor
27
+ text=texts_query,
28
+ return_tensors="pt",
29
+ padding="longest",
30
+ max_length=max_length + processor.image_seq_length,
31
+ )
32
+ del batch_query["pixel_values"]
33
+
34
+ batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
35
+ batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
36
+ return batch_query
colpali-main/colpali_engine/utils/dataset_transformation.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
4
+
5
+ USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
6
+
7
+
8
+ def add_metadata_column(dataset, column_name, value):
9
+ def add_source(example):
10
+ example[column_name] = value
11
+ return example
12
+
13
+ return dataset.map(add_source)
14
+
15
+
16
+ def load_train_set() -> DatasetDict:
17
+
18
+ ds_paths = [
19
+ "infovqa_train",
20
+ "docvqa_train",
21
+ "arxivqa_train",
22
+ "tatdqa_train",
23
+ "syntheticDocQA_government_reports_train",
24
+ "syntheticDocQA_healthcare_industry_train",
25
+ "syntheticDocQA_artificial_intelligence_train",
26
+ "syntheticDocQA_energy_train",
27
+ ]
28
+ base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
29
+ ds_tot = []
30
+ for path in ds_paths:
31
+ cpath = base_path + path
32
+ ds = load_dataset(cpath, split="train")
33
+ if "arxivqa" in path:
34
+ # subsample 10k
35
+ ds = ds.shuffle(42).select(range(10000))
36
+ ds_tot.append(ds)
37
+
38
+ dataset = concatenate_datasets(ds_tot)
39
+ dataset = dataset.shuffle(seed=42)
40
+ # split into train and test
41
+ dataset_eval = dataset.select(range(500))
42
+ dataset = dataset.select(range(500, len(dataset)))
43
+ ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
44
+ return ds_dict
45
+
46
+
47
+ def load_train_set_with_tabfquad() -> DatasetDict:
48
+
49
+ ds_paths = [
50
+ "infovqa_train",
51
+ "docvqa_train",
52
+ "arxivqa_train",
53
+ "tatdqa_train",
54
+ "tabfquad_train_subsampled",
55
+ "syntheticDocQA_government_reports_train",
56
+ "syntheticDocQA_healthcare_industry_train",
57
+ "syntheticDocQA_artificial_intelligence_train",
58
+ "syntheticDocQA_energy_train",
59
+ ]
60
+ base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
61
+ ds_tot = []
62
+ for path in ds_paths:
63
+ cpath = base_path + path
64
+ ds = load_dataset(cpath, split="train")
65
+ if "arxivqa" in path:
66
+ # subsample 10k
67
+ ds = ds.shuffle(42).select(range(10000))
68
+ ds_tot.append(ds)
69
+
70
+ dataset = concatenate_datasets(ds_tot)
71
+ dataset = dataset.shuffle(seed=42)
72
+ # split into train and test
73
+ dataset_eval = dataset.select(range(500))
74
+ dataset = dataset.select(range(500, len(dataset)))
75
+ ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
76
+ return ds_dict
77
+
78
+
79
+ def load_train_set_with_docmatix() -> DatasetDict:
80
+
81
+ ds_paths = [
82
+ "infovqa_train",
83
+ "docvqa_train",
84
+ "arxivqa_train",
85
+ "tatdqa_train",
86
+ "tabfquad_train_subsampled",
87
+ "syntheticDocQA_government_reports_train",
88
+ "syntheticDocQA_healthcare_industry_train",
89
+ "syntheticDocQA_artificial_intelligence_train",
90
+ "syntheticDocQA_energy_train",
91
+ "Docmatix_filtered_train",
92
+ ]
93
+ base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
94
+ ds_tot = []
95
+ for path in ds_paths:
96
+ cpath = base_path + path
97
+ ds = load_dataset(cpath, split="train")
98
+ if "arxivqa" in path:
99
+ # subsample 10k
100
+ ds = ds.shuffle(42).select(range(10000))
101
+ ds_tot.append(ds)
102
+
103
+ dataset = concatenate_datasets(ds_tot)
104
+ dataset = dataset.shuffle(seed=42)
105
+ # split into train and test
106
+ dataset_eval = dataset.select(range(500))
107
+ dataset = dataset.select(range(500, len(dataset)))
108
+ ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
109
+ return ds_dict
110
+
111
+
112
+ def load_docvqa_dataset() -> DatasetDict:
113
+ if USE_LOCAL_DATASET:
114
+ dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
115
+ dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
116
+ dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
117
+ dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
118
+ else:
119
+ dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
120
+ dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
121
+ dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
122
+ dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
123
+
124
+ # concatenate the two datasets
125
+ dataset = concatenate_datasets([dataset_doc, dataset_info])
126
+ dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
127
+ # sample 100 from eval dataset
128
+ dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
129
+
130
+ # rename question as query
131
+ dataset = dataset.rename_column("question", "query")
132
+ dataset_eval = dataset_eval.rename_column("question", "query")
133
+
134
+ # create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
135
+ dataset = dataset.map(
136
+ lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
137
+ )
138
+ dataset_eval = dataset_eval.map(
139
+ lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
140
+ )
141
+
142
+ ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
143
+
144
+ return ds_dict
145
+
146
+
147
+ class TestSetFactory:
148
+ def __init__(self, dataset_path):
149
+ self.dataset_path = dataset_path
150
+
151
+ def __call__(self, *args, **kwargs):
152
+ dataset = load_dataset(self.dataset_path, split="test")
153
+ return dataset
154
+
155
+
156
+ if __name__ == "__main__":
157
+ ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
158
+ print(ds)
colpali-main/colpali_engine/utils/gpu_stats.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cond import
2
+ try:
3
+ from pynvml import *
4
+
5
+ def print_gpu_utilization():
6
+ nvmlInit()
7
+ handle = nvmlDeviceGetHandleByIndex(0)
8
+ info = nvmlDeviceGetMemoryInfo(handle)
9
+ print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
10
+
11
+ def print_summary(result):
12
+ print(f"Time: {result.metrics['train_runtime']:.2f}")
13
+ print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
14
+ print_gpu_utilization()
15
+
16
+ except ImportError:
17
+ print("pynvml not found. GPU stats will not be printed.")
18
+
19
+ def print_summary(result):
20
+ print(f"Time: {result.metrics['train_runtime']:.2f}")
21
+ print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
22
+
23
+ def print_gpu_utilization():
24
+ pass
colpali-main/colpali_engine/utils/image_from_page_utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from PIL import Image
3
+
4
+
5
+ def load_from_pdf(pdf_path: str):
6
+ from pdf2image import convert_from_path
7
+
8
+ images = convert_from_path(pdf_path)
9
+ return images
10
+
11
+
12
+ def load_from_image_urls(urls: str):
13
+ images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
14
+ return images
15
+
16
+
17
+ def load_from_dataset(dataset):
18
+ from datasets import load_dataset
19
+
20
+ dataset = load_dataset(dataset, split="test")
21
+ return dataset["image"]
colpali-main/colpali_engine/utils/image_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for working with images.
3
+ """
4
+
5
+ import base64
6
+ import io
7
+
8
+ from PIL import Image
9
+
10
+
11
+ def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
12
+ """
13
+ Scale an image to a new height while maintaining the aspect ratio.
14
+ """
15
+ # Calculate the scaling factor
16
+ width, height = image.size
17
+ aspect_ratio = width / height
18
+ new_width = int(new_height * aspect_ratio)
19
+
20
+ # Resize the image
21
+ scaled_image = image.resize((new_width, new_height))
22
+
23
+ return scaled_image
24
+
25
+
26
+ def scale_to_max_dimension(image: Image.Image, max_dimension: int = 1024) -> Image.Image:
27
+ """
28
+ Scale an image to a maximum dimension while maintaining the aspect ratio.
29
+ """
30
+ # Get the dimensions of the image
31
+ width, height = image.size
32
+
33
+ max_original_dimension = max(width, height)
34
+
35
+ if max_original_dimension < max_dimension:
36
+ return image
37
+
38
+ # Calculate the scaling factor
39
+ aspect_ratio = max_dimension / max_original_dimension
40
+ new_width = int(width * aspect_ratio)
41
+ new_height = int(height * aspect_ratio)
42
+
43
+ # Resize the image
44
+ scaled_image = image.resize((new_width, new_height))
45
+
46
+ return scaled_image
47
+
48
+
49
+ def get_base64_image(img: str | Image.Image, add_url_prefix: bool = True) -> str:
50
+ """
51
+ Convert an image (from a filepath or a PIL.Image object) to a JPEG-base64 string.
52
+ """
53
+ if isinstance(img, str):
54
+ img = Image.open(img)
55
+ elif isinstance(img, Image.Image):
56
+ pass
57
+ else:
58
+ raise ValueError("`img` must be a path to an image or a PIL Image object.")
59
+
60
+ buffered = io.BytesIO()
61
+ img.save(buffered, format="jpeg")
62
+ b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
63
+
64
+ return f"data:image/jpeg;base64,{b64_data}" if add_url_prefix else b64_data
colpali-main/colpali_engine/utils/iter_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ def islice(iterable, *args):
5
+ """
6
+ Yield a slice of an iterable.
7
+ >>> islice('ABCDEFG', 2) → A B
8
+ >>> islice('ABCDEFG', 2, 4) → C D
9
+ >>> islice('ABCDEFG', 2, None) → C D E F G
10
+ >>> islice('ABCDEFG', 0, None, 2) → A C E G
11
+ """
12
+ s = slice(*args)
13
+ start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
14
+ it = iter(range(start, stop, step))
15
+ try:
16
+ nexti = next(it)
17
+ except StopIteration:
18
+ # Consume *iterable* up to the *start* position.
19
+ for i, element in zip(range(start), iterable):
20
+ pass
21
+ return
22
+ try:
23
+ for i, element in enumerate(iterable):
24
+ if i == nexti:
25
+ yield element
26
+ nexti = next(it)
27
+ except StopIteration:
28
+ # Consume to *stop*.
29
+ for i, element in zip(range(i + 1, stop), iterable):
30
+ pass
31
+
32
+
33
+ def batched(iterable, n: int):
34
+ """
35
+ Yield batches of n elements from an iterable.
36
+ >>> batched('ABCDEFG', 3) → ABC DEF G
37
+ """
38
+ if n < 1:
39
+ raise ValueError("n must be at least one")
40
+ it = iter(iterable)
41
+ while batch := tuple(islice(it, n)):
42
+ yield batch
colpali-main/colpali_engine/utils/pdf_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import random
4
+ from pathlib import Path
5
+
6
+ from pdf2image import convert_from_path
7
+ from tqdm import tqdm
8
+
9
+ random.seed(42)
10
+
11
+
12
+ def convert_pdf_to_images(pdf_file: str, save_folder: str):
13
+ """
14
+ Convert each page of a pdf to a jpg image and save them in a folder.
15
+
16
+ Args:
17
+ - pdf_file (str): path to the pdf file
18
+ - save_folder (str): path to the folder where the images will be saved
19
+
20
+ """
21
+ images = convert_from_path(pdf_file)
22
+
23
+ for i, image in enumerate(images):
24
+ if not os.path.exists(save_folder):
25
+ os.makedirs(save_folder)
26
+ image.save(os.path.join(save_folder, f"page_{i+1}.jpg"), "JPEG")
27
+
28
+
29
+ def convert_all_pdfs_to_images(path_to_folder: str, n_samples: int = 0):
30
+ """
31
+ Convert all pdfs in a folder and its subfolder to images and save them in a folder.
32
+ It will sample n_samples pdf files in each subfolder, allowing to have granularity on the number of pdf files to convert.
33
+
34
+
35
+ Args:
36
+ - path_to_folder (str): path to the folder containing the pdf files
37
+ - n_samples (int): number of pdf files to sample in each subfolder
38
+
39
+ directory structure:
40
+ - path_to_folder
41
+ - subfolder1
42
+ - pdf1
43
+ - pdf2
44
+ - ...
45
+ - subfolder2
46
+ - pdf1
47
+ - pdf2
48
+ - ...
49
+ - ...
50
+
51
+ """
52
+ # take n_samples pdf files in each subfolder : I want to take 10 pdf files from each subfolder
53
+ sub_dirs = [d for d in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, d))]
54
+
55
+ sampled_files = []
56
+
57
+ for sub_dir in sub_dirs:
58
+ pdf_files = glob.glob(os.path.join(path_to_folder, sub_dir, "*.pdf"))
59
+
60
+ if (n_samples == 0) or (len(pdf_files) <= n_samples):
61
+ print(f"Taking all pdf files in {sub_dir}")
62
+ sampled_files.extend(pdf_files)
63
+
64
+ else:
65
+ print(f"Taking {n_samples} pdf files in {sub_dir}")
66
+ sampled_files.extend(random.sample(pdf_files, n_samples))
67
+
68
+ pdf_files = [str(file) for file in sampled_files]
69
+
70
+ # Create an empty text file that will contain the file paths of the corrupted pdf files
71
+ dirpath_corrupted = Path(path_to_folder) / "corrupted_pdf_files.txt"
72
+ dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)
73
+
74
+ with dirpath_corrupted.open("w") as f:
75
+ with tqdm(total=len(pdf_files)) as pbar:
76
+ for pdf_file in pdf_files:
77
+ pbar.set_description(f"Processing {pdf_file}")
78
+ save_folder = os.path.join("pages_extracted", *Path(pdf_file).parts[-2:])
79
+ if not os.path.exists(os.path.join(path_to_folder, save_folder)):
80
+ try:
81
+ convert_pdf_to_images(pdf_file, os.path.join(path_to_folder, save_folder))
82
+ except Exception as e:
83
+ print(f"Error converting {pdf_file}: {e}")
84
+ f.write(pdf_file)
85
+ f.write("\n")
86
+ pbar.update(1)
87
+ return
colpali-main/colpali_engine/utils/plot_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import seaborn as sns
2
+
3
+
4
+ def setup_seaborn():
5
+ sns.set_style("white")
6
+ sns.set_context("paper", font_scale=2)
colpali-main/colpali_engine/utils/torch_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for interpretability.
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def get_torch_device() -> str:
9
+ """
10
+ Returns the device and dtype to be used for torch tensors.
11
+ """
12
+ if torch.cuda.is_available():
13
+ device = "cuda:0"
14
+ elif torch.backends.mps.is_available(): # for Apple Silicon
15
+ device = "mps"
16
+ else:
17
+ device = "cpu"
18
+ return device
colpali-main/colpali_engine/utils/train_colpali_engine_models.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace trainer
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Callable, Dict, Optional
6
+
7
+ import torch
8
+ from datasets import concatenate_datasets
9
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+ from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments
13
+
14
+ from colpali_engine.dataset.custom_collator import CustomCollator
15
+ from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss
16
+ from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
17
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
18
+ from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
19
+
20
+
21
+ @dataclass
22
+ class ColModelTrainingConfig:
23
+ model: PreTrainedModel
24
+ tr_args: TrainingArguments = None
25
+ output_dir: str = None
26
+ max_length: int = 256
27
+ run_eval: bool = True
28
+ run_train: bool = True
29
+ peft_config: Optional[LoraConfig] = None
30
+ add_suffix: bool = False
31
+ processor: Idefics2Processor = None
32
+ tokenizer: PreTrainedTokenizer = None
33
+ loss_func: Optional[Callable] = ColbertLoss()
34
+ dataset_loading_func: Optional[Callable] = None
35
+ eval_dataset_loader: Optional[Dict[str, Callable]] = None
36
+ pretrained_peft_model_name_or_path: Optional[str] = None
37
+
38
+ def __post_init__(self):
39
+ if self.output_dir is None:
40
+ sanitized_name = str(self.model.name_or_path).replace("/", "_")
41
+ self.output_dir = f"./models/{sanitized_name}"
42
+
43
+ if self.tr_args is None:
44
+ self.tr_args = TrainingArguments(output_dir=self.output_dir)
45
+ elif self.tr_args.output_dir is None:
46
+ self.tr_args.output_dir = self.output_dir
47
+
48
+ # cast if string
49
+ if isinstance(self.tr_args.learning_rate, str):
50
+ self.tr_args.learning_rate = float(self.tr_args.learning_rate)
51
+ self.tr_args.remove_unused_columns = False
52
+
53
+ if self.processor is None and self.tokenizer is None:
54
+ print("Using textual model tokenization")
55
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
56
+
57
+ if self.pretrained_peft_model_name_or_path is not None:
58
+ self.model.load_adapter(self.pretrained_peft_model_name_or_path)
59
+ print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}")
60
+
61
+ if self.peft_config is not None:
62
+ print("Configurating PEFT model")
63
+ if self.processor is None:
64
+ # Might be deprecated - use the "else" branch
65
+ self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True
66
+ # self.model.enable_input_require_grads()
67
+ self.model = get_peft_model(self.model, self.peft_config)
68
+ self.model.print_trainable_parameters()
69
+ else:
70
+ # Ugly debugging hack
71
+ # if self.model.model.config.text_config.vocab_size == 32000:
72
+ # print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!")
73
+ # self.model.model.text_model.resize_token_embeddings(32003)
74
+ # self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2]
75
+ # self.model.enable_input_require_grads()
76
+ if self.pretrained_peft_model_name_or_path is None:
77
+ self.model.add_adapter(self.peft_config)
78
+ self.model.enable_adapters()
79
+ else:
80
+ print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
81
+
82
+ print_gpu_utilization()
83
+
84
+
85
+ class ColModelTraining:
86
+ def __init__(self, config: ColModelTrainingConfig) -> None:
87
+ self.config = config
88
+ self.model = self.config.model
89
+ self.dataset = self.config.dataset_loading_func()
90
+ self.collator = CustomCollator(
91
+ processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length
92
+ )
93
+ self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
94
+ self.retriever_evaluator = CustomEvaluator(
95
+ is_multi_vector=(
96
+ isinstance(self.config.loss_func, ColbertLoss)
97
+ or isinstance(self.config.loss_func, ColbertPairwiseCELoss)
98
+ )
99
+ )
100
+
101
+ def train(self) -> None:
102
+
103
+ trainer = ContrastiveTrainer(
104
+ model=self.model,
105
+ train_dataset=self.dataset["train"],
106
+ eval_dataset=self.dataset["test"],
107
+ args=self.config.tr_args,
108
+ data_collator=self.collator,
109
+ loss_func=self.config.loss_func,
110
+ is_vision_model=self.config.processor is not None,
111
+ )
112
+ trainer.args.remove_unused_columns = False
113
+
114
+ result = trainer.train()
115
+ print_summary(result)
116
+
117
+ def eval_dataset(self, test_dataset):
118
+
119
+ self.model.eval()
120
+
121
+ # # debug
122
+ # if len(test_dataset) > 200:
123
+ # test_dataset = test_dataset.select(range(0, 100))
124
+
125
+ idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
126
+ idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]
127
+
128
+ dataloader_with_query = DataLoader(
129
+ test_dataset.select(idx_with_query),
130
+ batch_size=self.config.tr_args.per_device_eval_batch_size,
131
+ shuffle=False,
132
+ collate_fn=self.collator,
133
+ )
134
+ dataloader_without_query = DataLoader(
135
+ test_dataset.select(idx_without_query),
136
+ batch_size=self.config.tr_args.per_device_eval_batch_size,
137
+ shuffle=False,
138
+ collate_fn=self.collator,
139
+ )
140
+
141
+ # dataset is ordered so that non-null queries come first
142
+ test_dataset = concatenate_datasets(
143
+ [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)]
144
+ )
145
+
146
+ relevant_docs = {}
147
+ docidx_2_docid = {}
148
+ qsidx_2_query = []
149
+ for idx, sample in enumerate(test_dataset):
150
+ doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"]))
151
+ # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"]))
152
+ if sample["query"] is not None:
153
+ relevant_docs[str(idx)] = {doc_id: 1}
154
+ qsidx_2_query.append(str(idx))
155
+ docidx_2_docid[str(idx)] = doc_id
156
+
157
+ qs = []
158
+ ps = []
159
+
160
+ device = self.model.device
161
+ with (torch.no_grad()):
162
+ for dataloader in [dataloader_with_query, dataloader_without_query]:
163
+ for batch in tqdm(dataloader):
164
+ if "doc_pixel_values" not in batch:
165
+ doc = self.model(
166
+ input_ids=batch["doc_input_ids"].to(device),
167
+ attention_mask=batch["doc_attention_mask"].to(device),
168
+ )
169
+
170
+ else:
171
+ if "doc_pixel_attention_mask" in batch:
172
+ doc = self.model(
173
+ input_ids=batch["doc_input_ids"].to(device),
174
+ attention_mask=batch["doc_attention_mask"].to(device),
175
+ pixel_values=batch["doc_pixel_values"].to(device),
176
+ pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device),
177
+ )
178
+ else:
179
+ doc = self.model(
180
+ input_ids=batch["doc_input_ids"].to(device),
181
+ attention_mask=batch["doc_attention_mask"].to(device),
182
+ pixel_values=batch["doc_pixel_values"].to(device),
183
+ )
184
+
185
+ ps.extend(list(torch.unbind(doc.to("cpu"))))
186
+
187
+ if "query_input_ids" in batch:
188
+ query = self.model(
189
+ input_ids=batch["query_input_ids"].to(device),
190
+ attention_mask=batch["query_attention_mask"].to(device),
191
+ )
192
+ # variable len
193
+ qs.extend(list(torch.unbind(query.to("cpu"))))
194
+
195
+ print("Embeddings computed, evaluating")
196
+ scores = self.retriever_evaluator.evaluate(qs, ps)
197
+ # scores is 2d array of shape (n_queries, n_docs)
198
+ # turn it into a dict
199
+ results = {}
200
+ assert scores.shape[0] == len(qsidx_2_query)
201
+ for idx, scores_per_query in enumerate(scores):
202
+ results[qsidx_2_query[idx]] = {
203
+ docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query)
204
+ }
205
+
206
+ # evaluate
207
+ metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results)
208
+ print(metrics)
209
+ return metrics
210
+
211
+ def eval(self) -> None:
212
+
213
+ print("Evaluating on validation set")
214
+ metrics = self.eval_dataset(self.dataset["test"])
215
+ print(f"Metrics for validation set: {metrics}")
216
+ all_metrics = {"validation_set": metrics}
217
+
218
+ if self.config.eval_dataset_loader is not None:
219
+ for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
220
+ print(f"Evaluating {test_name}")
221
+ test_ds = test_dataset_loading_func()
222
+ metrics = self.eval_dataset(test_ds)
223
+ all_metrics[test_name] = metrics
224
+ print(f"Metrics for {test_name}: {metrics}")
225
+
226
+ # checkpoint dumps
227
+ with open(f"{self.config.output_dir}/results.json", "w") as f:
228
+ json.dump(all_metrics, f)
229
+
230
+ # save results as json
231
+ with open(f"{self.config.output_dir}/results.json", "w") as f:
232
+ json.dump(all_metrics, f)
233
+
234
+ def save(self, config_file):
235
+ # save model
236
+ self.model.save_pretrained(self.config.output_dir)
237
+ if self.config.tokenizer is not None:
238
+ self.config.tokenizer.save_pretrained(self.config.output_dir)
239
+ if self.config.processor is not None:
240
+ self.config.processor.save_pretrained(self.config.output_dir) # save config
241
+
242
+ # copy-paste the yml file with os
243
+ os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml")
244
+
245
+ # save git hash of the commit at beginning of training
246
+ with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
247
+ f.write(self.current_git_hash)
colpali-main/colpali_engine/utils/wrapper.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ from colpali_engine.models.clip_baselines import ColSigLIP, SigLIP
4
+ from colpali_engine.models.colbert_architectures import (
5
+ BiBERT,
6
+ BiXLMRoBERTa,
7
+ ColBERT,
8
+ ColCamembert,
9
+ ColLlama,
10
+ ColXLMRoBERTa,
11
+ )
12
+ from colpali_engine.models.idefics_colbert_architecture import BiIdefics, ColIdefics
13
+ from colpali_engine.models.paligemma_colbert_architecture import (
14
+ BiNewSiglip,
15
+ BiPaliLast,
16
+ BiPaliMean,
17
+ ColNewSiglip,
18
+ ColPali,
19
+ )
20
+
21
+ if importlib.util.find_spec("transformers") is not None:
22
+ from transformers import AutoProcessor, AutoTokenizer
23
+ from transformers.tokenization_utils import PreTrainedTokenizer
24
+
25
+ class AutoProcessorWrapper:
26
+ def __new__(cls, *args, **kwargs):
27
+ return AutoProcessor.from_pretrained(*args, **kwargs)
28
+
29
+ class AutoTokenizerWrapper(PreTrainedTokenizer):
30
+ def __new__(cls, *args, **kwargs):
31
+ return AutoTokenizer.from_pretrained(*args, **kwargs)
32
+
33
+ class AutoColModelWrapper:
34
+ def __new__(cls, *args, **kwargs):
35
+ pretrained_model_name_or_path = None
36
+ if args:
37
+ pretrained_model_name_or_path = args[0]
38
+ elif kwargs:
39
+ pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
40
+
41
+ training_objective = kwargs.pop("training_objective", "colbertv1")
42
+
43
+ if "camembert" in pretrained_model_name_or_path:
44
+ return ColCamembert.from_pretrained(*args, **kwargs)
45
+ elif "xlm-roberta" in pretrained_model_name_or_path:
46
+ if training_objective == "biencoder":
47
+ return BiXLMRoBERTa.from_pretrained(*args, **kwargs)
48
+ return ColXLMRoBERTa.from_pretrained(*args, **kwargs)
49
+ elif (
50
+ "llama" in pretrained_model_name_or_path.lower() or "croissant" in pretrained_model_name_or_path.lower()
51
+ ):
52
+ return ColLlama.from_pretrained(*args, **kwargs)
53
+ elif "idefics2" in pretrained_model_name_or_path:
54
+ if training_objective == "biencoder":
55
+ return BiIdefics.from_pretrained(*args, **kwargs)
56
+ return ColIdefics.from_pretrained(*args, **kwargs)
57
+ elif "siglip" in pretrained_model_name_or_path:
58
+ if training_objective == "biencoder_mean":
59
+ return SigLIP.from_pretrained(*args, **kwargs)
60
+ elif training_objective == "colbertv1":
61
+ return ColSigLIP.from_pretrained(*args, **kwargs)
62
+ else:
63
+ raise ValueError(f"Training objective {training_objective} not recognized")
64
+ elif "paligemma" in pretrained_model_name_or_path:
65
+ if training_objective == "biencoder_mean":
66
+ return BiPaliMean.from_pretrained(*args, **kwargs)
67
+ elif training_objective == "biencoder_last":
68
+ return BiPaliLast.from_pretrained(*args, **kwargs)
69
+ elif training_objective == "biencoder_mean_vision":
70
+ return BiNewSiglip.from_pretrained(*args, **kwargs)
71
+ elif training_objective == "colbertv1_vision":
72
+ return ColNewSiglip.from_pretrained(*args, **kwargs)
73
+ elif training_objective == "colbertv1":
74
+ return ColPali.from_pretrained(*args, **kwargs)
75
+ else:
76
+ raise ValueError(f"Training objective {training_objective} not recognized")
77
+ else:
78
+ if training_objective == "biencoder":
79
+ return BiBERT.from_pretrained(*args, **kwargs)
80
+ return ColBERT.from_pretrained(*args, **kwargs)
81
+
82
+ else:
83
+ raise ModuleNotFoundError("Transformers must be loaded")
colpali-main/demo/README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ title: cvquest-colpali
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.39.0
6
+ ---