Spaces:
Sleeping
Sleeping
HUANG-Stephanie
commited on
Commit
•
9ff79dc
1
Parent(s):
c4d37d5
Upload 88 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- colpali-main/.env.dist +5 -0
- colpali-main/.gitattributes +31 -0
- colpali-main/.gitignore +179 -0
- colpali-main/.python-version +1 -0
- colpali-main/LICENSE +21 -0
- colpali-main/README.md +222 -0
- colpali-main/colpali_engine/__init__.py +0 -0
- colpali-main/colpali_engine/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/dataset/__init__.py +0 -0
- colpali-main/colpali_engine/dataset/custom_collator.py +244 -0
- colpali-main/colpali_engine/dataset/hf_dataset_names.py +52 -0
- colpali-main/colpali_engine/evaluation/__init__.py +1 -0
- colpali-main/colpali_engine/evaluation/eval_manager.py +178 -0
- colpali-main/colpali_engine/interpretability/__init__.py +4 -0
- colpali-main/colpali_engine/interpretability/gen_interpretability_plots.py +113 -0
- colpali-main/colpali_engine/interpretability/plot_utils.py +131 -0
- colpali-main/colpali_engine/interpretability/processor.py +116 -0
- colpali-main/colpali_engine/interpretability/torch_utils.py +60 -0
- colpali-main/colpali_engine/interpretability/vit_configs.py +23 -0
- colpali-main/colpali_engine/loss/__init__.py +1 -0
- colpali-main/colpali_engine/loss/colbert_loss.py +122 -0
- colpali-main/colpali_engine/models/__init__.py +0 -0
- colpali-main/colpali_engine/models/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/models/clip_baselines.py +144 -0
- colpali-main/colpali_engine/models/colbert_architectures.py +177 -0
- colpali-main/colpali_engine/models/idefics_colbert_architecture.py +57 -0
- colpali-main/colpali_engine/models/paligemma_colbert_architecture.py +191 -0
- colpali-main/colpali_engine/trainer/__init__.py +0 -0
- colpali-main/colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/trainer/contrastive_trainer.py +64 -0
- colpali-main/colpali_engine/trainer/retrieval_evaluator.py +72 -0
- colpali-main/colpali_engine/utils/__init__.py +0 -0
- colpali-main/colpali_engine/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc +0 -0
- colpali-main/colpali_engine/utils/colidefics_processing_utils.py +53 -0
- colpali-main/colpali_engine/utils/colpali_processing_utils.py +36 -0
- colpali-main/colpali_engine/utils/dataset_transformation.py +158 -0
- colpali-main/colpali_engine/utils/gpu_stats.py +24 -0
- colpali-main/colpali_engine/utils/image_from_page_utils.py +21 -0
- colpali-main/colpali_engine/utils/image_utils.py +64 -0
- colpali-main/colpali_engine/utils/iter_utils.py +42 -0
- colpali-main/colpali_engine/utils/pdf_utils.py +87 -0
- colpali-main/colpali_engine/utils/plot_utils.py +6 -0
- colpali-main/colpali_engine/utils/torch_utils.py +18 -0
- colpali-main/colpali_engine/utils/train_colpali_engine_models.py +247 -0
- colpali-main/colpali_engine/utils/wrapper.py +83 -0
- colpali-main/demo/README.md +6 -0
colpali-main/.env.dist
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HF_TOKEN=
|
2 |
+
HF_DATASETS_CACHE=
|
3 |
+
|
4 |
+
VERTEX_PROJECT=
|
5 |
+
VERTEX_LOCATION=
|
colpali-main/.gitattributes
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.jsonl filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
4 |
+
|
5 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
colpali-main/.gitignore
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Custom
|
2 |
+
.DS_Store
|
3 |
+
.env
|
4 |
+
.litellm_cache/
|
5 |
+
data/litellm_cache_captionning/
|
6 |
+
.idea
|
7 |
+
.venv/
|
8 |
+
colbert/models/
|
9 |
+
logs/
|
10 |
+
data/downloaded_datasets/rimes_raw_dataset/
|
11 |
+
models/
|
12 |
+
!colpali_engine/models
|
13 |
+
data/
|
14 |
+
!*/configs/data/
|
15 |
+
data_dir/
|
16 |
+
|
17 |
+
# Byte-compiled / optimized / DLL files
|
18 |
+
__pycache__/
|
19 |
+
*.py[cod]
|
20 |
+
*$py.class
|
21 |
+
|
22 |
+
# C extensions
|
23 |
+
*.so
|
24 |
+
|
25 |
+
# Distribution / packaging
|
26 |
+
.Python
|
27 |
+
build/
|
28 |
+
develop-eggs/
|
29 |
+
dist/
|
30 |
+
downloads/
|
31 |
+
eggs/
|
32 |
+
.eggs/
|
33 |
+
lib/
|
34 |
+
lib64/
|
35 |
+
parts/
|
36 |
+
sdist/
|
37 |
+
var/
|
38 |
+
wheels/
|
39 |
+
share/python-wheels/
|
40 |
+
*.egg-info/
|
41 |
+
.installed.cfg
|
42 |
+
*.egg
|
43 |
+
MANIFEST
|
44 |
+
|
45 |
+
# PyInstaller
|
46 |
+
# Usually these files are written by a python script from a template
|
47 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
48 |
+
*.manifest
|
49 |
+
*.spec
|
50 |
+
|
51 |
+
# Installer logs
|
52 |
+
pip-log.txt
|
53 |
+
pip-delete-this-directory.txt
|
54 |
+
|
55 |
+
# Unit test / coverage reports
|
56 |
+
htmlcov/
|
57 |
+
.tox/
|
58 |
+
.nox/
|
59 |
+
.coverage
|
60 |
+
.coverage.*
|
61 |
+
.cache
|
62 |
+
nosetests.xml
|
63 |
+
coverage.xml
|
64 |
+
*.cover
|
65 |
+
*.py,cover
|
66 |
+
.hypothesis/
|
67 |
+
.pytest_cache/
|
68 |
+
cover/
|
69 |
+
|
70 |
+
# Translations
|
71 |
+
*.mo
|
72 |
+
*.pot
|
73 |
+
|
74 |
+
# Django stuff:
|
75 |
+
*.log
|
76 |
+
local_settings.py
|
77 |
+
db.sqlite3
|
78 |
+
db.sqlite3-journal
|
79 |
+
|
80 |
+
# Flask stuff:
|
81 |
+
instance/
|
82 |
+
.webassets-cache
|
83 |
+
|
84 |
+
# Scrapy stuff:
|
85 |
+
.scrapy
|
86 |
+
|
87 |
+
# Sphinx documentation
|
88 |
+
docs/_build/
|
89 |
+
|
90 |
+
# PyBuilder
|
91 |
+
.pybuilder/
|
92 |
+
target/
|
93 |
+
|
94 |
+
# Jupyter Notebook
|
95 |
+
.ipynb_checkpoints
|
96 |
+
notebooks/*.png
|
97 |
+
|
98 |
+
# IPython
|
99 |
+
profile_default/
|
100 |
+
ipython_config.py
|
101 |
+
|
102 |
+
# pyenv
|
103 |
+
# For a library or package, you might want to ignore these files since the code is
|
104 |
+
# intended to run in multiple environments; otherwise, check them in:
|
105 |
+
# .python-version
|
106 |
+
|
107 |
+
# pipenv
|
108 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
109 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
110 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
111 |
+
# install all needed dependencies.
|
112 |
+
#Pipfile.lock
|
113 |
+
|
114 |
+
# poetry
|
115 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
116 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
117 |
+
# commonly ignored for libraries.
|
118 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
119 |
+
#poetry.lock
|
120 |
+
|
121 |
+
# pdm
|
122 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
123 |
+
#pdm.lock
|
124 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
125 |
+
# in version control.
|
126 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
127 |
+
.pdm.toml
|
128 |
+
.pdm-python
|
129 |
+
.pdm-build/
|
130 |
+
|
131 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
132 |
+
__pypackages__/
|
133 |
+
|
134 |
+
# Celery stuff
|
135 |
+
celerybeat-schedule
|
136 |
+
celerybeat.pid
|
137 |
+
|
138 |
+
# SageMath parsed files
|
139 |
+
*.sage.py
|
140 |
+
|
141 |
+
# Environments
|
142 |
+
.env
|
143 |
+
.venv
|
144 |
+
env/
|
145 |
+
venv/
|
146 |
+
ENV/
|
147 |
+
env.bak/
|
148 |
+
venv.bak/
|
149 |
+
|
150 |
+
# Spyder project settings
|
151 |
+
.spyderproject
|
152 |
+
.spyproject
|
153 |
+
|
154 |
+
# Rope project settings
|
155 |
+
.ropeproject
|
156 |
+
|
157 |
+
# mkdocs documentation
|
158 |
+
/site
|
159 |
+
|
160 |
+
# mypy
|
161 |
+
.mypy_cache/
|
162 |
+
.dmypy.json
|
163 |
+
dmypy.json
|
164 |
+
|
165 |
+
# Pyre type checker
|
166 |
+
.pyre/
|
167 |
+
|
168 |
+
# pytype static type analyzer
|
169 |
+
.pytype/
|
170 |
+
|
171 |
+
# Cython debug symbols
|
172 |
+
cython_debug/
|
173 |
+
|
174 |
+
# PyCharm
|
175 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
176 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
177 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
178 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
179 |
+
#.idea/
|
colpali-main/.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11.6
|
colpali-main/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Manuel Faysse, Hugues Sibille, Tony Wu
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
colpali-main/README.md
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ColPali: Efficient Document Retrieval with Vision Language Models
|
2 |
+
|
3 |
+
|
4 |
+
[[Blog]](https://huggingface.co/blog/manu/colpali)
|
5 |
+
[[Paper]](https://arxiv.org/abs/2407.01449)
|
6 |
+
[[ColPali Model card]](https://huggingface.co/vidore/colpali)
|
7 |
+
[[ViDoRe Benchmark]](https://huggingface.co/vidore)
|
8 |
+
<!---[[Colab example]]()-->
|
9 |
+
[[HuggingFace Demo]](https://huggingface.co/spaces/manu/ColPali-demo)
|
10 |
+
|
11 |
+
|
12 |
+
## Associated Paper
|
13 |
+
|
14 |
+
**ColPali: Efficient Document Retrieval with Vision Language Models**
|
15 |
+
Manuel Faysse, Hugues Sibille, Tony Wu, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo
|
16 |
+
|
17 |
+
This repository contains the code for training custom Colbert retriever models.
|
18 |
+
Notably, we train colbert with LLMs (decoders) as well as Image Language models !
|
19 |
+
|
20 |
+
## Installation
|
21 |
+
|
22 |
+
### From git
|
23 |
+
```bash
|
24 |
+
pip install git+https://github.com/illuin-tech/colpali
|
25 |
+
```
|
26 |
+
|
27 |
+
### From source
|
28 |
+
```bash
|
29 |
+
git clone https://github.com/illuin-tech/colpali
|
30 |
+
mv colpali
|
31 |
+
pip install -r requirements.txt
|
32 |
+
```
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
Example usage of the model is shown in the `scripts` directory.
|
37 |
+
|
38 |
+
```bash
|
39 |
+
# hackable example script to adapt
|
40 |
+
python scripts/infer/run_inference_with_python.py
|
41 |
+
```
|
42 |
+
|
43 |
+
|
44 |
+
```python
|
45 |
+
import torch
|
46 |
+
import typer
|
47 |
+
from torch.utils.data import DataLoader
|
48 |
+
from tqdm import tqdm
|
49 |
+
from transformers import AutoProcessor
|
50 |
+
from PIL import Image
|
51 |
+
|
52 |
+
from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
53 |
+
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
54 |
+
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
|
55 |
+
from colpali_engine.utils.image_from_page_utils import load_from_dataset
|
56 |
+
|
57 |
+
|
58 |
+
def main() -> None:
|
59 |
+
"""Example script to run inference with ColPali"""
|
60 |
+
# Load model
|
61 |
+
model_name = "vidore/colpali"
|
62 |
+
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
|
63 |
+
model.load_adapter(model_name)
|
64 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
65 |
+
|
66 |
+
# select images -> load_from_pdf(<pdf_path>), load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
|
67 |
+
images = load_from_dataset("vidore/docvqa_test_subsampled")
|
68 |
+
queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]
|
69 |
+
|
70 |
+
# run inference - docs
|
71 |
+
dataloader = DataLoader(
|
72 |
+
images,
|
73 |
+
batch_size=4,
|
74 |
+
shuffle=False,
|
75 |
+
collate_fn=lambda x: process_images(processor, x),
|
76 |
+
)
|
77 |
+
ds = []
|
78 |
+
for batch_doc in tqdm(dataloader):
|
79 |
+
with torch.no_grad():
|
80 |
+
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
81 |
+
embeddings_doc = model(**batch_doc)
|
82 |
+
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
83 |
+
|
84 |
+
# run inference - queries
|
85 |
+
dataloader = DataLoader(
|
86 |
+
queries,
|
87 |
+
batch_size=4,
|
88 |
+
shuffle=False,
|
89 |
+
collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
|
90 |
+
)
|
91 |
+
|
92 |
+
qs = []
|
93 |
+
for batch_query in dataloader:
|
94 |
+
with torch.no_grad():
|
95 |
+
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
96 |
+
embeddings_query = model(**batch_query)
|
97 |
+
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
98 |
+
|
99 |
+
# run evaluation
|
100 |
+
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
|
101 |
+
scores = retriever_evaluator.evaluate(qs, ds)
|
102 |
+
print(scores.argmax(axis=1))
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
typer.run(main)
|
107 |
+
```
|
108 |
+
|
109 |
+
Detais are also given in the model card for the base Colpali model on HuggingFace: [ColPali Model card](https://huggingface.co/vidore/colpali).
|
110 |
+
|
111 |
+
## Training
|
112 |
+
|
113 |
+
```bash
|
114 |
+
USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml
|
115 |
+
```
|
116 |
+
|
117 |
+
or
|
118 |
+
|
119 |
+
```bash
|
120 |
+
accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml
|
121 |
+
```
|
122 |
+
|
123 |
+
### Configurations
|
124 |
+
All training arguments can be set through a configuration file.
|
125 |
+
The configuration file is a yaml file that contains all the arguments for training.
|
126 |
+
|
127 |
+
The construction is as follows:
|
128 |
+
|
129 |
+
```python
|
130 |
+
@dataclass
|
131 |
+
class ColModelTrainingConfig:
|
132 |
+
model: PreTrainedModel
|
133 |
+
tr_args: TrainingArguments = None
|
134 |
+
output_dir: str = None
|
135 |
+
max_length: int = 256
|
136 |
+
run_eval: bool = True
|
137 |
+
run_train: bool = True
|
138 |
+
peft_config: Optional[LoraConfig] = None
|
139 |
+
add_suffix: bool = False
|
140 |
+
processor: Idefics2Processor = None
|
141 |
+
tokenizer: PreTrainedTokenizer = None
|
142 |
+
loss_func: Optional[Callable] = ColbertLoss()
|
143 |
+
dataset_loading_func: Optional[Callable] = None
|
144 |
+
eval_dataset_loader: Optional[Dict[str, Callable]] = None
|
145 |
+
pretrained_peft_model_name_or_path: Optional[str] = None
|
146 |
+
```
|
147 |
+
### Example
|
148 |
+
|
149 |
+
An example configuration file is:
|
150 |
+
|
151 |
+
```yaml
|
152 |
+
config:
|
153 |
+
(): colpali_engine.utils.train_colpali_engine_models.ColModelTrainingConfig
|
154 |
+
output_dir: !path ../../../models/without_tabfquad/train_colpali-3b-mix-448
|
155 |
+
processor:
|
156 |
+
() : colpali_engine.utils.wrapper.AutoProcessorWrapper
|
157 |
+
pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
|
158 |
+
max_length: 50
|
159 |
+
model:
|
160 |
+
(): colpali_engine.utils.wrapper.AutoColModelWrapper
|
161 |
+
pretrained_model_name_or_path: "./models/paligemma-3b-mix-448"
|
162 |
+
training_objective: "colbertv1"
|
163 |
+
# attn_implementation: "eager"
|
164 |
+
torch_dtype: !ext torch.bfloat16
|
165 |
+
# device_map: "auto"
|
166 |
+
# quantization_config:
|
167 |
+
# (): transformers.BitsAndBytesConfig
|
168 |
+
# load_in_4bit: true
|
169 |
+
# bnb_4bit_quant_type: "nf4"
|
170 |
+
# bnb_4bit_compute_dtype: "bfloat16"
|
171 |
+
# bnb_4bit_use_double_quant: true
|
172 |
+
|
173 |
+
dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set
|
174 |
+
eval_dataset_loader: !import ../data/test_data.yaml
|
175 |
+
|
176 |
+
max_length: 50
|
177 |
+
run_eval: true
|
178 |
+
add_suffix: true
|
179 |
+
loss_func:
|
180 |
+
(): colpali_engine.loss.colbert_loss.ColbertPairwiseCELoss
|
181 |
+
tr_args: !import ../tr_args/default_tr_args.yaml
|
182 |
+
peft_config:
|
183 |
+
(): peft.LoraConfig
|
184 |
+
r: 32
|
185 |
+
lora_alpha: 32
|
186 |
+
lora_dropout: 0.1
|
187 |
+
init_lora_weights: "gaussian"
|
188 |
+
bias: "none"
|
189 |
+
task_type: "FEATURE_EXTRACTION"
|
190 |
+
target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
|
191 |
+
# target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(custom_text_proj).*$)'
|
192 |
+
```
|
193 |
+
|
194 |
+
|
195 |
+
#### Local training
|
196 |
+
|
197 |
+
```bash
|
198 |
+
USE_LOCAL_DATASET=0 python scripts/train/train_colbert.py scripts/configs/siglip/train_siglip_model_debug.yaml
|
199 |
+
```
|
200 |
+
|
201 |
+
|
202 |
+
#### SLURM
|
203 |
+
|
204 |
+
```bash
|
205 |
+
sbatch --nodes=1 --cpus-per-task=16 --mem-per-cpu=32GB --time=20:00:00 --gres=gpu:1 -p gpua100 --job-name=colidefics --output=colidefics.out --error=colidefics.err --wrap="accelerate launch scripts/train/train_colbert.py scripts/configs/train_colidefics_model.yaml"
|
206 |
+
|
207 |
+
sbatch --nodes=1 --time=5:00:00 -A cad15443 --gres=gpu:8 --constraint=MI250 --job-name=colpali --wrap="python scripts/train/train_colbert.py scripts/configs/train_colpali_model.yaml"
|
208 |
+
```
|
209 |
+
|
210 |
+
## CITATION
|
211 |
+
|
212 |
+
```bibtex
|
213 |
+
@misc{faysse2024colpaliefficientdocumentretrieval,
|
214 |
+
title={ColPali: Efficient Document Retrieval with Vision Language Models},
|
215 |
+
author={Manuel Faysse and Hugues Sibille and Tony Wu and Bilel Omrani and Gautier Viaud and Céline Hudelot and Pierre Colombo},
|
216 |
+
year={2024},
|
217 |
+
eprint={2407.01449},
|
218 |
+
archivePrefix={arXiv},
|
219 |
+
primaryClass={cs.IR},
|
220 |
+
url={https://arxiv.org/abs/2407.01449},
|
221 |
+
}
|
222 |
+
```
|
colpali-main/colpali_engine/__init__.py
ADDED
File without changes
|
colpali-main/colpali_engine/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
colpali-main/colpali_engine/dataset/__init__.py
ADDED
File without changes
|
colpali-main/colpali_engine/dataset/custom_collator.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
2 |
+
|
3 |
+
|
4 |
+
class CustomCollator:
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
processor: ProcessorMixin = None,
|
8 |
+
tokenizer: PreTrainedTokenizer = None,
|
9 |
+
max_length: int = 2048,
|
10 |
+
add_suffix: bool = False,
|
11 |
+
):
|
12 |
+
self.processor = processor
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.image_token_id = None
|
15 |
+
self.max_length = max_length
|
16 |
+
self.suffix = ""
|
17 |
+
if add_suffix:
|
18 |
+
self.suffix = "\n" * 10
|
19 |
+
|
20 |
+
if tokenizer is None and processor is None:
|
21 |
+
raise ValueError("Either processor or tokenizer should be provided.")
|
22 |
+
|
23 |
+
if self.processor is not None:
|
24 |
+
if self.processor.__class__.__name__ != "SiglipProcessor":
|
25 |
+
self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
|
26 |
+
self.processor.tokenizer.additional_special_tokens.index("<image>")
|
27 |
+
]
|
28 |
+
|
29 |
+
if self.tokenizer is not None:
|
30 |
+
raise ValueError("Only one of processor or tokenizer should be provided.")
|
31 |
+
|
32 |
+
if self.tokenizer and self.tokenizer.pad_token is None:
|
33 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
34 |
+
|
35 |
+
def __call__(self, examples):
|
36 |
+
if self.processor is None:
|
37 |
+
return self.forward_text(examples)
|
38 |
+
if self.processor.__class__.__name__ == "Idefics2Processor":
|
39 |
+
return self.forward_vision_idefics(examples)
|
40 |
+
if self.processor.__class__.__name__ == "PaliGemmaProcessor":
|
41 |
+
return self.forward_vision_pali(examples)
|
42 |
+
if self.processor.__class__.__name__ == "SiglipProcessor":
|
43 |
+
return self.forward_vision_siglip(examples)
|
44 |
+
raise ValueError("Processor not supported")
|
45 |
+
|
46 |
+
def forward_text(self, examples):
|
47 |
+
texts_doc = []
|
48 |
+
texts_query = []
|
49 |
+
for example in examples:
|
50 |
+
text_query = example["query"] + self.suffix
|
51 |
+
text_doc = example["doc"]
|
52 |
+
|
53 |
+
texts_doc.append(text_doc.strip())
|
54 |
+
texts_query.append(text_query.strip())
|
55 |
+
|
56 |
+
batch_doc = self.tokenizer(
|
57 |
+
texts_doc, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
|
58 |
+
)
|
59 |
+
batch_query = self.tokenizer(
|
60 |
+
texts_query, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
|
61 |
+
)
|
62 |
+
|
63 |
+
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
64 |
+
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
65 |
+
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
66 |
+
batch_doc.update(batch_query)
|
67 |
+
|
68 |
+
return batch_doc
|
69 |
+
|
70 |
+
def forward_vision_idefics(self, examples):
|
71 |
+
texts_doc = []
|
72 |
+
texts_query = []
|
73 |
+
images = []
|
74 |
+
for example in examples:
|
75 |
+
image = example["image"]
|
76 |
+
|
77 |
+
text_query = None
|
78 |
+
if example["query"] is not None:
|
79 |
+
query = example["query"]
|
80 |
+
messages_query = [
|
81 |
+
{
|
82 |
+
"role": "user",
|
83 |
+
"content": [
|
84 |
+
{
|
85 |
+
"type": "text",
|
86 |
+
"text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
|
87 |
+
},
|
88 |
+
],
|
89 |
+
},
|
90 |
+
]
|
91 |
+
text_query = self.processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
|
92 |
+
|
93 |
+
messages_doc = [
|
94 |
+
{
|
95 |
+
"role": "user",
|
96 |
+
"content": [
|
97 |
+
{"type": "text", "text": "Describe the image."},
|
98 |
+
{"type": "image"},
|
99 |
+
],
|
100 |
+
},
|
101 |
+
]
|
102 |
+
|
103 |
+
text_doc = self.processor.apply_chat_template(messages_doc, add_generation_prompt=False)
|
104 |
+
|
105 |
+
texts_doc.append(text_doc.strip())
|
106 |
+
texts_query.append(text_query)
|
107 |
+
images.append([image])
|
108 |
+
|
109 |
+
batch_doc = self.processor(
|
110 |
+
text=texts_doc, images=images, return_tensors="pt", padding="longest", max_length=self.max_length
|
111 |
+
)
|
112 |
+
|
113 |
+
batch_query = None
|
114 |
+
if all([t is None for t in texts_query]):
|
115 |
+
print("All queries are None. Returning None for all queries.")
|
116 |
+
elif any([t is None for t in texts_query]):
|
117 |
+
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
118 |
+
else:
|
119 |
+
batch_query = self.processor(
|
120 |
+
text=texts_query, return_tensors="pt", padding="longest", max_length=self.max_length
|
121 |
+
)
|
122 |
+
|
123 |
+
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
124 |
+
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
125 |
+
|
126 |
+
if batch_query is not None:
|
127 |
+
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
128 |
+
batch_doc.update(batch_query)
|
129 |
+
|
130 |
+
return batch_doc
|
131 |
+
|
132 |
+
def forward_vision_pali(self, examples):
|
133 |
+
texts_doc = []
|
134 |
+
texts_query = []
|
135 |
+
images = []
|
136 |
+
for example in examples:
|
137 |
+
|
138 |
+
if example["image"] is None:
|
139 |
+
raise ValueError("Image is None - This collator does not support None images yet.")
|
140 |
+
|
141 |
+
image = example["image"].convert("RGB")
|
142 |
+
images.append(image)
|
143 |
+
texts_doc.append("Describe the image.")
|
144 |
+
|
145 |
+
if example["query"] is None:
|
146 |
+
texts_query.append(None)
|
147 |
+
else:
|
148 |
+
query = example["query"]
|
149 |
+
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
|
150 |
+
texts_query.append(query)
|
151 |
+
|
152 |
+
batch_doc = self.processor(
|
153 |
+
text=texts_doc,
|
154 |
+
images=images,
|
155 |
+
return_tensors="pt",
|
156 |
+
padding="longest",
|
157 |
+
max_length=self.max_length + self.processor.image_seq_length,
|
158 |
+
)
|
159 |
+
|
160 |
+
batch_query = None
|
161 |
+
# check if some but not all queries are None
|
162 |
+
if all([t is None for t in texts_query]):
|
163 |
+
print("All queries are None. Returning None for all queries.")
|
164 |
+
elif any([t is None for t in texts_query]):
|
165 |
+
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
166 |
+
else:
|
167 |
+
batch_query = self.processor(
|
168 |
+
images=images, # NOTE: the image is not used in batch_query but it is required for calling the processor
|
169 |
+
text=texts_query,
|
170 |
+
return_tensors="pt",
|
171 |
+
padding="longest",
|
172 |
+
max_length=self.max_length + self.processor.image_seq_length,
|
173 |
+
)
|
174 |
+
del batch_query["pixel_values"]
|
175 |
+
batch_query["input_ids"] = batch_query["input_ids"][..., self.processor.image_seq_length :]
|
176 |
+
batch_query["attention_mask"] = batch_query["attention_mask"][..., self.processor.image_seq_length :]
|
177 |
+
|
178 |
+
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
179 |
+
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
180 |
+
|
181 |
+
if batch_query is not None:
|
182 |
+
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
183 |
+
batch_doc.update(batch_query)
|
184 |
+
|
185 |
+
return batch_doc
|
186 |
+
|
187 |
+
def forward_vision_siglip(self, examples):
|
188 |
+
texts_doc = []
|
189 |
+
texts_query = []
|
190 |
+
images = []
|
191 |
+
for example in examples:
|
192 |
+
|
193 |
+
if example["image"] is None:
|
194 |
+
raise ValueError("Image is None - This collator does not support None images yet.")
|
195 |
+
|
196 |
+
image = example["image"].convert("RGB")
|
197 |
+
images.append(image)
|
198 |
+
texts_doc.append("Describe the image.")
|
199 |
+
|
200 |
+
if example["query"] is None:
|
201 |
+
texts_query.append(None)
|
202 |
+
else:
|
203 |
+
query = f"Question: {example['query']}"
|
204 |
+
texts_query.append(query)
|
205 |
+
|
206 |
+
batch_doc = self.processor(
|
207 |
+
text=texts_doc,
|
208 |
+
images=images,
|
209 |
+
return_tensors="pt",
|
210 |
+
padding="max_length",
|
211 |
+
truncation=True,
|
212 |
+
)
|
213 |
+
|
214 |
+
batch_query = None
|
215 |
+
# check if some but not all queries are None
|
216 |
+
if all([t is None for t in texts_query]):
|
217 |
+
# print("All queries are None.")
|
218 |
+
pass
|
219 |
+
elif any([t is None for t in texts_query]):
|
220 |
+
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
221 |
+
else:
|
222 |
+
batch_query = self.processor(
|
223 |
+
images=images,
|
224 |
+
text=texts_query,
|
225 |
+
return_tensors="pt",
|
226 |
+
padding="max_length",
|
227 |
+
max_length=self.max_length,
|
228 |
+
truncation=True,
|
229 |
+
)
|
230 |
+
del batch_query["pixel_values"]
|
231 |
+
|
232 |
+
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
233 |
+
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
234 |
+
|
235 |
+
if batch_query is not None:
|
236 |
+
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
237 |
+
batch_doc.update(batch_query)
|
238 |
+
# add attention mask for queries
|
239 |
+
batch_doc["query_attention_mask"] = batch_doc["query_input_ids"].ne(0).long()
|
240 |
+
|
241 |
+
# add attention mask for docs
|
242 |
+
batch_doc["doc_attention_mask"] = batch_doc["doc_input_ids"].ne(0).long()
|
243 |
+
|
244 |
+
return batch_doc
|
colpali-main/colpali_engine/dataset/hf_dataset_names.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class TrainDatasets(Enum):
|
5 |
+
"""
|
6 |
+
Dataset names for the training datasets used in HuggingFace Datasets.
|
7 |
+
"""
|
8 |
+
|
9 |
+
government_reports = "vidore/syntheticDocQA_government_reports_train"
|
10 |
+
healthcare_industry = "vidore/syntheticDocQA_healthcare_industry_train"
|
11 |
+
energy = "vidore/syntheticDocQA_energy_train"
|
12 |
+
artificial_intelligence = "vidore/syntheticDocQA_artificial_intelligence_train"
|
13 |
+
arxivqa = "vidore/arxivqa_train"
|
14 |
+
docvqa = "vidore/docvqa_train"
|
15 |
+
infovqa = "vidore/infovqa_train"
|
16 |
+
tatqa = "vidore/tatqa_train"
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def get_synthetic_datasets():
|
20 |
+
return [
|
21 |
+
TrainDatasets.government_reports,
|
22 |
+
TrainDatasets.healthcare_industry,
|
23 |
+
TrainDatasets.energy,
|
24 |
+
TrainDatasets.artificial_intelligence,
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
class TestImagesDirpath(Enum):
|
29 |
+
"""
|
30 |
+
Dataset names for the test datasets used in HuggingFace Datasets.
|
31 |
+
"""
|
32 |
+
|
33 |
+
government_reports = "data/government_reports"
|
34 |
+
healthcare_industry = "data/healthcare_industry"
|
35 |
+
energy = "data/energy"
|
36 |
+
artificial_intelligence = "data/scrapped_pdfs_split/pages_extracted/artificial_intelligence_test"
|
37 |
+
arxivqa = "data/arxivqa"
|
38 |
+
docvqa = "data/docvqa"
|
39 |
+
infovqa = "data/infovqa"
|
40 |
+
tatqa = "data/tatqa"
|
41 |
+
|
42 |
+
|
43 |
+
class CaptionedSyntheticDatasets(Enum):
|
44 |
+
"""
|
45 |
+
Dataset names for the captioned synthetic datasets used in HuggingFace Datasets.
|
46 |
+
"""
|
47 |
+
|
48 |
+
shift = "vidore/baseline_cap_shiftproject_test"
|
49 |
+
|
50 |
+
|
51 |
+
class SyntheticDocQATest(Enum):
|
52 |
+
shift = "vidore/shiftproject_test"
|
colpali-main/colpali_engine/evaluation/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .eval_manager import EvalManager
|
colpali-main/colpali_engine/evaluation/eval_manager.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Any, ClassVar, Dict, Optional
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
|
9 |
+
class EvalManager:
|
10 |
+
"""
|
11 |
+
Stores evaluation results for various datasets and metrics.
|
12 |
+
|
13 |
+
The data is stored in a pandas DataFrame with a MultiIndex for columns.
|
14 |
+
The first level of the MultiIndex is the dataset name and the second level is the metric name.
|
15 |
+
|
16 |
+
Usage:
|
17 |
+
>>> evaluator = Evaluator.from_dirpath("data/evaluation_results/")
|
18 |
+
>>> print(evaluator.data)
|
19 |
+
|
20 |
+
"""
|
21 |
+
|
22 |
+
model_col: ClassVar[str] = "model"
|
23 |
+
dataset_col: ClassVar[str] = "dataset"
|
24 |
+
metric_col: ClassVar[str] = "metric"
|
25 |
+
|
26 |
+
def __init__(self, data: Optional[pd.DataFrame] = None):
|
27 |
+
if data is None:
|
28 |
+
data = pd.DataFrame()
|
29 |
+
self._df = data
|
30 |
+
self._df.index = self._df.index.rename(EvalManager.model_col)
|
31 |
+
|
32 |
+
def __str__(self) -> str:
|
33 |
+
return self.data.__str__()
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def from_dict(data: Dict[Any, Any]) -> EvalManager:
|
37 |
+
"""
|
38 |
+
Load evaluation results from a dictionary.
|
39 |
+
|
40 |
+
Expected format:
|
41 |
+
{
|
42 |
+
"model1": pd.read_json(path1).T.stack(),
|
43 |
+
"model2": pd.read_json(path2).T.stack(),
|
44 |
+
}
|
45 |
+
|
46 |
+
"""
|
47 |
+
df = pd.DataFrame.from_dict(data, orient="index")
|
48 |
+
return EvalManager(df)
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def from_json(path: str | Path) -> EvalManager:
|
52 |
+
datapath = Path(path)
|
53 |
+
if not datapath.is_file():
|
54 |
+
raise FileNotFoundError(f"{path} is not a file")
|
55 |
+
data = {}
|
56 |
+
data[datapath.stem] = pd.read_json(datapath).T.stack() # pylint: disable=no-member
|
57 |
+
return EvalManager.from_dict(data)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def from_dir(datadir: str | Path) -> EvalManager:
|
61 |
+
datadir_ = Path(datadir)
|
62 |
+
if not datadir_.is_dir():
|
63 |
+
raise FileNotFoundError(f"{datadir} is not a directory")
|
64 |
+
|
65 |
+
eval_files = list(datadir_.glob("*.json"))
|
66 |
+
|
67 |
+
data = {}
|
68 |
+
|
69 |
+
for filepath in eval_files:
|
70 |
+
data[filepath.stem] = pd.read_json(filepath).T.stack() # pylint: disable=no-member
|
71 |
+
|
72 |
+
return EvalManager.from_dict(data)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def from_csv(path: str | Path) -> EvalManager:
|
76 |
+
"""
|
77 |
+
Load evaluation results from a CSV file.
|
78 |
+
"""
|
79 |
+
try:
|
80 |
+
df = pd.read_csv(path, index_col=0, header=[0, 1])
|
81 |
+
return EvalManager(df)
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Error loading {path}: {e}")
|
84 |
+
raise e
|
85 |
+
|
86 |
+
@property
|
87 |
+
def data(self) -> pd.DataFrame:
|
88 |
+
"""
|
89 |
+
Returns the evaluation results as a pandas DataFrame.
|
90 |
+
"""
|
91 |
+
return self._df.copy()
|
92 |
+
|
93 |
+
@property
|
94 |
+
def models(self) -> pd.Index:
|
95 |
+
"""
|
96 |
+
Returns the models for which there are evaluation results.
|
97 |
+
"""
|
98 |
+
return self.data.index
|
99 |
+
|
100 |
+
@property
|
101 |
+
def datasets(self) -> pd.Index:
|
102 |
+
"""
|
103 |
+
Returns the datasets for which there are evaluation results.
|
104 |
+
"""
|
105 |
+
return self.data.columns.get_level_values(0).unique()
|
106 |
+
|
107 |
+
@property
|
108 |
+
def metrics(self) -> pd.Index:
|
109 |
+
"""
|
110 |
+
Returns the metrics for which there are evaluation results.
|
111 |
+
"""
|
112 |
+
return self.data.columns.get_level_values(1)
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def melt(df: pd.DataFrame) -> pd.DataFrame:
|
116 |
+
"""
|
117 |
+
Melt a suitable DataFrame (e.g. returned by `get_df_for_dataset` and
|
118 |
+
`get_df_for_metric`) into a 'long' format.
|
119 |
+
"""
|
120 |
+
return df.T.reset_index(names=[EvalManager.dataset_col, EvalManager.metric_col]).melt(
|
121 |
+
id_vars=[EvalManager.dataset_col, EvalManager.metric_col],
|
122 |
+
var_name=EvalManager.model_col,
|
123 |
+
value_name="score",
|
124 |
+
)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def melted(self) -> pd.DataFrame:
|
128 |
+
"""
|
129 |
+
Returns the evaluation results as a 'melted' DataFrame.
|
130 |
+
Useful for plotting with seaborn.
|
131 |
+
"""
|
132 |
+
return EvalManager.melt(self.data)
|
133 |
+
|
134 |
+
def get_df_for_model(self, model: str) -> pd.DataFrame:
|
135 |
+
if model not in self.data.index:
|
136 |
+
raise ValueError(f"Model {model} not found in the evaluation results")
|
137 |
+
return self.data.loc[[model], :] # type: ignore
|
138 |
+
|
139 |
+
def get_df_for_dataset(self, dataset: str) -> pd.DataFrame:
|
140 |
+
if dataset not in self.datasets:
|
141 |
+
raise ValueError(f"Dataset {dataset} not found in the evaluation results")
|
142 |
+
return self.data.loc[:, (dataset, slice(None))] # type: ignore
|
143 |
+
|
144 |
+
def get_df_for_metric(self, metric: str) -> pd.DataFrame:
|
145 |
+
if metric not in self.metrics:
|
146 |
+
raise ValueError(f"Metric {metric} not found in the evaluation results")
|
147 |
+
return self.data.loc[:, (slice(None), metric)] # type: ignore
|
148 |
+
|
149 |
+
def sort_by_dataset(self, ascending: bool = True) -> EvalManager:
|
150 |
+
"""
|
151 |
+
Sort the evaluation results by dataset name.
|
152 |
+
"""
|
153 |
+
df = self.data.T.sort_index(level=0, ascending=ascending).T
|
154 |
+
return EvalManager(df)
|
155 |
+
|
156 |
+
def sort_by_metric(self, ascending: bool = True) -> EvalManager:
|
157 |
+
"""
|
158 |
+
Sort the evaluation results by metric name.
|
159 |
+
"""
|
160 |
+
df = self.data.T.sort_index(level=1, ascending=ascending).T
|
161 |
+
return EvalManager(df)
|
162 |
+
|
163 |
+
def sort_columns(self, ascending: bool = True) -> EvalManager:
|
164 |
+
"""
|
165 |
+
Sort the evaluation results by dataset name and then by metric name.
|
166 |
+
"""
|
167 |
+
df = self.data.T.sort_index(level=[0, 1], ascending=ascending).T
|
168 |
+
return EvalManager(df)
|
169 |
+
|
170 |
+
def to_csv(self, path: str | Path):
|
171 |
+
"""
|
172 |
+
Save the evaluation results to a CSV file.
|
173 |
+
|
174 |
+
Using `Evaluation.from_csv(path_to_saved_csv)` will load the evaluation results back into memory.
|
175 |
+
"""
|
176 |
+
savepath = Path(path)
|
177 |
+
savepath.parent.mkdir(parents=True, exist_ok=True)
|
178 |
+
self.data.to_csv(savepath)
|
colpali-main/colpali_engine/interpretability/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .plot_utils import *
|
2 |
+
from .processor import *
|
3 |
+
from .torch_utils import *
|
4 |
+
from .vit_configs import *
|
colpali-main/colpali_engine/interpretability/gen_interpretability_plots.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from uuid import uuid4
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import trange
|
11 |
+
|
12 |
+
from colpali_engine.interpretability.plot_utils import plot_patches
|
13 |
+
from colpali_engine.interpretability.processor import ColPaliProcessor
|
14 |
+
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
|
15 |
+
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
|
16 |
+
from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
17 |
+
|
18 |
+
OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class InterpretabilityInput:
|
23 |
+
query: str
|
24 |
+
image: Image.Image
|
25 |
+
start_idx_token: int
|
26 |
+
end_idx_token: int
|
27 |
+
|
28 |
+
|
29 |
+
def generate_interpretability_plots(
|
30 |
+
model: ColPali,
|
31 |
+
processor: ColPaliProcessor,
|
32 |
+
query: str,
|
33 |
+
image: Image.Image,
|
34 |
+
savedir: str | Path | None = None,
|
35 |
+
add_special_prompt_to_doc: bool = True,
|
36 |
+
) -> None:
|
37 |
+
|
38 |
+
# Sanity checks
|
39 |
+
if len(model.active_adapters()) != 1:
|
40 |
+
raise ValueError("The model must have exactly one active adapter.")
|
41 |
+
|
42 |
+
if model.config.name_or_path not in VIT_CONFIG:
|
43 |
+
raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
|
44 |
+
vit_config = VIT_CONFIG[model.config.name_or_path]
|
45 |
+
|
46 |
+
# Handle savepath
|
47 |
+
if not savedir:
|
48 |
+
savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
|
49 |
+
print(f"No savepath provided. Results will be saved to: `{savedir}`.")
|
50 |
+
elif isinstance(savedir, str):
|
51 |
+
savedir = Path(savedir)
|
52 |
+
savedir.mkdir(parents=True, exist_ok=True)
|
53 |
+
|
54 |
+
# Resize the image to square
|
55 |
+
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
|
56 |
+
|
57 |
+
# Preprocess the inputs
|
58 |
+
input_text_processed = processor.process_text(query).to(model.device)
|
59 |
+
input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
|
60 |
+
model.device
|
61 |
+
)
|
62 |
+
|
63 |
+
# Forward pass
|
64 |
+
with torch.no_grad():
|
65 |
+
output_text = model.forward(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
|
66 |
+
|
67 |
+
# NOTE: `output_image`` will have shape:
|
68 |
+
# (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
|
69 |
+
# (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
|
70 |
+
with torch.no_grad():
|
71 |
+
output_image = model.forward(**asdict(input_image_processed))
|
72 |
+
|
73 |
+
if add_special_prompt_to_doc: # remove the special tokens
|
74 |
+
output_image = output_image[
|
75 |
+
:, : processor.processor.image_seq_length, :
|
76 |
+
] # (1, n_patch_x * n_patch_y, hidden_dim)
|
77 |
+
|
78 |
+
output_image = rearrange(
|
79 |
+
output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
|
80 |
+
) # (1, n_patch_x, n_patch_y, hidden_dim)
|
81 |
+
|
82 |
+
# Get the unnormalized attention map
|
83 |
+
attention_map = torch.einsum(
|
84 |
+
"bnk,bijk->bnij", output_text, output_image
|
85 |
+
) # (1, n_text_tokens, n_patch_x, n_patch_y)
|
86 |
+
attention_map_normalized = normalize_attention_map_per_query_token(
|
87 |
+
attention_map
|
88 |
+
) # (1, n_text_tokens, n_patch_x, n_patch_y)
|
89 |
+
attention_map_normalized = attention_map_normalized.float()
|
90 |
+
|
91 |
+
# Get text token information
|
92 |
+
n_tokens = input_text_processed.input_ids.size(1)
|
93 |
+
text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
|
94 |
+
print("Text tokens:")
|
95 |
+
pprint.pprint(text_tokens)
|
96 |
+
print("\n")
|
97 |
+
|
98 |
+
for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."): # exclude the <bos> and the "\n" tokens
|
99 |
+
fig, axis = plot_patches(
|
100 |
+
input_image_square,
|
101 |
+
vit_config.patch_size,
|
102 |
+
vit_config.resolution,
|
103 |
+
patch_opacities=attention_map_normalized[0, token_idx, :, :],
|
104 |
+
style="dark_background",
|
105 |
+
)
|
106 |
+
|
107 |
+
fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
|
108 |
+
savepath = savedir / f"token_{token_idx}.png"
|
109 |
+
fig.savefig(savepath)
|
110 |
+
print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
|
111 |
+
plt.close(fig)
|
112 |
+
|
113 |
+
return
|
colpali-main/colpali_engine/interpretability/plot_utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, cast
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import numpy.typing as npt
|
6 |
+
import seaborn as sns
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
MAX_OPACITY = 255
|
11 |
+
|
12 |
+
|
13 |
+
def plot_patches(
|
14 |
+
img: Image.Image,
|
15 |
+
patch_size: int,
|
16 |
+
image_resolution: int,
|
17 |
+
patch_opacities: Optional[npt.NDArray | torch.Tensor] = None,
|
18 |
+
figsize: Tuple[int, int] = (8, 8),
|
19 |
+
style: Dict[str, Any] | str | None = None,
|
20 |
+
) -> Tuple[plt.Figure, plt.Axes]:
|
21 |
+
"""
|
22 |
+
Plot patches of a square image.
|
23 |
+
Set `style` to "dark_background" if your image has a light background.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Get the number of patches
|
27 |
+
if image_resolution % patch_size != 0:
|
28 |
+
raise ValueError("The image resolution must be divisible by the patch size.")
|
29 |
+
num_patches = image_resolution // patch_size
|
30 |
+
|
31 |
+
# Default style
|
32 |
+
if style is None:
|
33 |
+
style = {}
|
34 |
+
|
35 |
+
# Sanity checks
|
36 |
+
if patch_opacities is not None:
|
37 |
+
if isinstance(patch_opacities, torch.Tensor):
|
38 |
+
patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy())
|
39 |
+
if patch_opacities.shape != (num_patches, num_patches):
|
40 |
+
raise ValueError("The shape of the patch_opacities tensor is not correct.")
|
41 |
+
if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)):
|
42 |
+
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
|
43 |
+
|
44 |
+
# If the image is not square, raise an error
|
45 |
+
if img.size[0] != img.size[1]:
|
46 |
+
raise ValueError("The image must be square.")
|
47 |
+
|
48 |
+
# Get the image as a numpy array
|
49 |
+
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
|
50 |
+
|
51 |
+
# Create a figure
|
52 |
+
with plt.style.context(style):
|
53 |
+
fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)
|
54 |
+
|
55 |
+
# Plot the patches
|
56 |
+
for i in range(num_patches):
|
57 |
+
for j in range(num_patches):
|
58 |
+
patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
|
59 |
+
# Set the opacity of the patch
|
60 |
+
if patch_opacities is not None:
|
61 |
+
patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY)
|
62 |
+
axis[i, j].imshow(patch)
|
63 |
+
axis[i, j].axis("off")
|
64 |
+
|
65 |
+
fig.subplots_adjust(wspace=0.1, hspace=0.1)
|
66 |
+
|
67 |
+
fig.tight_layout()
|
68 |
+
|
69 |
+
return fig, axis
|
70 |
+
|
71 |
+
|
72 |
+
def plot_attention_heatmap(
|
73 |
+
img: Image.Image,
|
74 |
+
patch_size: int,
|
75 |
+
image_resolution: int,
|
76 |
+
attention_map: npt.NDArray | torch.Tensor,
|
77 |
+
figsize: Tuple[int, int] = (8, 8),
|
78 |
+
style: Dict[str, Any] | str | None = None,
|
79 |
+
show_colorbar: bool = False,
|
80 |
+
show_axes: bool = False,
|
81 |
+
) -> Tuple[plt.Figure, plt.Axes]:
|
82 |
+
"""
|
83 |
+
Plot a heatmap of the attention map over the image.
|
84 |
+
The image must be square and `attention_map` must be normalized between 0 and 1.
|
85 |
+
"""
|
86 |
+
|
87 |
+
# Get the number of patches
|
88 |
+
if image_resolution % patch_size != 0:
|
89 |
+
raise ValueError("The image resolution must be divisible by the patch size.")
|
90 |
+
num_patches = image_resolution // patch_size
|
91 |
+
|
92 |
+
# Default style
|
93 |
+
if style is None:
|
94 |
+
style = {}
|
95 |
+
|
96 |
+
# Sanity checks
|
97 |
+
if isinstance(attention_map, torch.Tensor):
|
98 |
+
attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
|
99 |
+
if attention_map.shape != (num_patches, num_patches):
|
100 |
+
raise ValueError("The shape of the patch_opacities tensor is not correct.")
|
101 |
+
if not np.all((0 <= attention_map) & (attention_map <= 1)):
|
102 |
+
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
|
103 |
+
|
104 |
+
# If the image is not square, raise an error
|
105 |
+
if img.size[0] != img.size[1]:
|
106 |
+
raise ValueError("The image must be square.")
|
107 |
+
|
108 |
+
# Get the image as a numpy array
|
109 |
+
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
|
110 |
+
|
111 |
+
# Get the attention map as a numpy array
|
112 |
+
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
|
113 |
+
img.size, Image.Resampling.BICUBIC
|
114 |
+
)
|
115 |
+
|
116 |
+
# Create a figure
|
117 |
+
with plt.style.context(style):
|
118 |
+
fig, ax = plt.subplots(figsize=figsize)
|
119 |
+
ax.imshow(img_array)
|
120 |
+
im = ax.imshow(
|
121 |
+
attention_map_image,
|
122 |
+
cmap=sns.color_palette("mako", as_cmap=True),
|
123 |
+
alpha=0.5,
|
124 |
+
)
|
125 |
+
if show_colorbar:
|
126 |
+
fig.colorbar(im)
|
127 |
+
if not show_axes:
|
128 |
+
ax.set_axis_off()
|
129 |
+
fig.tight_layout()
|
130 |
+
|
131 |
+
return fig, ax
|
colpali-main/colpali_engine/interpretability/processor.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, cast
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import LlamaTokenizerFast, PaliGemmaProcessor
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ColPaliTextInput:
|
13 |
+
input_ids: torch.Tensor
|
14 |
+
attention_mask: torch.Tensor
|
15 |
+
|
16 |
+
def to(self, device: torch.device) -> ColPaliTextInput:
|
17 |
+
return ColPaliTextInput(
|
18 |
+
input_ids=self.input_ids.to(device),
|
19 |
+
attention_mask=self.attention_mask.to(device),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ColPaliImageInput:
|
25 |
+
input_ids: torch.Tensor
|
26 |
+
pixel_values: torch.Tensor
|
27 |
+
attention_mask: torch.Tensor
|
28 |
+
|
29 |
+
def to(self, device: str | torch.device) -> ColPaliImageInput:
|
30 |
+
return ColPaliImageInput(
|
31 |
+
input_ids=self.input_ids.to(device),
|
32 |
+
pixel_values=self.pixel_values.to(device),
|
33 |
+
attention_mask=self.attention_mask.to(device),
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
class ColPaliProcessor:
|
38 |
+
def __init__(self, processor: PaliGemmaProcessor):
|
39 |
+
self.processor = processor
|
40 |
+
self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def from_pretrained(model_name: str) -> ColPaliProcessor:
|
44 |
+
return ColPaliProcessor(processor=cast(PaliGemmaProcessor, PaliGemmaProcessor.from_pretrained(model_name)))
|
45 |
+
|
46 |
+
def process_text(
|
47 |
+
self,
|
48 |
+
text: str | List[str],
|
49 |
+
padding: str = "longest",
|
50 |
+
return_tensors: str = "pt",
|
51 |
+
add_special_tokens: bool = True,
|
52 |
+
) -> ColPaliTextInput:
|
53 |
+
"""
|
54 |
+
Process text inputs for the model.
|
55 |
+
If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n".
|
56 |
+
"""
|
57 |
+
if add_special_tokens:
|
58 |
+
if isinstance(text, str):
|
59 |
+
text = self.tokenizer.bos_token + text + "\n"
|
60 |
+
elif isinstance(text, list):
|
61 |
+
text = [self.tokenizer.bos_token + t + "\n" for t in text]
|
62 |
+
else:
|
63 |
+
raise ValueError("text must be a string or a list of strings.")
|
64 |
+
|
65 |
+
batch_output = self.tokenizer(
|
66 |
+
text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens
|
67 |
+
)
|
68 |
+
|
69 |
+
return ColPaliTextInput(
|
70 |
+
input_ids=cast(torch.Tensor, batch_output["input_ids"]),
|
71 |
+
attention_mask=cast(torch.Tensor, batch_output["attention_mask"]),
|
72 |
+
)
|
73 |
+
|
74 |
+
def process_image(
|
75 |
+
self,
|
76 |
+
image: Image.Image | List[Image.Image],
|
77 |
+
padding: str = "longest",
|
78 |
+
do_convert_rgb: bool = True,
|
79 |
+
return_tensors: str = "pt",
|
80 |
+
add_special_prompt: bool = True,
|
81 |
+
) -> ColPaliImageInput:
|
82 |
+
# NOTE: The special prompt was used at training time,
|
83 |
+
special_prompt = "Describe the image." if add_special_prompt else None
|
84 |
+
if isinstance(image, Image.Image):
|
85 |
+
text_input = [special_prompt]
|
86 |
+
elif isinstance(image, list):
|
87 |
+
text_input = [special_prompt] * len(image)
|
88 |
+
else:
|
89 |
+
raise ValueError("image must be a PIL Image or a list of PIL Images.")
|
90 |
+
|
91 |
+
batch_output = self.processor(
|
92 |
+
text=text_input,
|
93 |
+
images=image,
|
94 |
+
padding=padding,
|
95 |
+
do_convert_rgb=do_convert_rgb,
|
96 |
+
return_tensors=return_tensors,
|
97 |
+
)
|
98 |
+
|
99 |
+
if add_special_prompt:
|
100 |
+
return ColPaliImageInput(
|
101 |
+
input_ids=batch_output["input_ids"],
|
102 |
+
pixel_values=batch_output["pixel_values"],
|
103 |
+
attention_mask=batch_output["attention_mask"],
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
return ColPaliImageInput(
|
107 |
+
input_ids=batch_output["input_ids"][:, : self.processor.image_seq_length],
|
108 |
+
pixel_values=batch_output["pixel_values"][:, : self.processor.image_seq_length],
|
109 |
+
attention_mask=batch_output["attention_mask"][:, : self.processor.image_seq_length],
|
110 |
+
)
|
111 |
+
|
112 |
+
def decode(self, *args, **kwargs):
|
113 |
+
return self.tokenizer.decode(*args, **kwargs)
|
114 |
+
|
115 |
+
def batch_decode(self, *args, **kwargs):
|
116 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
colpali-main/colpali_engine/interpretability/torch_utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
EPSILON = 1e-10
|
8 |
+
|
9 |
+
|
10 |
+
def normalize_attention_map_per_query_token(x: torch.Tensor) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
Normalizes the attention map for ColPali for each query token.
|
13 |
+
The output tensor will have values in the range [0, 1] and the
|
14 |
+
same shape as the input tensor.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
|
18 |
+
"""
|
19 |
+
if x.ndim != 4:
|
20 |
+
raise ValueError("The input tensor must have 4 dimensions.")
|
21 |
+
|
22 |
+
# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
|
23 |
+
min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
|
24 |
+
|
25 |
+
# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
|
26 |
+
max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
|
27 |
+
|
28 |
+
# Normalize the tensor
|
29 |
+
x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
|
30 |
+
|
31 |
+
return x_normalized
|
32 |
+
|
33 |
+
|
34 |
+
def normalize_attention_map_per_query(x: torch.Tensor) -> torch.Tensor:
|
35 |
+
"""
|
36 |
+
Normalizes the attention map for ColPali for each query token.
|
37 |
+
The output tensor will have values in the range [0, 1] and the
|
38 |
+
same shape as the input tensor.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
|
42 |
+
"""
|
43 |
+
# Log warning
|
44 |
+
logger.warning(
|
45 |
+
"This function should not be used for ColPali because it doesn't make sense to normalize the attention map across the text tokens."
|
46 |
+
)
|
47 |
+
|
48 |
+
if x.ndim != 4:
|
49 |
+
raise ValueError("The input tensor must have 4 dimensions.")
|
50 |
+
|
51 |
+
# Compute the minimum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
|
52 |
+
min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0].min(dim=-3, keepdim=True)[0]
|
53 |
+
|
54 |
+
# Compute the maximum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
|
55 |
+
max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0].max(dim=-3, keepdim=True)[0]
|
56 |
+
|
57 |
+
# Normalize the tensor
|
58 |
+
x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
|
59 |
+
|
60 |
+
return x_normalized
|
colpali-main/colpali_engine/interpretability/vit_configs.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class ViTConfig:
|
7 |
+
patch_size: int
|
8 |
+
resolution: int
|
9 |
+
|
10 |
+
@property
|
11 |
+
def n_patch_per_dim(self) -> int:
|
12 |
+
if self.resolution % self.patch_size != 0:
|
13 |
+
raise ValueError(f"Resolution {self.resolution} is not divisible by patch size {self.patch_size}")
|
14 |
+
return self.resolution // self.patch_size
|
15 |
+
|
16 |
+
|
17 |
+
VIT_CONFIG: Dict[str, ViTConfig] = {
|
18 |
+
"google/siglip-so400m-patch14-384": ViTConfig(patch_size=14, resolution=384),
|
19 |
+
"timm/ViT-SO400M-14-SigLIP-384": ViTConfig(patch_size=14, resolution=384),
|
20 |
+
"google/paligemma-3b-mix-448": ViTConfig(
|
21 |
+
patch_size=14, resolution=448
|
22 |
+
), # based on "timm/ViT-SO400M-14-SigLIP-384" with increased resolution
|
23 |
+
}
|
colpali-main/colpali_engine/loss/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .colbert_loss import ColbertLoss
|
colpali-main/colpali_engine/loss/colbert_loss.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
|
5 |
+
|
6 |
+
class BiEncoderLoss(torch.nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.ce_loss = CrossEntropyLoss()
|
10 |
+
# self.pooling_strategy = pooling_strategy
|
11 |
+
|
12 |
+
def forward(self, query_embeddings, doc_embeddings):
|
13 |
+
"""
|
14 |
+
query_embeddings: (batch_size, dim)
|
15 |
+
doc_embeddings: (batch_size, dim)
|
16 |
+
"""
|
17 |
+
|
18 |
+
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
|
19 |
+
|
20 |
+
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
|
21 |
+
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
|
22 |
+
# loss = (loss_rowwise + loss_columnwise) / 2
|
23 |
+
return loss_rowwise
|
24 |
+
|
25 |
+
|
26 |
+
class ColbertLoss(torch.nn.Module):
|
27 |
+
def __init__(self):
|
28 |
+
super().__init__()
|
29 |
+
self.ce_loss = CrossEntropyLoss()
|
30 |
+
|
31 |
+
def forward(self, query_embeddings, doc_embeddings):
|
32 |
+
"""
|
33 |
+
query_embeddings: (batch_size, num_query_tokens, dim)
|
34 |
+
doc_embeddings: (batch_size, num_doc_tokens, dim)
|
35 |
+
"""
|
36 |
+
|
37 |
+
scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
|
38 |
+
|
39 |
+
# scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
|
40 |
+
# for i in range(query_embeddings.shape[0]):
|
41 |
+
# for j in range(doc_embeddings.shape[0]):
|
42 |
+
# # step 1 - dot product --> (s1,s2)
|
43 |
+
# q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
|
44 |
+
# # step 2 -> max on doc --> (s1)
|
45 |
+
# q_scores = torch.max(q2d_scores, dim=1)[0]
|
46 |
+
# # step 3 --> sum the max score --> (1)
|
47 |
+
# sum_q_score = torch.sum(q_scores)
|
48 |
+
# # step 4 --> assert is scalar
|
49 |
+
# scores[i, j] = sum_q_score
|
50 |
+
|
51 |
+
# assert (scores_einsum - scores < 0.0001).all().item()
|
52 |
+
|
53 |
+
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
|
54 |
+
# TODO: comparing between queries might not make sense since it's a sum over the length of the query
|
55 |
+
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
|
56 |
+
# loss = (loss_rowwise + loss_columnwise) / 2
|
57 |
+
return loss_rowwise
|
58 |
+
|
59 |
+
|
60 |
+
class ColbertPairwiseCELoss(torch.nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super().__init__()
|
63 |
+
self.ce_loss = CrossEntropyLoss()
|
64 |
+
|
65 |
+
def forward(self, query_embeddings, doc_embeddings):
|
66 |
+
"""
|
67 |
+
query_embeddings: (batch_size, num_query_tokens, dim)
|
68 |
+
doc_embeddings: (batch_size, num_doc_tokens, dim)
|
69 |
+
|
70 |
+
Positive scores are the diagonal of the scores matrix.
|
71 |
+
"""
|
72 |
+
|
73 |
+
# Compute the ColBERT scores
|
74 |
+
scores = (
|
75 |
+
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
|
76 |
+
) # (batch_size, batch_size)
|
77 |
+
|
78 |
+
# Positive scores are the diagonal of the scores matrix.
|
79 |
+
pos_scores = scores.diagonal() # (batch_size,)
|
80 |
+
|
81 |
+
# Negative score for a given query is the maximum of the scores against all all other pages.
|
82 |
+
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
|
83 |
+
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
|
84 |
+
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
|
85 |
+
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
|
86 |
+
|
87 |
+
# Compute the loss
|
88 |
+
# The loss is computed as the negative log of the softmax of the positive scores
|
89 |
+
# relative to the negative scores.
|
90 |
+
# This can be simplified to log-sum-exp of negative scores minus the positive score
|
91 |
+
# for numerical stability.
|
92 |
+
# torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
|
93 |
+
loss = F.softplus(neg_scores - pos_scores).mean()
|
94 |
+
|
95 |
+
return loss
|
96 |
+
|
97 |
+
|
98 |
+
class BiPairwiseCELoss(torch.nn.Module):
|
99 |
+
def __init__(self):
|
100 |
+
super().__init__()
|
101 |
+
self.ce_loss = CrossEntropyLoss()
|
102 |
+
|
103 |
+
def forward(self, query_embeddings, doc_embeddings):
|
104 |
+
"""
|
105 |
+
query_embeddings: (batch_size, dim)
|
106 |
+
doc_embeddings: (batch_size, dim)
|
107 |
+
"""
|
108 |
+
|
109 |
+
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
|
110 |
+
|
111 |
+
pos_scores = scores.diagonal()
|
112 |
+
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
|
113 |
+
neg_scores = neg_scores.max(dim=1)[0]
|
114 |
+
|
115 |
+
# Compute the loss
|
116 |
+
# The loss is computed as the negative log of the softmax of the positive scores
|
117 |
+
# relative to the negative scores.
|
118 |
+
# This can be simplified to log-sum-exp of negative scores minus the positive score
|
119 |
+
# for numerical stability.
|
120 |
+
loss = F.softplus(neg_scores - pos_scores).mean()
|
121 |
+
|
122 |
+
return loss
|
colpali-main/colpali_engine/models/__init__.py
ADDED
File without changes
|
colpali-main/colpali_engine/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (170 Bytes). View file
|
|
colpali-main/colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc
ADDED
Binary file (4.87 kB). View file
|
|
colpali-main/colpali_engine/models/clip_baselines.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import SiglipModel
|
6 |
+
|
7 |
+
|
8 |
+
class SigLIP(SiglipModel):
|
9 |
+
def forward(self, *args, **kwargs):
|
10 |
+
"""
|
11 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
12 |
+
|
13 |
+
Args:
|
14 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
15 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
19 |
+
"""
|
20 |
+
return self.forward_branch(*args, **kwargs)
|
21 |
+
|
22 |
+
def forward_branch(
|
23 |
+
self,
|
24 |
+
input_ids: Optional[torch.LongTensor] = None,
|
25 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
26 |
+
attention_mask: Optional[torch.Tensor] = None,
|
27 |
+
position_ids: Optional[torch.LongTensor] = None,
|
28 |
+
return_loss: Optional[bool] = None,
|
29 |
+
output_attentions: Optional[bool] = None,
|
30 |
+
output_hidden_states: Optional[bool] = None,
|
31 |
+
return_dict: Optional[bool] = None,
|
32 |
+
interpolate_pos_encoding: bool = False,
|
33 |
+
):
|
34 |
+
|
35 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
36 |
+
output_hidden_states = (
|
37 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
38 |
+
)
|
39 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
40 |
+
|
41 |
+
if pixel_values is not None:
|
42 |
+
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
43 |
+
|
44 |
+
outputs = self.vision_model(
|
45 |
+
pixel_values=pixel_values.to(dtype=self.dtype),
|
46 |
+
output_attentions=output_attentions,
|
47 |
+
output_hidden_states=output_hidden_states,
|
48 |
+
return_dict=return_dict,
|
49 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
50 |
+
)
|
51 |
+
|
52 |
+
else:
|
53 |
+
outputs = self.text_model(
|
54 |
+
input_ids=input_ids,
|
55 |
+
attention_mask=attention_mask,
|
56 |
+
position_ids=position_ids,
|
57 |
+
output_attentions=output_attentions,
|
58 |
+
output_hidden_states=output_hidden_states,
|
59 |
+
return_dict=return_dict,
|
60 |
+
)
|
61 |
+
|
62 |
+
embeds = outputs[1]
|
63 |
+
|
64 |
+
# normalized features
|
65 |
+
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
|
66 |
+
return embeds
|
67 |
+
|
68 |
+
|
69 |
+
class ColSigLIP(SiglipModel):
|
70 |
+
def __init__(self, config):
|
71 |
+
super(ColSigLIP, self).__init__(config=config)
|
72 |
+
self.dim = 128
|
73 |
+
self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim)
|
74 |
+
self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim)
|
75 |
+
self.main_input_name = "doc_input_ids"
|
76 |
+
|
77 |
+
def forward(self, *args, **kwargs):
|
78 |
+
"""
|
79 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
80 |
+
|
81 |
+
Args:
|
82 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
83 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
87 |
+
"""
|
88 |
+
return self.forward_branch(*args, **kwargs)
|
89 |
+
|
90 |
+
def forward_branch(
|
91 |
+
self,
|
92 |
+
input_ids: Optional[torch.LongTensor] = None,
|
93 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
94 |
+
attention_mask: Optional[torch.Tensor] = None,
|
95 |
+
position_ids: Optional[torch.LongTensor] = None,
|
96 |
+
return_loss: Optional[bool] = None,
|
97 |
+
output_attentions: Optional[bool] = None,
|
98 |
+
output_hidden_states: Optional[bool] = None,
|
99 |
+
return_dict: Optional[bool] = None,
|
100 |
+
interpolate_pos_encoding: bool = False,
|
101 |
+
):
|
102 |
+
|
103 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
104 |
+
output_hidden_states = (
|
105 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
106 |
+
)
|
107 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
108 |
+
|
109 |
+
if pixel_values is not None:
|
110 |
+
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
111 |
+
|
112 |
+
outputs = self.vision_model(
|
113 |
+
pixel_values=pixel_values.to(dtype=self.dtype),
|
114 |
+
output_attentions=output_attentions,
|
115 |
+
output_hidden_states=output_hidden_states,
|
116 |
+
return_dict=return_dict,
|
117 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
118 |
+
)
|
119 |
+
|
120 |
+
last_hidden_states = outputs.last_hidden_state
|
121 |
+
|
122 |
+
proj = self.custom_vision_proj(last_hidden_states)
|
123 |
+
# normalize l2 norm
|
124 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
125 |
+
|
126 |
+
else:
|
127 |
+
outputs = self.text_model(
|
128 |
+
input_ids=input_ids,
|
129 |
+
attention_mask=attention_mask,
|
130 |
+
position_ids=position_ids,
|
131 |
+
output_attentions=output_attentions,
|
132 |
+
output_hidden_states=output_hidden_states,
|
133 |
+
return_dict=return_dict,
|
134 |
+
)
|
135 |
+
|
136 |
+
last_hidden_states = outputs.last_hidden_state
|
137 |
+
|
138 |
+
proj = self.custom_text_proj(last_hidden_states)
|
139 |
+
# normalize l2 norm
|
140 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
141 |
+
proj = proj * attention_mask.unsqueeze(-1)
|
142 |
+
|
143 |
+
# normalized features
|
144 |
+
return proj
|
colpali-main/colpali_engine/models/colbert_architectures.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import (
|
3 |
+
BertModel,
|
4 |
+
BertPreTrainedModel,
|
5 |
+
CamembertModel,
|
6 |
+
CamembertPreTrainedModel,
|
7 |
+
LlamaModel,
|
8 |
+
LlamaPreTrainedModel,
|
9 |
+
XLMRobertaModel,
|
10 |
+
XLMRobertaPreTrainedModel,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class ColCamembert(CamembertPreTrainedModel):
|
15 |
+
def __init__(self, config):
|
16 |
+
super(ColCamembert, self).__init__(config=config)
|
17 |
+
self.roberta: CamembertPreTrainedModel = CamembertModel(config)
|
18 |
+
self.dim = 128
|
19 |
+
self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
|
20 |
+
self.main_input_name = "doc_input_ids"
|
21 |
+
|
22 |
+
def forward(self, *args, **kwargs):
|
23 |
+
"""
|
24 |
+
Forward pass through Camenbert and the linear layer for dimensionality reduction
|
25 |
+
|
26 |
+
Args:
|
27 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
28 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
32 |
+
"""
|
33 |
+
outputs = self.roberta(*args, **kwargs)
|
34 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
35 |
+
proj = self.linear(last_hidden_states)
|
36 |
+
# normalize l2 norm
|
37 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
38 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
39 |
+
return proj
|
40 |
+
|
41 |
+
|
42 |
+
class ColXLMRoBERTa(XLMRobertaPreTrainedModel):
|
43 |
+
def __init__(self, config):
|
44 |
+
super(ColXLMRoBERTa, self).__init__(config=config)
|
45 |
+
self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
|
46 |
+
self.dim = 128
|
47 |
+
self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
|
48 |
+
self.main_input_name = "doc_input_ids"
|
49 |
+
|
50 |
+
def forward(self, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
Forward pass through Roberta and the linear layer for dimensionality reduction
|
53 |
+
|
54 |
+
Args:
|
55 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
56 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
60 |
+
"""
|
61 |
+
outputs = self.roberta(*args, **kwargs)
|
62 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
63 |
+
proj = self.linear(last_hidden_states)
|
64 |
+
# normalize l2 norm
|
65 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
66 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
67 |
+
return proj
|
68 |
+
|
69 |
+
|
70 |
+
class BiXLMRoBERTa(XLMRobertaPreTrainedModel):
|
71 |
+
def __init__(self, config):
|
72 |
+
super(BiXLMRoBERTa, self).__init__(config=config)
|
73 |
+
self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
|
74 |
+
self.main_input_name = "doc_input_ids"
|
75 |
+
|
76 |
+
def forward(self, *args, **kwargs):
|
77 |
+
"""
|
78 |
+
Forward pass through Roberta and the linear layer for dimensionality reduction
|
79 |
+
|
80 |
+
Args:
|
81 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
82 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
86 |
+
"""
|
87 |
+
outputs = self.roberta(*args, **kwargs)
|
88 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
89 |
+
# pooling - mean tokens that have attention mask == 1
|
90 |
+
proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
|
91 |
+
proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
|
92 |
+
# normalize l2 norm
|
93 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
94 |
+
return proj
|
95 |
+
|
96 |
+
|
97 |
+
class ColBERT(BertPreTrainedModel):
|
98 |
+
def __init__(self, config):
|
99 |
+
super(ColBERT, self).__init__(config=config)
|
100 |
+
self.bert: BertModel = BertModel(config)
|
101 |
+
self.dim = 128
|
102 |
+
self.linear = nn.Linear(self.bert.config.hidden_size, self.dim)
|
103 |
+
self.main_input_name = "doc_input_ids"
|
104 |
+
|
105 |
+
def forward(self, *args, **kwargs):
|
106 |
+
"""
|
107 |
+
Forward pass through BERT and the linear layer for dimensionality reduction
|
108 |
+
|
109 |
+
Args:
|
110 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
111 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
115 |
+
"""
|
116 |
+
outputs = self.bert(*args, **kwargs)
|
117 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
118 |
+
proj = self.linear(last_hidden_states)
|
119 |
+
# normalize l2 norm
|
120 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
121 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
122 |
+
return proj
|
123 |
+
|
124 |
+
|
125 |
+
class BiBERT(BertPreTrainedModel):
|
126 |
+
def __init__(self, config):
|
127 |
+
super(BiBERT, self).__init__(config=config)
|
128 |
+
self.bert: BertModel = BertModel(config)
|
129 |
+
self.main_input_name = "doc_input_ids"
|
130 |
+
|
131 |
+
def forward(self, *args, **kwargs):
|
132 |
+
"""
|
133 |
+
Forward pass through BERT and the linear layer for dimensionality reduction
|
134 |
+
|
135 |
+
Args:
|
136 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
137 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
141 |
+
"""
|
142 |
+
outputs = self.bert(*args, **kwargs)
|
143 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
144 |
+
# pooling - mean tokens that have attention mask == 1
|
145 |
+
proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
|
146 |
+
proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
|
147 |
+
# normalize l2 norm
|
148 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
149 |
+
return proj
|
150 |
+
|
151 |
+
|
152 |
+
class ColLlama(LlamaPreTrainedModel):
|
153 |
+
def __init__(self, config):
|
154 |
+
super(ColLlama, self).__init__(config=config)
|
155 |
+
self.model: LlamaModel = LlamaModel(config)
|
156 |
+
self.dim = 128
|
157 |
+
self.linear = nn.Linear(self.model.config.hidden_size, self.dim)
|
158 |
+
self.main_input_name = "doc_input_ids"
|
159 |
+
|
160 |
+
def forward(self, *args, **kwargs):
|
161 |
+
"""
|
162 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
163 |
+
|
164 |
+
Args:
|
165 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
166 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
170 |
+
"""
|
171 |
+
outputs = self.model(*args, **kwargs)
|
172 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
173 |
+
proj = self.linear(last_hidden_states)
|
174 |
+
# normalize l2 norm
|
175 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
176 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
177 |
+
return proj
|
colpali-main/colpali_engine/models/idefics_colbert_architecture.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from transformers import Idefics2Model, Idefics2PreTrainedModel
|
3 |
+
|
4 |
+
|
5 |
+
class BiIdefics(Idefics2PreTrainedModel):
|
6 |
+
def __init__(self, config):
|
7 |
+
super(BiIdefics, self).__init__(config=config)
|
8 |
+
self.model: Idefics2Model = Idefics2Model(config)
|
9 |
+
self.pooling_strategy = "last"
|
10 |
+
self.main_input_name = "doc_input_ids"
|
11 |
+
|
12 |
+
def forward(self, *args, **kwargs):
|
13 |
+
"""
|
14 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
15 |
+
|
16 |
+
Args:
|
17 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
18 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
22 |
+
"""
|
23 |
+
outputs = self.model(*args, **kwargs)
|
24 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
25 |
+
# pooling - last token
|
26 |
+
proj = last_hidden_states[:, -1, :]
|
27 |
+
# normalize l2 norm
|
28 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
29 |
+
return proj
|
30 |
+
|
31 |
+
|
32 |
+
class ColIdefics(Idefics2PreTrainedModel):
|
33 |
+
def __init__(self, config):
|
34 |
+
super(ColIdefics, self).__init__(config=config)
|
35 |
+
self.model: Idefics2Model = Idefics2Model(config)
|
36 |
+
self.dim = 128
|
37 |
+
self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
38 |
+
self.main_input_name = "doc_input_ids"
|
39 |
+
|
40 |
+
def forward(self, *args, **kwargs):
|
41 |
+
"""
|
42 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
43 |
+
|
44 |
+
Args:
|
45 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
46 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
50 |
+
"""
|
51 |
+
outputs = self.model(*args, **kwargs)
|
52 |
+
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
53 |
+
proj = self.linear(last_hidden_states)
|
54 |
+
# normalize l2 norm
|
55 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
56 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
57 |
+
return proj
|
colpali-main/colpali_engine/models/paligemma_colbert_architecture.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel
|
4 |
+
|
5 |
+
|
6 |
+
class BiPaliLast(PaliGemmaPreTrainedModel):
|
7 |
+
def __init__(self, config):
|
8 |
+
super(BiPaliLast, self).__init__(config=config)
|
9 |
+
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
10 |
+
self.pooling_strategy = "last"
|
11 |
+
self.main_input_name = "doc_input_ids"
|
12 |
+
|
13 |
+
def forward(self, *args, **kwargs):
|
14 |
+
"""
|
15 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
16 |
+
|
17 |
+
Args:
|
18 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
19 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
23 |
+
"""
|
24 |
+
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
25 |
+
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
26 |
+
# pooling - last token
|
27 |
+
proj = last_hidden_states[:, -1, :]
|
28 |
+
# normalize l2 norm
|
29 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
30 |
+
return proj
|
31 |
+
|
32 |
+
|
33 |
+
class BiPaliMean(PaliGemmaPreTrainedModel):
|
34 |
+
def __init__(self, config):
|
35 |
+
super(BiPaliMean, self).__init__(config=config)
|
36 |
+
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
37 |
+
self.pooling_strategy = "mean"
|
38 |
+
self.main_input_name = "doc_input_ids"
|
39 |
+
|
40 |
+
def forward(self, *args, **kwargs):
|
41 |
+
"""
|
42 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
43 |
+
|
44 |
+
Args:
|
45 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
46 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
50 |
+
"""
|
51 |
+
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
52 |
+
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
53 |
+
# pooling -mean on attention mask==1
|
54 |
+
proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
|
55 |
+
kwargs["attention_mask"], dim=1, keepdim=True
|
56 |
+
)
|
57 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
58 |
+
return proj
|
59 |
+
|
60 |
+
|
61 |
+
class ColPali(PaliGemmaPreTrainedModel):
|
62 |
+
def __init__(self, config):
|
63 |
+
super(ColPali, self).__init__(config=config)
|
64 |
+
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
65 |
+
self.dim = 128
|
66 |
+
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
67 |
+
self.main_input_name = "doc_input_ids"
|
68 |
+
|
69 |
+
def forward(self, *args, **kwargs):
|
70 |
+
"""
|
71 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
72 |
+
|
73 |
+
Args:
|
74 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
75 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
79 |
+
"""
|
80 |
+
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
81 |
+
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
82 |
+
proj = self.custom_text_proj(last_hidden_states)
|
83 |
+
# normalize l2 norm
|
84 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
85 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
86 |
+
return proj
|
87 |
+
|
88 |
+
|
89 |
+
class ColNewSiglip(PaliGemmaPreTrainedModel):
|
90 |
+
def __init__(self, config):
|
91 |
+
super(ColNewSiglip, self).__init__(config=config)
|
92 |
+
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
93 |
+
self.dim = 128
|
94 |
+
self.custom_image_proj = nn.Linear(self.model.config.vision_config.projection_dim, self.dim)
|
95 |
+
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
96 |
+
self.main_input_name = "doc_input_ids"
|
97 |
+
|
98 |
+
def forward(self, *args, **kwargs):
|
99 |
+
"""
|
100 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
101 |
+
|
102 |
+
Args:
|
103 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
104 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
108 |
+
"""
|
109 |
+
# outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
110 |
+
if "pixel_values" in kwargs:
|
111 |
+
image_features = self.vision_model_output(*args, **kwargs)
|
112 |
+
# print(f"Doc: {image_features.shape}")
|
113 |
+
proj = self.custom_image_proj(image_features)
|
114 |
+
# print(f"Doc proj: {proj.shape}")
|
115 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
116 |
+
else:
|
117 |
+
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
118 |
+
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
119 |
+
# print(f"Query: {last_hidden_states.shape}")
|
120 |
+
proj = self.custom_text_proj(last_hidden_states)
|
121 |
+
# print(f"Query proj: {proj.shape}")
|
122 |
+
# normalize l2 norm
|
123 |
+
proj = proj / proj.norm(dim=-1, keepdim=True)
|
124 |
+
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
125 |
+
return proj
|
126 |
+
|
127 |
+
def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
|
128 |
+
|
129 |
+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
130 |
+
# 2. Merge text and images
|
131 |
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
132 |
+
image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
133 |
+
selected_image_feature = image_outputs.last_hidden_state
|
134 |
+
image_features = self.model.multi_modal_projector(selected_image_feature)
|
135 |
+
|
136 |
+
return image_features
|
137 |
+
|
138 |
+
raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
|
139 |
+
|
140 |
+
|
141 |
+
class BiNewSiglip(PaliGemmaPreTrainedModel):
|
142 |
+
def __init__(self, config):
|
143 |
+
super(BiNewSiglip, self).__init__(config=config)
|
144 |
+
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
145 |
+
self.main_input_name = "doc_input_ids"
|
146 |
+
|
147 |
+
def forward(self, *args, **kwargs):
|
148 |
+
"""
|
149 |
+
Forward pass through Llama and the linear layer for dimensionality reduction
|
150 |
+
|
151 |
+
Args:
|
152 |
+
- input_ids (torch.LongTensor): The input tokens tensor.
|
153 |
+
- attention_mask (torch.LongTensor): The attention mask tensor.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
157 |
+
"""
|
158 |
+
# outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
159 |
+
if "pixel_values" in kwargs:
|
160 |
+
image_features = self.vision_model_output(*args, **kwargs)
|
161 |
+
# print(f"Doc: {image_features.shape}")
|
162 |
+
# pool image features
|
163 |
+
proj = torch.mean(image_features, dim=1)
|
164 |
+
# print(f"Doc proj: {proj.shape}")
|
165 |
+
norm = proj.norm(dim=-1, keepdim=True)
|
166 |
+
proj = proj / norm
|
167 |
+
else:
|
168 |
+
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
169 |
+
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
170 |
+
# pooling -mean on attention mask==1
|
171 |
+
|
172 |
+
proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
|
173 |
+
kwargs["attention_mask"], dim=1, keepdim=True
|
174 |
+
)
|
175 |
+
# print(f"Query proj: {proj.shape}")
|
176 |
+
norm = proj.norm(dim=-1, keepdim=True)
|
177 |
+
proj = proj / norm
|
178 |
+
return proj
|
179 |
+
|
180 |
+
def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
|
181 |
+
|
182 |
+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
183 |
+
# 2. Merge text and images
|
184 |
+
if pixel_values is not None and input_ids.shape[1] != 1:
|
185 |
+
image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
186 |
+
selected_image_feature = image_outputs.last_hidden_state
|
187 |
+
image_features = self.model.multi_modal_projector(selected_image_feature)
|
188 |
+
|
189 |
+
return image_features
|
190 |
+
|
191 |
+
raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
|
colpali-main/colpali_engine/trainer/__init__.py
ADDED
File without changes
|
colpali-main/colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (171 Bytes). View file
|
|
colpali-main/colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc
ADDED
Binary file (3.18 kB). View file
|
|
colpali-main/colpali_engine/trainer/contrastive_trainer.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import Trainer
|
3 |
+
|
4 |
+
|
5 |
+
class ContrastiveTrainer(Trainer):
|
6 |
+
def __init__(self, loss_func, is_vision_model, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.loss_func = loss_func
|
9 |
+
self.is_vision_model = is_vision_model
|
10 |
+
|
11 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
12 |
+
query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
|
13 |
+
if self.is_vision_model:
|
14 |
+
if "doc_pixel_attention_mask" not in inputs:
|
15 |
+
doc_outputs = model(
|
16 |
+
input_ids=inputs["doc_input_ids"],
|
17 |
+
attention_mask=inputs["doc_attention_mask"],
|
18 |
+
pixel_values=inputs["doc_pixel_values"],
|
19 |
+
)
|
20 |
+
else:
|
21 |
+
doc_outputs = model(
|
22 |
+
input_ids=inputs["doc_input_ids"],
|
23 |
+
attention_mask=inputs["doc_attention_mask"],
|
24 |
+
pixel_values=inputs["doc_pixel_values"],
|
25 |
+
pixel_attention_mask=inputs["doc_pixel_attention_mask"],
|
26 |
+
)
|
27 |
+
else:
|
28 |
+
doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
|
29 |
+
|
30 |
+
loss = self.loss_func(query_outputs, doc_outputs)
|
31 |
+
return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
|
32 |
+
|
33 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
|
34 |
+
"""This function is used to generate predictions and return the loss for the given inputs."""
|
35 |
+
if not prediction_loss_only:
|
36 |
+
raise ValueError("prediction_step is only called with prediction_loss_only=True")
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
if self.is_vision_model:
|
40 |
+
if "doc_pixel_attention_mask" not in inputs:
|
41 |
+
doc_outputs = model(
|
42 |
+
input_ids=inputs["doc_input_ids"],
|
43 |
+
attention_mask=inputs["doc_attention_mask"],
|
44 |
+
pixel_values=inputs["doc_pixel_values"],
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
doc_outputs = model(
|
48 |
+
input_ids=inputs["doc_input_ids"],
|
49 |
+
attention_mask=inputs["doc_attention_mask"],
|
50 |
+
pixel_values=inputs["doc_pixel_values"],
|
51 |
+
pixel_attention_mask=inputs["doc_pixel_attention_mask"],
|
52 |
+
)
|
53 |
+
query_outputs = model(
|
54 |
+
input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
|
58 |
+
query_outputs = model(
|
59 |
+
input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
|
60 |
+
)
|
61 |
+
doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
|
62 |
+
|
63 |
+
loss = self.loss_func(query_outputs, doc_outputs)
|
64 |
+
return loss, None, None
|
colpali-main/colpali_engine/trainer/retrieval_evaluator.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from mteb.evaluation.evaluators import RetrievalEvaluator
|
3 |
+
|
4 |
+
|
5 |
+
class CustomEvaluator:
|
6 |
+
def __init__(self, is_multi_vector=False):
|
7 |
+
self.is_multi_vector = is_multi_vector
|
8 |
+
self.mteb_evaluator = RetrievalEvaluator()
|
9 |
+
|
10 |
+
def evaluate(self, qs, ps):
|
11 |
+
if self.is_multi_vector:
|
12 |
+
scores = self.evaluate_colbert(qs, ps)
|
13 |
+
else:
|
14 |
+
scores = self.evaluate_biencoder(qs, ps)
|
15 |
+
|
16 |
+
assert scores.shape[0] == len(qs)
|
17 |
+
|
18 |
+
arg_score = scores.argmax(dim=1)
|
19 |
+
# compare to arange
|
20 |
+
accuracy = (arg_score == torch.arange(scores.shape[0], device=scores.device)).sum().item() / scores.shape[0]
|
21 |
+
print(arg_score)
|
22 |
+
print(f"Top 1 Accuracy (verif): {accuracy}")
|
23 |
+
|
24 |
+
# cast to numpy
|
25 |
+
# scores = scores.cpu().numpy()
|
26 |
+
scores = scores.to(torch.float32).cpu().numpy()
|
27 |
+
return scores
|
28 |
+
|
29 |
+
def compute_metrics(self, relevant_docs, results, **kwargs):
|
30 |
+
# wrap mteb package
|
31 |
+
|
32 |
+
ndcg, _map, recall, precision, naucs = self.mteb_evaluator.evaluate(
|
33 |
+
relevant_docs,
|
34 |
+
results,
|
35 |
+
self.mteb_evaluator.k_values,
|
36 |
+
ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
|
37 |
+
)
|
38 |
+
mrr = self.mteb_evaluator.evaluate_custom(relevant_docs, results, self.mteb_evaluator.k_values, "mrr")
|
39 |
+
scores = {
|
40 |
+
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
41 |
+
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
42 |
+
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
|
43 |
+
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
|
44 |
+
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
|
45 |
+
**{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
|
46 |
+
}
|
47 |
+
return scores
|
48 |
+
|
49 |
+
def evaluate_colbert(self, qs, ps, batch_size=128) -> torch.Tensor:
|
50 |
+
scores = []
|
51 |
+
for i in range(0, len(qs), batch_size):
|
52 |
+
scores_batch = []
|
53 |
+
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
|
54 |
+
"cpu"
|
55 |
+
)
|
56 |
+
for j in range(0, len(ps), batch_size):
|
57 |
+
ps_batch = torch.nn.utils.rnn.pad_sequence(
|
58 |
+
ps[j : j + batch_size], batch_first=True, padding_value=0
|
59 |
+
).to("cpu")
|
60 |
+
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
|
61 |
+
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
62 |
+
scores.append(scores_batch)
|
63 |
+
scores = torch.cat(scores, dim=0)
|
64 |
+
return scores
|
65 |
+
|
66 |
+
def evaluate_biencoder(self, qs, ps) -> torch.Tensor:
|
67 |
+
|
68 |
+
qs = torch.stack(qs)
|
69 |
+
ps = torch.stack(ps)
|
70 |
+
|
71 |
+
scores = torch.einsum("bd,cd->bc", qs, ps)
|
72 |
+
return scores
|
colpali-main/colpali_engine/utils/__init__.py
ADDED
File without changes
|
colpali-main/colpali_engine/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (169 Bytes). View file
|
|
colpali-main/colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc
ADDED
Binary file (1.2 kB). View file
|
|
colpali-main/colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc
ADDED
Binary file (998 Bytes). View file
|
|
colpali-main/colpali_engine/utils/colidefics_processing_utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utils for processing images and queries for ColPaLi
|
2 |
+
|
3 |
+
def process_images(processor, images, max_length: int = 50):
|
4 |
+
texts_doc = []
|
5 |
+
images = [image.convert("RGB") for image in images]
|
6 |
+
|
7 |
+
for _ in images:
|
8 |
+
messages_doc = [
|
9 |
+
{
|
10 |
+
"role": "user",
|
11 |
+
"content": [
|
12 |
+
{"type": "text", "text": "Describe the image."},
|
13 |
+
{"type": "image"},
|
14 |
+
],
|
15 |
+
},
|
16 |
+
]
|
17 |
+
|
18 |
+
text_doc = processor.apply_chat_template(messages_doc, add_generation_prompt=False)
|
19 |
+
texts_doc.append(text_doc.strip())
|
20 |
+
|
21 |
+
batch_doc = processor(
|
22 |
+
text=texts_doc,
|
23 |
+
images=images,
|
24 |
+
return_tensors="pt",
|
25 |
+
padding="longest",
|
26 |
+
)
|
27 |
+
return batch_doc
|
28 |
+
|
29 |
+
|
30 |
+
def process_queries(processor, queries, mock_image, max_length: int = 50):
|
31 |
+
texts_query = []
|
32 |
+
for query in queries:
|
33 |
+
messages_query = [
|
34 |
+
{
|
35 |
+
"role": "user",
|
36 |
+
"content": [
|
37 |
+
{
|
38 |
+
"type": "text",
|
39 |
+
"text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
|
40 |
+
},
|
41 |
+
],
|
42 |
+
},
|
43 |
+
]
|
44 |
+
text_query = processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
|
45 |
+
texts_query.append(text_query)
|
46 |
+
|
47 |
+
batch_query = processor(
|
48 |
+
text=texts_query,
|
49 |
+
return_tensors="pt",
|
50 |
+
padding="longest",
|
51 |
+
max_length=max_length,
|
52 |
+
)
|
53 |
+
return batch_query
|
colpali-main/colpali_engine/utils/colpali_processing_utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Utils for processing images and queries for ColPaLi
|
2 |
+
|
3 |
+
|
4 |
+
def process_images(processor, images, max_length: int = 50):
|
5 |
+
texts_doc = ["Describe the image."] * len(images)
|
6 |
+
images = [image.convert("RGB") for image in images]
|
7 |
+
|
8 |
+
batch_doc = processor(
|
9 |
+
text=texts_doc,
|
10 |
+
images=images,
|
11 |
+
return_tensors="pt",
|
12 |
+
padding="longest",
|
13 |
+
max_length=max_length + processor.image_seq_length,
|
14 |
+
)
|
15 |
+
return batch_doc
|
16 |
+
|
17 |
+
|
18 |
+
def process_queries(processor, queries, mock_image, max_length: int = 50):
|
19 |
+
texts_query = []
|
20 |
+
for query in queries:
|
21 |
+
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
|
22 |
+
texts_query.append(query)
|
23 |
+
|
24 |
+
batch_query = processor(
|
25 |
+
images=[mock_image.convert("RGB")] * len(texts_query),
|
26 |
+
# NOTE: the image is not used in batch_query but it is required for calling the processor
|
27 |
+
text=texts_query,
|
28 |
+
return_tensors="pt",
|
29 |
+
padding="longest",
|
30 |
+
max_length=max_length + processor.image_seq_length,
|
31 |
+
)
|
32 |
+
del batch_query["pixel_values"]
|
33 |
+
|
34 |
+
batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
|
35 |
+
batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
|
36 |
+
return batch_query
|
colpali-main/colpali_engine/utils/dataset_transformation.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
4 |
+
|
5 |
+
USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
|
6 |
+
|
7 |
+
|
8 |
+
def add_metadata_column(dataset, column_name, value):
|
9 |
+
def add_source(example):
|
10 |
+
example[column_name] = value
|
11 |
+
return example
|
12 |
+
|
13 |
+
return dataset.map(add_source)
|
14 |
+
|
15 |
+
|
16 |
+
def load_train_set() -> DatasetDict:
|
17 |
+
|
18 |
+
ds_paths = [
|
19 |
+
"infovqa_train",
|
20 |
+
"docvqa_train",
|
21 |
+
"arxivqa_train",
|
22 |
+
"tatdqa_train",
|
23 |
+
"syntheticDocQA_government_reports_train",
|
24 |
+
"syntheticDocQA_healthcare_industry_train",
|
25 |
+
"syntheticDocQA_artificial_intelligence_train",
|
26 |
+
"syntheticDocQA_energy_train",
|
27 |
+
]
|
28 |
+
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
29 |
+
ds_tot = []
|
30 |
+
for path in ds_paths:
|
31 |
+
cpath = base_path + path
|
32 |
+
ds = load_dataset(cpath, split="train")
|
33 |
+
if "arxivqa" in path:
|
34 |
+
# subsample 10k
|
35 |
+
ds = ds.shuffle(42).select(range(10000))
|
36 |
+
ds_tot.append(ds)
|
37 |
+
|
38 |
+
dataset = concatenate_datasets(ds_tot)
|
39 |
+
dataset = dataset.shuffle(seed=42)
|
40 |
+
# split into train and test
|
41 |
+
dataset_eval = dataset.select(range(500))
|
42 |
+
dataset = dataset.select(range(500, len(dataset)))
|
43 |
+
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
44 |
+
return ds_dict
|
45 |
+
|
46 |
+
|
47 |
+
def load_train_set_with_tabfquad() -> DatasetDict:
|
48 |
+
|
49 |
+
ds_paths = [
|
50 |
+
"infovqa_train",
|
51 |
+
"docvqa_train",
|
52 |
+
"arxivqa_train",
|
53 |
+
"tatdqa_train",
|
54 |
+
"tabfquad_train_subsampled",
|
55 |
+
"syntheticDocQA_government_reports_train",
|
56 |
+
"syntheticDocQA_healthcare_industry_train",
|
57 |
+
"syntheticDocQA_artificial_intelligence_train",
|
58 |
+
"syntheticDocQA_energy_train",
|
59 |
+
]
|
60 |
+
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
61 |
+
ds_tot = []
|
62 |
+
for path in ds_paths:
|
63 |
+
cpath = base_path + path
|
64 |
+
ds = load_dataset(cpath, split="train")
|
65 |
+
if "arxivqa" in path:
|
66 |
+
# subsample 10k
|
67 |
+
ds = ds.shuffle(42).select(range(10000))
|
68 |
+
ds_tot.append(ds)
|
69 |
+
|
70 |
+
dataset = concatenate_datasets(ds_tot)
|
71 |
+
dataset = dataset.shuffle(seed=42)
|
72 |
+
# split into train and test
|
73 |
+
dataset_eval = dataset.select(range(500))
|
74 |
+
dataset = dataset.select(range(500, len(dataset)))
|
75 |
+
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
76 |
+
return ds_dict
|
77 |
+
|
78 |
+
|
79 |
+
def load_train_set_with_docmatix() -> DatasetDict:
|
80 |
+
|
81 |
+
ds_paths = [
|
82 |
+
"infovqa_train",
|
83 |
+
"docvqa_train",
|
84 |
+
"arxivqa_train",
|
85 |
+
"tatdqa_train",
|
86 |
+
"tabfquad_train_subsampled",
|
87 |
+
"syntheticDocQA_government_reports_train",
|
88 |
+
"syntheticDocQA_healthcare_industry_train",
|
89 |
+
"syntheticDocQA_artificial_intelligence_train",
|
90 |
+
"syntheticDocQA_energy_train",
|
91 |
+
"Docmatix_filtered_train",
|
92 |
+
]
|
93 |
+
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
94 |
+
ds_tot = []
|
95 |
+
for path in ds_paths:
|
96 |
+
cpath = base_path + path
|
97 |
+
ds = load_dataset(cpath, split="train")
|
98 |
+
if "arxivqa" in path:
|
99 |
+
# subsample 10k
|
100 |
+
ds = ds.shuffle(42).select(range(10000))
|
101 |
+
ds_tot.append(ds)
|
102 |
+
|
103 |
+
dataset = concatenate_datasets(ds_tot)
|
104 |
+
dataset = dataset.shuffle(seed=42)
|
105 |
+
# split into train and test
|
106 |
+
dataset_eval = dataset.select(range(500))
|
107 |
+
dataset = dataset.select(range(500, len(dataset)))
|
108 |
+
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
109 |
+
return ds_dict
|
110 |
+
|
111 |
+
|
112 |
+
def load_docvqa_dataset() -> DatasetDict:
|
113 |
+
if USE_LOCAL_DATASET:
|
114 |
+
dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
|
115 |
+
dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
|
116 |
+
dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
|
117 |
+
dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
|
118 |
+
else:
|
119 |
+
dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
|
120 |
+
dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
|
121 |
+
dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
|
122 |
+
dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
|
123 |
+
|
124 |
+
# concatenate the two datasets
|
125 |
+
dataset = concatenate_datasets([dataset_doc, dataset_info])
|
126 |
+
dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
|
127 |
+
# sample 100 from eval dataset
|
128 |
+
dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
|
129 |
+
|
130 |
+
# rename question as query
|
131 |
+
dataset = dataset.rename_column("question", "query")
|
132 |
+
dataset_eval = dataset_eval.rename_column("question", "query")
|
133 |
+
|
134 |
+
# create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
|
135 |
+
dataset = dataset.map(
|
136 |
+
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
|
137 |
+
)
|
138 |
+
dataset_eval = dataset_eval.map(
|
139 |
+
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
|
140 |
+
)
|
141 |
+
|
142 |
+
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
143 |
+
|
144 |
+
return ds_dict
|
145 |
+
|
146 |
+
|
147 |
+
class TestSetFactory:
|
148 |
+
def __init__(self, dataset_path):
|
149 |
+
self.dataset_path = dataset_path
|
150 |
+
|
151 |
+
def __call__(self, *args, **kwargs):
|
152 |
+
dataset = load_dataset(self.dataset_path, split="test")
|
153 |
+
return dataset
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
|
158 |
+
print(ds)
|
colpali-main/colpali_engine/utils/gpu_stats.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# cond import
|
2 |
+
try:
|
3 |
+
from pynvml import *
|
4 |
+
|
5 |
+
def print_gpu_utilization():
|
6 |
+
nvmlInit()
|
7 |
+
handle = nvmlDeviceGetHandleByIndex(0)
|
8 |
+
info = nvmlDeviceGetMemoryInfo(handle)
|
9 |
+
print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
|
10 |
+
|
11 |
+
def print_summary(result):
|
12 |
+
print(f"Time: {result.metrics['train_runtime']:.2f}")
|
13 |
+
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
|
14 |
+
print_gpu_utilization()
|
15 |
+
|
16 |
+
except ImportError:
|
17 |
+
print("pynvml not found. GPU stats will not be printed.")
|
18 |
+
|
19 |
+
def print_summary(result):
|
20 |
+
print(f"Time: {result.metrics['train_runtime']:.2f}")
|
21 |
+
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
|
22 |
+
|
23 |
+
def print_gpu_utilization():
|
24 |
+
pass
|
colpali-main/colpali_engine/utils/image_from_page_utils.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def load_from_pdf(pdf_path: str):
|
6 |
+
from pdf2image import convert_from_path
|
7 |
+
|
8 |
+
images = convert_from_path(pdf_path)
|
9 |
+
return images
|
10 |
+
|
11 |
+
|
12 |
+
def load_from_image_urls(urls: str):
|
13 |
+
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
14 |
+
return images
|
15 |
+
|
16 |
+
|
17 |
+
def load_from_dataset(dataset):
|
18 |
+
from datasets import load_dataset
|
19 |
+
|
20 |
+
dataset = load_dataset(dataset, split="test")
|
21 |
+
return dataset["image"]
|
colpali-main/colpali_engine/utils/image_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for working with images.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
|
12 |
+
"""
|
13 |
+
Scale an image to a new height while maintaining the aspect ratio.
|
14 |
+
"""
|
15 |
+
# Calculate the scaling factor
|
16 |
+
width, height = image.size
|
17 |
+
aspect_ratio = width / height
|
18 |
+
new_width = int(new_height * aspect_ratio)
|
19 |
+
|
20 |
+
# Resize the image
|
21 |
+
scaled_image = image.resize((new_width, new_height))
|
22 |
+
|
23 |
+
return scaled_image
|
24 |
+
|
25 |
+
|
26 |
+
def scale_to_max_dimension(image: Image.Image, max_dimension: int = 1024) -> Image.Image:
|
27 |
+
"""
|
28 |
+
Scale an image to a maximum dimension while maintaining the aspect ratio.
|
29 |
+
"""
|
30 |
+
# Get the dimensions of the image
|
31 |
+
width, height = image.size
|
32 |
+
|
33 |
+
max_original_dimension = max(width, height)
|
34 |
+
|
35 |
+
if max_original_dimension < max_dimension:
|
36 |
+
return image
|
37 |
+
|
38 |
+
# Calculate the scaling factor
|
39 |
+
aspect_ratio = max_dimension / max_original_dimension
|
40 |
+
new_width = int(width * aspect_ratio)
|
41 |
+
new_height = int(height * aspect_ratio)
|
42 |
+
|
43 |
+
# Resize the image
|
44 |
+
scaled_image = image.resize((new_width, new_height))
|
45 |
+
|
46 |
+
return scaled_image
|
47 |
+
|
48 |
+
|
49 |
+
def get_base64_image(img: str | Image.Image, add_url_prefix: bool = True) -> str:
|
50 |
+
"""
|
51 |
+
Convert an image (from a filepath or a PIL.Image object) to a JPEG-base64 string.
|
52 |
+
"""
|
53 |
+
if isinstance(img, str):
|
54 |
+
img = Image.open(img)
|
55 |
+
elif isinstance(img, Image.Image):
|
56 |
+
pass
|
57 |
+
else:
|
58 |
+
raise ValueError("`img` must be a path to an image or a PIL Image object.")
|
59 |
+
|
60 |
+
buffered = io.BytesIO()
|
61 |
+
img.save(buffered, format="jpeg")
|
62 |
+
b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
63 |
+
|
64 |
+
return f"data:image/jpeg;base64,{b64_data}" if add_url_prefix else b64_data
|
colpali-main/colpali_engine/utils/iter_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
|
4 |
+
def islice(iterable, *args):
|
5 |
+
"""
|
6 |
+
Yield a slice of an iterable.
|
7 |
+
>>> islice('ABCDEFG', 2) → A B
|
8 |
+
>>> islice('ABCDEFG', 2, 4) → C D
|
9 |
+
>>> islice('ABCDEFG', 2, None) → C D E F G
|
10 |
+
>>> islice('ABCDEFG', 0, None, 2) → A C E G
|
11 |
+
"""
|
12 |
+
s = slice(*args)
|
13 |
+
start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
|
14 |
+
it = iter(range(start, stop, step))
|
15 |
+
try:
|
16 |
+
nexti = next(it)
|
17 |
+
except StopIteration:
|
18 |
+
# Consume *iterable* up to the *start* position.
|
19 |
+
for i, element in zip(range(start), iterable):
|
20 |
+
pass
|
21 |
+
return
|
22 |
+
try:
|
23 |
+
for i, element in enumerate(iterable):
|
24 |
+
if i == nexti:
|
25 |
+
yield element
|
26 |
+
nexti = next(it)
|
27 |
+
except StopIteration:
|
28 |
+
# Consume to *stop*.
|
29 |
+
for i, element in zip(range(i + 1, stop), iterable):
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
def batched(iterable, n: int):
|
34 |
+
"""
|
35 |
+
Yield batches of n elements from an iterable.
|
36 |
+
>>> batched('ABCDEFG', 3) → ABC DEF G
|
37 |
+
"""
|
38 |
+
if n < 1:
|
39 |
+
raise ValueError("n must be at least one")
|
40 |
+
it = iter(iterable)
|
41 |
+
while batch := tuple(islice(it, n)):
|
42 |
+
yield batch
|
colpali-main/colpali_engine/utils/pdf_utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from pdf2image import convert_from_path
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
random.seed(42)
|
10 |
+
|
11 |
+
|
12 |
+
def convert_pdf_to_images(pdf_file: str, save_folder: str):
|
13 |
+
"""
|
14 |
+
Convert each page of a pdf to a jpg image and save them in a folder.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
- pdf_file (str): path to the pdf file
|
18 |
+
- save_folder (str): path to the folder where the images will be saved
|
19 |
+
|
20 |
+
"""
|
21 |
+
images = convert_from_path(pdf_file)
|
22 |
+
|
23 |
+
for i, image in enumerate(images):
|
24 |
+
if not os.path.exists(save_folder):
|
25 |
+
os.makedirs(save_folder)
|
26 |
+
image.save(os.path.join(save_folder, f"page_{i+1}.jpg"), "JPEG")
|
27 |
+
|
28 |
+
|
29 |
+
def convert_all_pdfs_to_images(path_to_folder: str, n_samples: int = 0):
|
30 |
+
"""
|
31 |
+
Convert all pdfs in a folder and its subfolder to images and save them in a folder.
|
32 |
+
It will sample n_samples pdf files in each subfolder, allowing to have granularity on the number of pdf files to convert.
|
33 |
+
|
34 |
+
|
35 |
+
Args:
|
36 |
+
- path_to_folder (str): path to the folder containing the pdf files
|
37 |
+
- n_samples (int): number of pdf files to sample in each subfolder
|
38 |
+
|
39 |
+
directory structure:
|
40 |
+
- path_to_folder
|
41 |
+
- subfolder1
|
42 |
+
- pdf1
|
43 |
+
- pdf2
|
44 |
+
- ...
|
45 |
+
- subfolder2
|
46 |
+
- pdf1
|
47 |
+
- pdf2
|
48 |
+
- ...
|
49 |
+
- ...
|
50 |
+
|
51 |
+
"""
|
52 |
+
# take n_samples pdf files in each subfolder : I want to take 10 pdf files from each subfolder
|
53 |
+
sub_dirs = [d for d in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, d))]
|
54 |
+
|
55 |
+
sampled_files = []
|
56 |
+
|
57 |
+
for sub_dir in sub_dirs:
|
58 |
+
pdf_files = glob.glob(os.path.join(path_to_folder, sub_dir, "*.pdf"))
|
59 |
+
|
60 |
+
if (n_samples == 0) or (len(pdf_files) <= n_samples):
|
61 |
+
print(f"Taking all pdf files in {sub_dir}")
|
62 |
+
sampled_files.extend(pdf_files)
|
63 |
+
|
64 |
+
else:
|
65 |
+
print(f"Taking {n_samples} pdf files in {sub_dir}")
|
66 |
+
sampled_files.extend(random.sample(pdf_files, n_samples))
|
67 |
+
|
68 |
+
pdf_files = [str(file) for file in sampled_files]
|
69 |
+
|
70 |
+
# Create an empty text file that will contain the file paths of the corrupted pdf files
|
71 |
+
dirpath_corrupted = Path(path_to_folder) / "corrupted_pdf_files.txt"
|
72 |
+
dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)
|
73 |
+
|
74 |
+
with dirpath_corrupted.open("w") as f:
|
75 |
+
with tqdm(total=len(pdf_files)) as pbar:
|
76 |
+
for pdf_file in pdf_files:
|
77 |
+
pbar.set_description(f"Processing {pdf_file}")
|
78 |
+
save_folder = os.path.join("pages_extracted", *Path(pdf_file).parts[-2:])
|
79 |
+
if not os.path.exists(os.path.join(path_to_folder, save_folder)):
|
80 |
+
try:
|
81 |
+
convert_pdf_to_images(pdf_file, os.path.join(path_to_folder, save_folder))
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Error converting {pdf_file}: {e}")
|
84 |
+
f.write(pdf_file)
|
85 |
+
f.write("\n")
|
86 |
+
pbar.update(1)
|
87 |
+
return
|
colpali-main/colpali_engine/utils/plot_utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import seaborn as sns
|
2 |
+
|
3 |
+
|
4 |
+
def setup_seaborn():
|
5 |
+
sns.set_style("white")
|
6 |
+
sns.set_context("paper", font_scale=2)
|
colpali-main/colpali_engine/utils/torch_utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for interpretability.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def get_torch_device() -> str:
|
9 |
+
"""
|
10 |
+
Returns the device and dtype to be used for torch tensors.
|
11 |
+
"""
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = "cuda:0"
|
14 |
+
elif torch.backends.mps.is_available(): # for Apple Silicon
|
15 |
+
device = "mps"
|
16 |
+
else:
|
17 |
+
device = "cpu"
|
18 |
+
return device
|
colpali-main/colpali_engine/utils/train_colpali_engine_models.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HuggingFace trainer
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Callable, Dict, Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from datasets import concatenate_datasets
|
9 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments
|
13 |
+
|
14 |
+
from colpali_engine.dataset.custom_collator import CustomCollator
|
15 |
+
from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss
|
16 |
+
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
|
17 |
+
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
18 |
+
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class ColModelTrainingConfig:
|
23 |
+
model: PreTrainedModel
|
24 |
+
tr_args: TrainingArguments = None
|
25 |
+
output_dir: str = None
|
26 |
+
max_length: int = 256
|
27 |
+
run_eval: bool = True
|
28 |
+
run_train: bool = True
|
29 |
+
peft_config: Optional[LoraConfig] = None
|
30 |
+
add_suffix: bool = False
|
31 |
+
processor: Idefics2Processor = None
|
32 |
+
tokenizer: PreTrainedTokenizer = None
|
33 |
+
loss_func: Optional[Callable] = ColbertLoss()
|
34 |
+
dataset_loading_func: Optional[Callable] = None
|
35 |
+
eval_dataset_loader: Optional[Dict[str, Callable]] = None
|
36 |
+
pretrained_peft_model_name_or_path: Optional[str] = None
|
37 |
+
|
38 |
+
def __post_init__(self):
|
39 |
+
if self.output_dir is None:
|
40 |
+
sanitized_name = str(self.model.name_or_path).replace("/", "_")
|
41 |
+
self.output_dir = f"./models/{sanitized_name}"
|
42 |
+
|
43 |
+
if self.tr_args is None:
|
44 |
+
self.tr_args = TrainingArguments(output_dir=self.output_dir)
|
45 |
+
elif self.tr_args.output_dir is None:
|
46 |
+
self.tr_args.output_dir = self.output_dir
|
47 |
+
|
48 |
+
# cast if string
|
49 |
+
if isinstance(self.tr_args.learning_rate, str):
|
50 |
+
self.tr_args.learning_rate = float(self.tr_args.learning_rate)
|
51 |
+
self.tr_args.remove_unused_columns = False
|
52 |
+
|
53 |
+
if self.processor is None and self.tokenizer is None:
|
54 |
+
print("Using textual model tokenization")
|
55 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
|
56 |
+
|
57 |
+
if self.pretrained_peft_model_name_or_path is not None:
|
58 |
+
self.model.load_adapter(self.pretrained_peft_model_name_or_path)
|
59 |
+
print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}")
|
60 |
+
|
61 |
+
if self.peft_config is not None:
|
62 |
+
print("Configurating PEFT model")
|
63 |
+
if self.processor is None:
|
64 |
+
# Might be deprecated - use the "else" branch
|
65 |
+
self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True
|
66 |
+
# self.model.enable_input_require_grads()
|
67 |
+
self.model = get_peft_model(self.model, self.peft_config)
|
68 |
+
self.model.print_trainable_parameters()
|
69 |
+
else:
|
70 |
+
# Ugly debugging hack
|
71 |
+
# if self.model.model.config.text_config.vocab_size == 32000:
|
72 |
+
# print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!")
|
73 |
+
# self.model.model.text_model.resize_token_embeddings(32003)
|
74 |
+
# self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2]
|
75 |
+
# self.model.enable_input_require_grads()
|
76 |
+
if self.pretrained_peft_model_name_or_path is None:
|
77 |
+
self.model.add_adapter(self.peft_config)
|
78 |
+
self.model.enable_adapters()
|
79 |
+
else:
|
80 |
+
print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
|
81 |
+
|
82 |
+
print_gpu_utilization()
|
83 |
+
|
84 |
+
|
85 |
+
class ColModelTraining:
|
86 |
+
def __init__(self, config: ColModelTrainingConfig) -> None:
|
87 |
+
self.config = config
|
88 |
+
self.model = self.config.model
|
89 |
+
self.dataset = self.config.dataset_loading_func()
|
90 |
+
self.collator = CustomCollator(
|
91 |
+
processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length
|
92 |
+
)
|
93 |
+
self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
|
94 |
+
self.retriever_evaluator = CustomEvaluator(
|
95 |
+
is_multi_vector=(
|
96 |
+
isinstance(self.config.loss_func, ColbertLoss)
|
97 |
+
or isinstance(self.config.loss_func, ColbertPairwiseCELoss)
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
def train(self) -> None:
|
102 |
+
|
103 |
+
trainer = ContrastiveTrainer(
|
104 |
+
model=self.model,
|
105 |
+
train_dataset=self.dataset["train"],
|
106 |
+
eval_dataset=self.dataset["test"],
|
107 |
+
args=self.config.tr_args,
|
108 |
+
data_collator=self.collator,
|
109 |
+
loss_func=self.config.loss_func,
|
110 |
+
is_vision_model=self.config.processor is not None,
|
111 |
+
)
|
112 |
+
trainer.args.remove_unused_columns = False
|
113 |
+
|
114 |
+
result = trainer.train()
|
115 |
+
print_summary(result)
|
116 |
+
|
117 |
+
def eval_dataset(self, test_dataset):
|
118 |
+
|
119 |
+
self.model.eval()
|
120 |
+
|
121 |
+
# # debug
|
122 |
+
# if len(test_dataset) > 200:
|
123 |
+
# test_dataset = test_dataset.select(range(0, 100))
|
124 |
+
|
125 |
+
idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
|
126 |
+
idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]
|
127 |
+
|
128 |
+
dataloader_with_query = DataLoader(
|
129 |
+
test_dataset.select(idx_with_query),
|
130 |
+
batch_size=self.config.tr_args.per_device_eval_batch_size,
|
131 |
+
shuffle=False,
|
132 |
+
collate_fn=self.collator,
|
133 |
+
)
|
134 |
+
dataloader_without_query = DataLoader(
|
135 |
+
test_dataset.select(idx_without_query),
|
136 |
+
batch_size=self.config.tr_args.per_device_eval_batch_size,
|
137 |
+
shuffle=False,
|
138 |
+
collate_fn=self.collator,
|
139 |
+
)
|
140 |
+
|
141 |
+
# dataset is ordered so that non-null queries come first
|
142 |
+
test_dataset = concatenate_datasets(
|
143 |
+
[test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)]
|
144 |
+
)
|
145 |
+
|
146 |
+
relevant_docs = {}
|
147 |
+
docidx_2_docid = {}
|
148 |
+
qsidx_2_query = []
|
149 |
+
for idx, sample in enumerate(test_dataset):
|
150 |
+
doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"]))
|
151 |
+
# query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"]))
|
152 |
+
if sample["query"] is not None:
|
153 |
+
relevant_docs[str(idx)] = {doc_id: 1}
|
154 |
+
qsidx_2_query.append(str(idx))
|
155 |
+
docidx_2_docid[str(idx)] = doc_id
|
156 |
+
|
157 |
+
qs = []
|
158 |
+
ps = []
|
159 |
+
|
160 |
+
device = self.model.device
|
161 |
+
with (torch.no_grad()):
|
162 |
+
for dataloader in [dataloader_with_query, dataloader_without_query]:
|
163 |
+
for batch in tqdm(dataloader):
|
164 |
+
if "doc_pixel_values" not in batch:
|
165 |
+
doc = self.model(
|
166 |
+
input_ids=batch["doc_input_ids"].to(device),
|
167 |
+
attention_mask=batch["doc_attention_mask"].to(device),
|
168 |
+
)
|
169 |
+
|
170 |
+
else:
|
171 |
+
if "doc_pixel_attention_mask" in batch:
|
172 |
+
doc = self.model(
|
173 |
+
input_ids=batch["doc_input_ids"].to(device),
|
174 |
+
attention_mask=batch["doc_attention_mask"].to(device),
|
175 |
+
pixel_values=batch["doc_pixel_values"].to(device),
|
176 |
+
pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device),
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
doc = self.model(
|
180 |
+
input_ids=batch["doc_input_ids"].to(device),
|
181 |
+
attention_mask=batch["doc_attention_mask"].to(device),
|
182 |
+
pixel_values=batch["doc_pixel_values"].to(device),
|
183 |
+
)
|
184 |
+
|
185 |
+
ps.extend(list(torch.unbind(doc.to("cpu"))))
|
186 |
+
|
187 |
+
if "query_input_ids" in batch:
|
188 |
+
query = self.model(
|
189 |
+
input_ids=batch["query_input_ids"].to(device),
|
190 |
+
attention_mask=batch["query_attention_mask"].to(device),
|
191 |
+
)
|
192 |
+
# variable len
|
193 |
+
qs.extend(list(torch.unbind(query.to("cpu"))))
|
194 |
+
|
195 |
+
print("Embeddings computed, evaluating")
|
196 |
+
scores = self.retriever_evaluator.evaluate(qs, ps)
|
197 |
+
# scores is 2d array of shape (n_queries, n_docs)
|
198 |
+
# turn it into a dict
|
199 |
+
results = {}
|
200 |
+
assert scores.shape[0] == len(qsidx_2_query)
|
201 |
+
for idx, scores_per_query in enumerate(scores):
|
202 |
+
results[qsidx_2_query[idx]] = {
|
203 |
+
docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query)
|
204 |
+
}
|
205 |
+
|
206 |
+
# evaluate
|
207 |
+
metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results)
|
208 |
+
print(metrics)
|
209 |
+
return metrics
|
210 |
+
|
211 |
+
def eval(self) -> None:
|
212 |
+
|
213 |
+
print("Evaluating on validation set")
|
214 |
+
metrics = self.eval_dataset(self.dataset["test"])
|
215 |
+
print(f"Metrics for validation set: {metrics}")
|
216 |
+
all_metrics = {"validation_set": metrics}
|
217 |
+
|
218 |
+
if self.config.eval_dataset_loader is not None:
|
219 |
+
for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
|
220 |
+
print(f"Evaluating {test_name}")
|
221 |
+
test_ds = test_dataset_loading_func()
|
222 |
+
metrics = self.eval_dataset(test_ds)
|
223 |
+
all_metrics[test_name] = metrics
|
224 |
+
print(f"Metrics for {test_name}: {metrics}")
|
225 |
+
|
226 |
+
# checkpoint dumps
|
227 |
+
with open(f"{self.config.output_dir}/results.json", "w") as f:
|
228 |
+
json.dump(all_metrics, f)
|
229 |
+
|
230 |
+
# save results as json
|
231 |
+
with open(f"{self.config.output_dir}/results.json", "w") as f:
|
232 |
+
json.dump(all_metrics, f)
|
233 |
+
|
234 |
+
def save(self, config_file):
|
235 |
+
# save model
|
236 |
+
self.model.save_pretrained(self.config.output_dir)
|
237 |
+
if self.config.tokenizer is not None:
|
238 |
+
self.config.tokenizer.save_pretrained(self.config.output_dir)
|
239 |
+
if self.config.processor is not None:
|
240 |
+
self.config.processor.save_pretrained(self.config.output_dir) # save config
|
241 |
+
|
242 |
+
# copy-paste the yml file with os
|
243 |
+
os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml")
|
244 |
+
|
245 |
+
# save git hash of the commit at beginning of training
|
246 |
+
with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
|
247 |
+
f.write(self.current_git_hash)
|
colpali-main/colpali_engine/utils/wrapper.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
from colpali_engine.models.clip_baselines import ColSigLIP, SigLIP
|
4 |
+
from colpali_engine.models.colbert_architectures import (
|
5 |
+
BiBERT,
|
6 |
+
BiXLMRoBERTa,
|
7 |
+
ColBERT,
|
8 |
+
ColCamembert,
|
9 |
+
ColLlama,
|
10 |
+
ColXLMRoBERTa,
|
11 |
+
)
|
12 |
+
from colpali_engine.models.idefics_colbert_architecture import BiIdefics, ColIdefics
|
13 |
+
from colpali_engine.models.paligemma_colbert_architecture import (
|
14 |
+
BiNewSiglip,
|
15 |
+
BiPaliLast,
|
16 |
+
BiPaliMean,
|
17 |
+
ColNewSiglip,
|
18 |
+
ColPali,
|
19 |
+
)
|
20 |
+
|
21 |
+
if importlib.util.find_spec("transformers") is not None:
|
22 |
+
from transformers import AutoProcessor, AutoTokenizer
|
23 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
24 |
+
|
25 |
+
class AutoProcessorWrapper:
|
26 |
+
def __new__(cls, *args, **kwargs):
|
27 |
+
return AutoProcessor.from_pretrained(*args, **kwargs)
|
28 |
+
|
29 |
+
class AutoTokenizerWrapper(PreTrainedTokenizer):
|
30 |
+
def __new__(cls, *args, **kwargs):
|
31 |
+
return AutoTokenizer.from_pretrained(*args, **kwargs)
|
32 |
+
|
33 |
+
class AutoColModelWrapper:
|
34 |
+
def __new__(cls, *args, **kwargs):
|
35 |
+
pretrained_model_name_or_path = None
|
36 |
+
if args:
|
37 |
+
pretrained_model_name_or_path = args[0]
|
38 |
+
elif kwargs:
|
39 |
+
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
40 |
+
|
41 |
+
training_objective = kwargs.pop("training_objective", "colbertv1")
|
42 |
+
|
43 |
+
if "camembert" in pretrained_model_name_or_path:
|
44 |
+
return ColCamembert.from_pretrained(*args, **kwargs)
|
45 |
+
elif "xlm-roberta" in pretrained_model_name_or_path:
|
46 |
+
if training_objective == "biencoder":
|
47 |
+
return BiXLMRoBERTa.from_pretrained(*args, **kwargs)
|
48 |
+
return ColXLMRoBERTa.from_pretrained(*args, **kwargs)
|
49 |
+
elif (
|
50 |
+
"llama" in pretrained_model_name_or_path.lower() or "croissant" in pretrained_model_name_or_path.lower()
|
51 |
+
):
|
52 |
+
return ColLlama.from_pretrained(*args, **kwargs)
|
53 |
+
elif "idefics2" in pretrained_model_name_or_path:
|
54 |
+
if training_objective == "biencoder":
|
55 |
+
return BiIdefics.from_pretrained(*args, **kwargs)
|
56 |
+
return ColIdefics.from_pretrained(*args, **kwargs)
|
57 |
+
elif "siglip" in pretrained_model_name_or_path:
|
58 |
+
if training_objective == "biencoder_mean":
|
59 |
+
return SigLIP.from_pretrained(*args, **kwargs)
|
60 |
+
elif training_objective == "colbertv1":
|
61 |
+
return ColSigLIP.from_pretrained(*args, **kwargs)
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Training objective {training_objective} not recognized")
|
64 |
+
elif "paligemma" in pretrained_model_name_or_path:
|
65 |
+
if training_objective == "biencoder_mean":
|
66 |
+
return BiPaliMean.from_pretrained(*args, **kwargs)
|
67 |
+
elif training_objective == "biencoder_last":
|
68 |
+
return BiPaliLast.from_pretrained(*args, **kwargs)
|
69 |
+
elif training_objective == "biencoder_mean_vision":
|
70 |
+
return BiNewSiglip.from_pretrained(*args, **kwargs)
|
71 |
+
elif training_objective == "colbertv1_vision":
|
72 |
+
return ColNewSiglip.from_pretrained(*args, **kwargs)
|
73 |
+
elif training_objective == "colbertv1":
|
74 |
+
return ColPali.from_pretrained(*args, **kwargs)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Training objective {training_objective} not recognized")
|
77 |
+
else:
|
78 |
+
if training_objective == "biencoder":
|
79 |
+
return BiBERT.from_pretrained(*args, **kwargs)
|
80 |
+
return ColBERT.from_pretrained(*args, **kwargs)
|
81 |
+
|
82 |
+
else:
|
83 |
+
raise ModuleNotFoundError("Transformers must be loaded")
|
colpali-main/demo/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: cvquest-colpali
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.39.0
|
6 |
+
---
|