Spaces:
Running
Running
HUANG-Stephanie
commited on
Commit
•
c4d37d5
1
Parent(s):
d106c36
Delete colpali_engine
Browse files- colpali_engine/__init__.py +0 -0
- colpali_engine/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali_engine/dataset/__init__.py +0 -0
- colpali_engine/dataset/custom_collator.py +0 -244
- colpali_engine/dataset/hf_dataset_names.py +0 -52
- colpali_engine/evaluation/__init__.py +0 -1
- colpali_engine/evaluation/eval_manager.py +0 -178
- colpali_engine/interpretability/__init__.py +0 -4
- colpali_engine/interpretability/gen_interpretability_plots.py +0 -113
- colpali_engine/interpretability/plot_utils.py +0 -131
- colpali_engine/interpretability/processor.py +0 -116
- colpali_engine/interpretability/torch_utils.py +0 -60
- colpali_engine/interpretability/vit_configs.py +0 -23
- colpali_engine/loss/__init__.py +0 -1
- colpali_engine/loss/colbert_loss.py +0 -122
- colpali_engine/models/__init__.py +0 -0
- colpali_engine/models/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc +0 -0
- colpali_engine/models/clip_baselines.py +0 -144
- colpali_engine/models/colbert_architectures.py +0 -177
- colpali_engine/models/idefics_colbert_architecture.py +0 -57
- colpali_engine/models/paligemma_colbert_architecture.py +0 -191
- colpali_engine/trainer/__init__.py +0 -0
- colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc +0 -0
- colpali_engine/trainer/contrastive_trainer.py +0 -64
- colpali_engine/trainer/retrieval_evaluator.py +0 -72
- colpali_engine/utils/__init__.py +0 -0
- colpali_engine/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc +0 -0
- colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc +0 -0
- colpali_engine/utils/colidefics_processing_utils.py +0 -53
- colpali_engine/utils/colpali_processing_utils.py +0 -36
- colpali_engine/utils/dataset_transformation.py +0 -158
- colpali_engine/utils/gpu_stats.py +0 -24
- colpali_engine/utils/image_from_page_utils.py +0 -21
- colpali_engine/utils/image_utils.py +0 -64
- colpali_engine/utils/iter_utils.py +0 -42
- colpali_engine/utils/pdf_utils.py +0 -87
- colpali_engine/utils/plot_utils.py +0 -6
- colpali_engine/utils/torch_utils.py +0 -18
- colpali_engine/utils/train_colpali_engine_models.py +0 -247
- colpali_engine/utils/wrapper.py +0 -83
colpali_engine/__init__.py
DELETED
File without changes
|
colpali_engine/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (163 Bytes)
|
|
colpali_engine/dataset/__init__.py
DELETED
File without changes
|
colpali_engine/dataset/custom_collator.py
DELETED
@@ -1,244 +0,0 @@
|
|
1 |
-
from transformers import PreTrainedTokenizer, ProcessorMixin
|
2 |
-
|
3 |
-
|
4 |
-
class CustomCollator:
|
5 |
-
def __init__(
|
6 |
-
self,
|
7 |
-
processor: ProcessorMixin = None,
|
8 |
-
tokenizer: PreTrainedTokenizer = None,
|
9 |
-
max_length: int = 2048,
|
10 |
-
add_suffix: bool = False,
|
11 |
-
):
|
12 |
-
self.processor = processor
|
13 |
-
self.tokenizer = tokenizer
|
14 |
-
self.image_token_id = None
|
15 |
-
self.max_length = max_length
|
16 |
-
self.suffix = ""
|
17 |
-
if add_suffix:
|
18 |
-
self.suffix = "\n" * 10
|
19 |
-
|
20 |
-
if tokenizer is None and processor is None:
|
21 |
-
raise ValueError("Either processor or tokenizer should be provided.")
|
22 |
-
|
23 |
-
if self.processor is not None:
|
24 |
-
if self.processor.__class__.__name__ != "SiglipProcessor":
|
25 |
-
self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
|
26 |
-
self.processor.tokenizer.additional_special_tokens.index("<image>")
|
27 |
-
]
|
28 |
-
|
29 |
-
if self.tokenizer is not None:
|
30 |
-
raise ValueError("Only one of processor or tokenizer should be provided.")
|
31 |
-
|
32 |
-
if self.tokenizer and self.tokenizer.pad_token is None:
|
33 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
34 |
-
|
35 |
-
def __call__(self, examples):
|
36 |
-
if self.processor is None:
|
37 |
-
return self.forward_text(examples)
|
38 |
-
if self.processor.__class__.__name__ == "Idefics2Processor":
|
39 |
-
return self.forward_vision_idefics(examples)
|
40 |
-
if self.processor.__class__.__name__ == "PaliGemmaProcessor":
|
41 |
-
return self.forward_vision_pali(examples)
|
42 |
-
if self.processor.__class__.__name__ == "SiglipProcessor":
|
43 |
-
return self.forward_vision_siglip(examples)
|
44 |
-
raise ValueError("Processor not supported")
|
45 |
-
|
46 |
-
def forward_text(self, examples):
|
47 |
-
texts_doc = []
|
48 |
-
texts_query = []
|
49 |
-
for example in examples:
|
50 |
-
text_query = example["query"] + self.suffix
|
51 |
-
text_doc = example["doc"]
|
52 |
-
|
53 |
-
texts_doc.append(text_doc.strip())
|
54 |
-
texts_query.append(text_query.strip())
|
55 |
-
|
56 |
-
batch_doc = self.tokenizer(
|
57 |
-
texts_doc, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
|
58 |
-
)
|
59 |
-
batch_query = self.tokenizer(
|
60 |
-
texts_query, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
|
61 |
-
)
|
62 |
-
|
63 |
-
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
64 |
-
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
65 |
-
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
66 |
-
batch_doc.update(batch_query)
|
67 |
-
|
68 |
-
return batch_doc
|
69 |
-
|
70 |
-
def forward_vision_idefics(self, examples):
|
71 |
-
texts_doc = []
|
72 |
-
texts_query = []
|
73 |
-
images = []
|
74 |
-
for example in examples:
|
75 |
-
image = example["image"]
|
76 |
-
|
77 |
-
text_query = None
|
78 |
-
if example["query"] is not None:
|
79 |
-
query = example["query"]
|
80 |
-
messages_query = [
|
81 |
-
{
|
82 |
-
"role": "user",
|
83 |
-
"content": [
|
84 |
-
{
|
85 |
-
"type": "text",
|
86 |
-
"text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
|
87 |
-
},
|
88 |
-
],
|
89 |
-
},
|
90 |
-
]
|
91 |
-
text_query = self.processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
|
92 |
-
|
93 |
-
messages_doc = [
|
94 |
-
{
|
95 |
-
"role": "user",
|
96 |
-
"content": [
|
97 |
-
{"type": "text", "text": "Describe the image."},
|
98 |
-
{"type": "image"},
|
99 |
-
],
|
100 |
-
},
|
101 |
-
]
|
102 |
-
|
103 |
-
text_doc = self.processor.apply_chat_template(messages_doc, add_generation_prompt=False)
|
104 |
-
|
105 |
-
texts_doc.append(text_doc.strip())
|
106 |
-
texts_query.append(text_query)
|
107 |
-
images.append([image])
|
108 |
-
|
109 |
-
batch_doc = self.processor(
|
110 |
-
text=texts_doc, images=images, return_tensors="pt", padding="longest", max_length=self.max_length
|
111 |
-
)
|
112 |
-
|
113 |
-
batch_query = None
|
114 |
-
if all([t is None for t in texts_query]):
|
115 |
-
print("All queries are None. Returning None for all queries.")
|
116 |
-
elif any([t is None for t in texts_query]):
|
117 |
-
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
118 |
-
else:
|
119 |
-
batch_query = self.processor(
|
120 |
-
text=texts_query, return_tensors="pt", padding="longest", max_length=self.max_length
|
121 |
-
)
|
122 |
-
|
123 |
-
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
124 |
-
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
125 |
-
|
126 |
-
if batch_query is not None:
|
127 |
-
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
128 |
-
batch_doc.update(batch_query)
|
129 |
-
|
130 |
-
return batch_doc
|
131 |
-
|
132 |
-
def forward_vision_pali(self, examples):
|
133 |
-
texts_doc = []
|
134 |
-
texts_query = []
|
135 |
-
images = []
|
136 |
-
for example in examples:
|
137 |
-
|
138 |
-
if example["image"] is None:
|
139 |
-
raise ValueError("Image is None - This collator does not support None images yet.")
|
140 |
-
|
141 |
-
image = example["image"].convert("RGB")
|
142 |
-
images.append(image)
|
143 |
-
texts_doc.append("Describe the image.")
|
144 |
-
|
145 |
-
if example["query"] is None:
|
146 |
-
texts_query.append(None)
|
147 |
-
else:
|
148 |
-
query = example["query"]
|
149 |
-
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
|
150 |
-
texts_query.append(query)
|
151 |
-
|
152 |
-
batch_doc = self.processor(
|
153 |
-
text=texts_doc,
|
154 |
-
images=images,
|
155 |
-
return_tensors="pt",
|
156 |
-
padding="longest",
|
157 |
-
max_length=self.max_length + self.processor.image_seq_length,
|
158 |
-
)
|
159 |
-
|
160 |
-
batch_query = None
|
161 |
-
# check if some but not all queries are None
|
162 |
-
if all([t is None for t in texts_query]):
|
163 |
-
print("All queries are None. Returning None for all queries.")
|
164 |
-
elif any([t is None for t in texts_query]):
|
165 |
-
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
166 |
-
else:
|
167 |
-
batch_query = self.processor(
|
168 |
-
images=images, # NOTE: the image is not used in batch_query but it is required for calling the processor
|
169 |
-
text=texts_query,
|
170 |
-
return_tensors="pt",
|
171 |
-
padding="longest",
|
172 |
-
max_length=self.max_length + self.processor.image_seq_length,
|
173 |
-
)
|
174 |
-
del batch_query["pixel_values"]
|
175 |
-
batch_query["input_ids"] = batch_query["input_ids"][..., self.processor.image_seq_length :]
|
176 |
-
batch_query["attention_mask"] = batch_query["attention_mask"][..., self.processor.image_seq_length :]
|
177 |
-
|
178 |
-
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
179 |
-
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
180 |
-
|
181 |
-
if batch_query is not None:
|
182 |
-
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
183 |
-
batch_doc.update(batch_query)
|
184 |
-
|
185 |
-
return batch_doc
|
186 |
-
|
187 |
-
def forward_vision_siglip(self, examples):
|
188 |
-
texts_doc = []
|
189 |
-
texts_query = []
|
190 |
-
images = []
|
191 |
-
for example in examples:
|
192 |
-
|
193 |
-
if example["image"] is None:
|
194 |
-
raise ValueError("Image is None - This collator does not support None images yet.")
|
195 |
-
|
196 |
-
image = example["image"].convert("RGB")
|
197 |
-
images.append(image)
|
198 |
-
texts_doc.append("Describe the image.")
|
199 |
-
|
200 |
-
if example["query"] is None:
|
201 |
-
texts_query.append(None)
|
202 |
-
else:
|
203 |
-
query = f"Question: {example['query']}"
|
204 |
-
texts_query.append(query)
|
205 |
-
|
206 |
-
batch_doc = self.processor(
|
207 |
-
text=texts_doc,
|
208 |
-
images=images,
|
209 |
-
return_tensors="pt",
|
210 |
-
padding="max_length",
|
211 |
-
truncation=True,
|
212 |
-
)
|
213 |
-
|
214 |
-
batch_query = None
|
215 |
-
# check if some but not all queries are None
|
216 |
-
if all([t is None for t in texts_query]):
|
217 |
-
# print("All queries are None.")
|
218 |
-
pass
|
219 |
-
elif any([t is None for t in texts_query]):
|
220 |
-
raise ValueError("Some queries are None. This collator does not support None queries yet.")
|
221 |
-
else:
|
222 |
-
batch_query = self.processor(
|
223 |
-
images=images,
|
224 |
-
text=texts_query,
|
225 |
-
return_tensors="pt",
|
226 |
-
padding="max_length",
|
227 |
-
max_length=self.max_length,
|
228 |
-
truncation=True,
|
229 |
-
)
|
230 |
-
del batch_query["pixel_values"]
|
231 |
-
|
232 |
-
# prefix each key with "doc_" or "query_" to avoid key conflicts
|
233 |
-
batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
|
234 |
-
|
235 |
-
if batch_query is not None:
|
236 |
-
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
|
237 |
-
batch_doc.update(batch_query)
|
238 |
-
# add attention mask for queries
|
239 |
-
batch_doc["query_attention_mask"] = batch_doc["query_input_ids"].ne(0).long()
|
240 |
-
|
241 |
-
# add attention mask for docs
|
242 |
-
batch_doc["doc_attention_mask"] = batch_doc["doc_input_ids"].ne(0).long()
|
243 |
-
|
244 |
-
return batch_doc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/dataset/hf_dataset_names.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
from enum import Enum
|
2 |
-
|
3 |
-
|
4 |
-
class TrainDatasets(Enum):
|
5 |
-
"""
|
6 |
-
Dataset names for the training datasets used in HuggingFace Datasets.
|
7 |
-
"""
|
8 |
-
|
9 |
-
government_reports = "vidore/syntheticDocQA_government_reports_train"
|
10 |
-
healthcare_industry = "vidore/syntheticDocQA_healthcare_industry_train"
|
11 |
-
energy = "vidore/syntheticDocQA_energy_train"
|
12 |
-
artificial_intelligence = "vidore/syntheticDocQA_artificial_intelligence_train"
|
13 |
-
arxivqa = "vidore/arxivqa_train"
|
14 |
-
docvqa = "vidore/docvqa_train"
|
15 |
-
infovqa = "vidore/infovqa_train"
|
16 |
-
tatqa = "vidore/tatqa_train"
|
17 |
-
|
18 |
-
@staticmethod
|
19 |
-
def get_synthetic_datasets():
|
20 |
-
return [
|
21 |
-
TrainDatasets.government_reports,
|
22 |
-
TrainDatasets.healthcare_industry,
|
23 |
-
TrainDatasets.energy,
|
24 |
-
TrainDatasets.artificial_intelligence,
|
25 |
-
]
|
26 |
-
|
27 |
-
|
28 |
-
class TestImagesDirpath(Enum):
|
29 |
-
"""
|
30 |
-
Dataset names for the test datasets used in HuggingFace Datasets.
|
31 |
-
"""
|
32 |
-
|
33 |
-
government_reports = "data/government_reports"
|
34 |
-
healthcare_industry = "data/healthcare_industry"
|
35 |
-
energy = "data/energy"
|
36 |
-
artificial_intelligence = "data/scrapped_pdfs_split/pages_extracted/artificial_intelligence_test"
|
37 |
-
arxivqa = "data/arxivqa"
|
38 |
-
docvqa = "data/docvqa"
|
39 |
-
infovqa = "data/infovqa"
|
40 |
-
tatqa = "data/tatqa"
|
41 |
-
|
42 |
-
|
43 |
-
class CaptionedSyntheticDatasets(Enum):
|
44 |
-
"""
|
45 |
-
Dataset names for the captioned synthetic datasets used in HuggingFace Datasets.
|
46 |
-
"""
|
47 |
-
|
48 |
-
shift = "vidore/baseline_cap_shiftproject_test"
|
49 |
-
|
50 |
-
|
51 |
-
class SyntheticDocQATest(Enum):
|
52 |
-
shift = "vidore/shiftproject_test"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/evaluation/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .eval_manager import EvalManager
|
|
|
|
colpali_engine/evaluation/eval_manager.py
DELETED
@@ -1,178 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Any, ClassVar, Dict, Optional
|
5 |
-
|
6 |
-
import pandas as pd
|
7 |
-
|
8 |
-
|
9 |
-
class EvalManager:
|
10 |
-
"""
|
11 |
-
Stores evaluation results for various datasets and metrics.
|
12 |
-
|
13 |
-
The data is stored in a pandas DataFrame with a MultiIndex for columns.
|
14 |
-
The first level of the MultiIndex is the dataset name and the second level is the metric name.
|
15 |
-
|
16 |
-
Usage:
|
17 |
-
>>> evaluator = Evaluator.from_dirpath("data/evaluation_results/")
|
18 |
-
>>> print(evaluator.data)
|
19 |
-
|
20 |
-
"""
|
21 |
-
|
22 |
-
model_col: ClassVar[str] = "model"
|
23 |
-
dataset_col: ClassVar[str] = "dataset"
|
24 |
-
metric_col: ClassVar[str] = "metric"
|
25 |
-
|
26 |
-
def __init__(self, data: Optional[pd.DataFrame] = None):
|
27 |
-
if data is None:
|
28 |
-
data = pd.DataFrame()
|
29 |
-
self._df = data
|
30 |
-
self._df.index = self._df.index.rename(EvalManager.model_col)
|
31 |
-
|
32 |
-
def __str__(self) -> str:
|
33 |
-
return self.data.__str__()
|
34 |
-
|
35 |
-
@staticmethod
|
36 |
-
def from_dict(data: Dict[Any, Any]) -> EvalManager:
|
37 |
-
"""
|
38 |
-
Load evaluation results from a dictionary.
|
39 |
-
|
40 |
-
Expected format:
|
41 |
-
{
|
42 |
-
"model1": pd.read_json(path1).T.stack(),
|
43 |
-
"model2": pd.read_json(path2).T.stack(),
|
44 |
-
}
|
45 |
-
|
46 |
-
"""
|
47 |
-
df = pd.DataFrame.from_dict(data, orient="index")
|
48 |
-
return EvalManager(df)
|
49 |
-
|
50 |
-
@staticmethod
|
51 |
-
def from_json(path: str | Path) -> EvalManager:
|
52 |
-
datapath = Path(path)
|
53 |
-
if not datapath.is_file():
|
54 |
-
raise FileNotFoundError(f"{path} is not a file")
|
55 |
-
data = {}
|
56 |
-
data[datapath.stem] = pd.read_json(datapath).T.stack() # pylint: disable=no-member
|
57 |
-
return EvalManager.from_dict(data)
|
58 |
-
|
59 |
-
@staticmethod
|
60 |
-
def from_dir(datadir: str | Path) -> EvalManager:
|
61 |
-
datadir_ = Path(datadir)
|
62 |
-
if not datadir_.is_dir():
|
63 |
-
raise FileNotFoundError(f"{datadir} is not a directory")
|
64 |
-
|
65 |
-
eval_files = list(datadir_.glob("*.json"))
|
66 |
-
|
67 |
-
data = {}
|
68 |
-
|
69 |
-
for filepath in eval_files:
|
70 |
-
data[filepath.stem] = pd.read_json(filepath).T.stack() # pylint: disable=no-member
|
71 |
-
|
72 |
-
return EvalManager.from_dict(data)
|
73 |
-
|
74 |
-
@staticmethod
|
75 |
-
def from_csv(path: str | Path) -> EvalManager:
|
76 |
-
"""
|
77 |
-
Load evaluation results from a CSV file.
|
78 |
-
"""
|
79 |
-
try:
|
80 |
-
df = pd.read_csv(path, index_col=0, header=[0, 1])
|
81 |
-
return EvalManager(df)
|
82 |
-
except Exception as e:
|
83 |
-
print(f"Error loading {path}: {e}")
|
84 |
-
raise e
|
85 |
-
|
86 |
-
@property
|
87 |
-
def data(self) -> pd.DataFrame:
|
88 |
-
"""
|
89 |
-
Returns the evaluation results as a pandas DataFrame.
|
90 |
-
"""
|
91 |
-
return self._df.copy()
|
92 |
-
|
93 |
-
@property
|
94 |
-
def models(self) -> pd.Index:
|
95 |
-
"""
|
96 |
-
Returns the models for which there are evaluation results.
|
97 |
-
"""
|
98 |
-
return self.data.index
|
99 |
-
|
100 |
-
@property
|
101 |
-
def datasets(self) -> pd.Index:
|
102 |
-
"""
|
103 |
-
Returns the datasets for which there are evaluation results.
|
104 |
-
"""
|
105 |
-
return self.data.columns.get_level_values(0).unique()
|
106 |
-
|
107 |
-
@property
|
108 |
-
def metrics(self) -> pd.Index:
|
109 |
-
"""
|
110 |
-
Returns the metrics for which there are evaluation results.
|
111 |
-
"""
|
112 |
-
return self.data.columns.get_level_values(1)
|
113 |
-
|
114 |
-
@staticmethod
|
115 |
-
def melt(df: pd.DataFrame) -> pd.DataFrame:
|
116 |
-
"""
|
117 |
-
Melt a suitable DataFrame (e.g. returned by `get_df_for_dataset` and
|
118 |
-
`get_df_for_metric`) into a 'long' format.
|
119 |
-
"""
|
120 |
-
return df.T.reset_index(names=[EvalManager.dataset_col, EvalManager.metric_col]).melt(
|
121 |
-
id_vars=[EvalManager.dataset_col, EvalManager.metric_col],
|
122 |
-
var_name=EvalManager.model_col,
|
123 |
-
value_name="score",
|
124 |
-
)
|
125 |
-
|
126 |
-
@property
|
127 |
-
def melted(self) -> pd.DataFrame:
|
128 |
-
"""
|
129 |
-
Returns the evaluation results as a 'melted' DataFrame.
|
130 |
-
Useful for plotting with seaborn.
|
131 |
-
"""
|
132 |
-
return EvalManager.melt(self.data)
|
133 |
-
|
134 |
-
def get_df_for_model(self, model: str) -> pd.DataFrame:
|
135 |
-
if model not in self.data.index:
|
136 |
-
raise ValueError(f"Model {model} not found in the evaluation results")
|
137 |
-
return self.data.loc[[model], :] # type: ignore
|
138 |
-
|
139 |
-
def get_df_for_dataset(self, dataset: str) -> pd.DataFrame:
|
140 |
-
if dataset not in self.datasets:
|
141 |
-
raise ValueError(f"Dataset {dataset} not found in the evaluation results")
|
142 |
-
return self.data.loc[:, (dataset, slice(None))] # type: ignore
|
143 |
-
|
144 |
-
def get_df_for_metric(self, metric: str) -> pd.DataFrame:
|
145 |
-
if metric not in self.metrics:
|
146 |
-
raise ValueError(f"Metric {metric} not found in the evaluation results")
|
147 |
-
return self.data.loc[:, (slice(None), metric)] # type: ignore
|
148 |
-
|
149 |
-
def sort_by_dataset(self, ascending: bool = True) -> EvalManager:
|
150 |
-
"""
|
151 |
-
Sort the evaluation results by dataset name.
|
152 |
-
"""
|
153 |
-
df = self.data.T.sort_index(level=0, ascending=ascending).T
|
154 |
-
return EvalManager(df)
|
155 |
-
|
156 |
-
def sort_by_metric(self, ascending: bool = True) -> EvalManager:
|
157 |
-
"""
|
158 |
-
Sort the evaluation results by metric name.
|
159 |
-
"""
|
160 |
-
df = self.data.T.sort_index(level=1, ascending=ascending).T
|
161 |
-
return EvalManager(df)
|
162 |
-
|
163 |
-
def sort_columns(self, ascending: bool = True) -> EvalManager:
|
164 |
-
"""
|
165 |
-
Sort the evaluation results by dataset name and then by metric name.
|
166 |
-
"""
|
167 |
-
df = self.data.T.sort_index(level=[0, 1], ascending=ascending).T
|
168 |
-
return EvalManager(df)
|
169 |
-
|
170 |
-
def to_csv(self, path: str | Path):
|
171 |
-
"""
|
172 |
-
Save the evaluation results to a CSV file.
|
173 |
-
|
174 |
-
Using `Evaluation.from_csv(path_to_saved_csv)` will load the evaluation results back into memory.
|
175 |
-
"""
|
176 |
-
savepath = Path(path)
|
177 |
-
savepath.parent.mkdir(parents=True, exist_ok=True)
|
178 |
-
self.data.to_csv(savepath)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from .plot_utils import *
|
2 |
-
from .processor import *
|
3 |
-
from .torch_utils import *
|
4 |
-
from .vit_configs import *
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/gen_interpretability_plots.py
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
import pprint
|
2 |
-
from dataclasses import asdict, dataclass
|
3 |
-
from pathlib import Path
|
4 |
-
from uuid import uuid4
|
5 |
-
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import torch
|
8 |
-
from einops import rearrange
|
9 |
-
from PIL import Image
|
10 |
-
from tqdm import trange
|
11 |
-
|
12 |
-
from colpali_engine.interpretability.plot_utils import plot_patches
|
13 |
-
from colpali_engine.interpretability.processor import ColPaliProcessor
|
14 |
-
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
|
15 |
-
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
|
16 |
-
from colpali_engine.models.paligemma_colbert_architecture import ColPali
|
17 |
-
|
18 |
-
OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
|
19 |
-
|
20 |
-
|
21 |
-
@dataclass
|
22 |
-
class InterpretabilityInput:
|
23 |
-
query: str
|
24 |
-
image: Image.Image
|
25 |
-
start_idx_token: int
|
26 |
-
end_idx_token: int
|
27 |
-
|
28 |
-
|
29 |
-
def generate_interpretability_plots(
|
30 |
-
model: ColPali,
|
31 |
-
processor: ColPaliProcessor,
|
32 |
-
query: str,
|
33 |
-
image: Image.Image,
|
34 |
-
savedir: str | Path | None = None,
|
35 |
-
add_special_prompt_to_doc: bool = True,
|
36 |
-
) -> None:
|
37 |
-
|
38 |
-
# Sanity checks
|
39 |
-
if len(model.active_adapters()) != 1:
|
40 |
-
raise ValueError("The model must have exactly one active adapter.")
|
41 |
-
|
42 |
-
if model.config.name_or_path not in VIT_CONFIG:
|
43 |
-
raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
|
44 |
-
vit_config = VIT_CONFIG[model.config.name_or_path]
|
45 |
-
|
46 |
-
# Handle savepath
|
47 |
-
if not savedir:
|
48 |
-
savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
|
49 |
-
print(f"No savepath provided. Results will be saved to: `{savedir}`.")
|
50 |
-
elif isinstance(savedir, str):
|
51 |
-
savedir = Path(savedir)
|
52 |
-
savedir.mkdir(parents=True, exist_ok=True)
|
53 |
-
|
54 |
-
# Resize the image to square
|
55 |
-
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
|
56 |
-
|
57 |
-
# Preprocess the inputs
|
58 |
-
input_text_processed = processor.process_text(query).to(model.device)
|
59 |
-
input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
|
60 |
-
model.device
|
61 |
-
)
|
62 |
-
|
63 |
-
# Forward pass
|
64 |
-
with torch.no_grad():
|
65 |
-
output_text = model.forward(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
|
66 |
-
|
67 |
-
# NOTE: `output_image`` will have shape:
|
68 |
-
# (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
|
69 |
-
# (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
|
70 |
-
with torch.no_grad():
|
71 |
-
output_image = model.forward(**asdict(input_image_processed))
|
72 |
-
|
73 |
-
if add_special_prompt_to_doc: # remove the special tokens
|
74 |
-
output_image = output_image[
|
75 |
-
:, : processor.processor.image_seq_length, :
|
76 |
-
] # (1, n_patch_x * n_patch_y, hidden_dim)
|
77 |
-
|
78 |
-
output_image = rearrange(
|
79 |
-
output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
|
80 |
-
) # (1, n_patch_x, n_patch_y, hidden_dim)
|
81 |
-
|
82 |
-
# Get the unnormalized attention map
|
83 |
-
attention_map = torch.einsum(
|
84 |
-
"bnk,bijk->bnij", output_text, output_image
|
85 |
-
) # (1, n_text_tokens, n_patch_x, n_patch_y)
|
86 |
-
attention_map_normalized = normalize_attention_map_per_query_token(
|
87 |
-
attention_map
|
88 |
-
) # (1, n_text_tokens, n_patch_x, n_patch_y)
|
89 |
-
attention_map_normalized = attention_map_normalized.float()
|
90 |
-
|
91 |
-
# Get text token information
|
92 |
-
n_tokens = input_text_processed.input_ids.size(1)
|
93 |
-
text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
|
94 |
-
print("Text tokens:")
|
95 |
-
pprint.pprint(text_tokens)
|
96 |
-
print("\n")
|
97 |
-
|
98 |
-
for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."): # exclude the <bos> and the "\n" tokens
|
99 |
-
fig, axis = plot_patches(
|
100 |
-
input_image_square,
|
101 |
-
vit_config.patch_size,
|
102 |
-
vit_config.resolution,
|
103 |
-
patch_opacities=attention_map_normalized[0, token_idx, :, :],
|
104 |
-
style="dark_background",
|
105 |
-
)
|
106 |
-
|
107 |
-
fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
|
108 |
-
savepath = savedir / f"token_{token_idx}.png"
|
109 |
-
fig.savefig(savepath)
|
110 |
-
print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
|
111 |
-
plt.close(fig)
|
112 |
-
|
113 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/plot_utils.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, Optional, Tuple, cast
|
2 |
-
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
-
import numpy as np
|
5 |
-
import numpy.typing as npt
|
6 |
-
import seaborn as sns
|
7 |
-
import torch
|
8 |
-
from PIL import Image
|
9 |
-
|
10 |
-
MAX_OPACITY = 255
|
11 |
-
|
12 |
-
|
13 |
-
def plot_patches(
|
14 |
-
img: Image.Image,
|
15 |
-
patch_size: int,
|
16 |
-
image_resolution: int,
|
17 |
-
patch_opacities: Optional[npt.NDArray | torch.Tensor] = None,
|
18 |
-
figsize: Tuple[int, int] = (8, 8),
|
19 |
-
style: Dict[str, Any] | str | None = None,
|
20 |
-
) -> Tuple[plt.Figure, plt.Axes]:
|
21 |
-
"""
|
22 |
-
Plot patches of a square image.
|
23 |
-
Set `style` to "dark_background" if your image has a light background.
|
24 |
-
"""
|
25 |
-
|
26 |
-
# Get the number of patches
|
27 |
-
if image_resolution % patch_size != 0:
|
28 |
-
raise ValueError("The image resolution must be divisible by the patch size.")
|
29 |
-
num_patches = image_resolution // patch_size
|
30 |
-
|
31 |
-
# Default style
|
32 |
-
if style is None:
|
33 |
-
style = {}
|
34 |
-
|
35 |
-
# Sanity checks
|
36 |
-
if patch_opacities is not None:
|
37 |
-
if isinstance(patch_opacities, torch.Tensor):
|
38 |
-
patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy())
|
39 |
-
if patch_opacities.shape != (num_patches, num_patches):
|
40 |
-
raise ValueError("The shape of the patch_opacities tensor is not correct.")
|
41 |
-
if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)):
|
42 |
-
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
|
43 |
-
|
44 |
-
# If the image is not square, raise an error
|
45 |
-
if img.size[0] != img.size[1]:
|
46 |
-
raise ValueError("The image must be square.")
|
47 |
-
|
48 |
-
# Get the image as a numpy array
|
49 |
-
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
|
50 |
-
|
51 |
-
# Create a figure
|
52 |
-
with plt.style.context(style):
|
53 |
-
fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)
|
54 |
-
|
55 |
-
# Plot the patches
|
56 |
-
for i in range(num_patches):
|
57 |
-
for j in range(num_patches):
|
58 |
-
patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
|
59 |
-
# Set the opacity of the patch
|
60 |
-
if patch_opacities is not None:
|
61 |
-
patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY)
|
62 |
-
axis[i, j].imshow(patch)
|
63 |
-
axis[i, j].axis("off")
|
64 |
-
|
65 |
-
fig.subplots_adjust(wspace=0.1, hspace=0.1)
|
66 |
-
|
67 |
-
fig.tight_layout()
|
68 |
-
|
69 |
-
return fig, axis
|
70 |
-
|
71 |
-
|
72 |
-
def plot_attention_heatmap(
|
73 |
-
img: Image.Image,
|
74 |
-
patch_size: int,
|
75 |
-
image_resolution: int,
|
76 |
-
attention_map: npt.NDArray | torch.Tensor,
|
77 |
-
figsize: Tuple[int, int] = (8, 8),
|
78 |
-
style: Dict[str, Any] | str | None = None,
|
79 |
-
show_colorbar: bool = False,
|
80 |
-
show_axes: bool = False,
|
81 |
-
) -> Tuple[plt.Figure, plt.Axes]:
|
82 |
-
"""
|
83 |
-
Plot a heatmap of the attention map over the image.
|
84 |
-
The image must be square and `attention_map` must be normalized between 0 and 1.
|
85 |
-
"""
|
86 |
-
|
87 |
-
# Get the number of patches
|
88 |
-
if image_resolution % patch_size != 0:
|
89 |
-
raise ValueError("The image resolution must be divisible by the patch size.")
|
90 |
-
num_patches = image_resolution // patch_size
|
91 |
-
|
92 |
-
# Default style
|
93 |
-
if style is None:
|
94 |
-
style = {}
|
95 |
-
|
96 |
-
# Sanity checks
|
97 |
-
if isinstance(attention_map, torch.Tensor):
|
98 |
-
attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
|
99 |
-
if attention_map.shape != (num_patches, num_patches):
|
100 |
-
raise ValueError("The shape of the patch_opacities tensor is not correct.")
|
101 |
-
if not np.all((0 <= attention_map) & (attention_map <= 1)):
|
102 |
-
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
|
103 |
-
|
104 |
-
# If the image is not square, raise an error
|
105 |
-
if img.size[0] != img.size[1]:
|
106 |
-
raise ValueError("The image must be square.")
|
107 |
-
|
108 |
-
# Get the image as a numpy array
|
109 |
-
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
|
110 |
-
|
111 |
-
# Get the attention map as a numpy array
|
112 |
-
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
|
113 |
-
img.size, Image.Resampling.BICUBIC
|
114 |
-
)
|
115 |
-
|
116 |
-
# Create a figure
|
117 |
-
with plt.style.context(style):
|
118 |
-
fig, ax = plt.subplots(figsize=figsize)
|
119 |
-
ax.imshow(img_array)
|
120 |
-
im = ax.imshow(
|
121 |
-
attention_map_image,
|
122 |
-
cmap=sns.color_palette("mako", as_cmap=True),
|
123 |
-
alpha=0.5,
|
124 |
-
)
|
125 |
-
if show_colorbar:
|
126 |
-
fig.colorbar(im)
|
127 |
-
if not show_axes:
|
128 |
-
ax.set_axis_off()
|
129 |
-
fig.tight_layout()
|
130 |
-
|
131 |
-
return fig, ax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/processor.py
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
from dataclasses import dataclass
|
4 |
-
from typing import List, cast
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from PIL import Image
|
8 |
-
from transformers import LlamaTokenizerFast, PaliGemmaProcessor
|
9 |
-
|
10 |
-
|
11 |
-
@dataclass
|
12 |
-
class ColPaliTextInput:
|
13 |
-
input_ids: torch.Tensor
|
14 |
-
attention_mask: torch.Tensor
|
15 |
-
|
16 |
-
def to(self, device: torch.device) -> ColPaliTextInput:
|
17 |
-
return ColPaliTextInput(
|
18 |
-
input_ids=self.input_ids.to(device),
|
19 |
-
attention_mask=self.attention_mask.to(device),
|
20 |
-
)
|
21 |
-
|
22 |
-
|
23 |
-
@dataclass
|
24 |
-
class ColPaliImageInput:
|
25 |
-
input_ids: torch.Tensor
|
26 |
-
pixel_values: torch.Tensor
|
27 |
-
attention_mask: torch.Tensor
|
28 |
-
|
29 |
-
def to(self, device: str | torch.device) -> ColPaliImageInput:
|
30 |
-
return ColPaliImageInput(
|
31 |
-
input_ids=self.input_ids.to(device),
|
32 |
-
pixel_values=self.pixel_values.to(device),
|
33 |
-
attention_mask=self.attention_mask.to(device),
|
34 |
-
)
|
35 |
-
|
36 |
-
|
37 |
-
class ColPaliProcessor:
|
38 |
-
def __init__(self, processor: PaliGemmaProcessor):
|
39 |
-
self.processor = processor
|
40 |
-
self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore
|
41 |
-
|
42 |
-
@staticmethod
|
43 |
-
def from_pretrained(model_name: str) -> ColPaliProcessor:
|
44 |
-
return ColPaliProcessor(processor=cast(PaliGemmaProcessor, PaliGemmaProcessor.from_pretrained(model_name)))
|
45 |
-
|
46 |
-
def process_text(
|
47 |
-
self,
|
48 |
-
text: str | List[str],
|
49 |
-
padding: str = "longest",
|
50 |
-
return_tensors: str = "pt",
|
51 |
-
add_special_tokens: bool = True,
|
52 |
-
) -> ColPaliTextInput:
|
53 |
-
"""
|
54 |
-
Process text inputs for the model.
|
55 |
-
If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n".
|
56 |
-
"""
|
57 |
-
if add_special_tokens:
|
58 |
-
if isinstance(text, str):
|
59 |
-
text = self.tokenizer.bos_token + text + "\n"
|
60 |
-
elif isinstance(text, list):
|
61 |
-
text = [self.tokenizer.bos_token + t + "\n" for t in text]
|
62 |
-
else:
|
63 |
-
raise ValueError("text must be a string or a list of strings.")
|
64 |
-
|
65 |
-
batch_output = self.tokenizer(
|
66 |
-
text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens
|
67 |
-
)
|
68 |
-
|
69 |
-
return ColPaliTextInput(
|
70 |
-
input_ids=cast(torch.Tensor, batch_output["input_ids"]),
|
71 |
-
attention_mask=cast(torch.Tensor, batch_output["attention_mask"]),
|
72 |
-
)
|
73 |
-
|
74 |
-
def process_image(
|
75 |
-
self,
|
76 |
-
image: Image.Image | List[Image.Image],
|
77 |
-
padding: str = "longest",
|
78 |
-
do_convert_rgb: bool = True,
|
79 |
-
return_tensors: str = "pt",
|
80 |
-
add_special_prompt: bool = True,
|
81 |
-
) -> ColPaliImageInput:
|
82 |
-
# NOTE: The special prompt was used at training time,
|
83 |
-
special_prompt = "Describe the image." if add_special_prompt else None
|
84 |
-
if isinstance(image, Image.Image):
|
85 |
-
text_input = [special_prompt]
|
86 |
-
elif isinstance(image, list):
|
87 |
-
text_input = [special_prompt] * len(image)
|
88 |
-
else:
|
89 |
-
raise ValueError("image must be a PIL Image or a list of PIL Images.")
|
90 |
-
|
91 |
-
batch_output = self.processor(
|
92 |
-
text=text_input,
|
93 |
-
images=image,
|
94 |
-
padding=padding,
|
95 |
-
do_convert_rgb=do_convert_rgb,
|
96 |
-
return_tensors=return_tensors,
|
97 |
-
)
|
98 |
-
|
99 |
-
if add_special_prompt:
|
100 |
-
return ColPaliImageInput(
|
101 |
-
input_ids=batch_output["input_ids"],
|
102 |
-
pixel_values=batch_output["pixel_values"],
|
103 |
-
attention_mask=batch_output["attention_mask"],
|
104 |
-
)
|
105 |
-
else:
|
106 |
-
return ColPaliImageInput(
|
107 |
-
input_ids=batch_output["input_ids"][:, : self.processor.image_seq_length],
|
108 |
-
pixel_values=batch_output["pixel_values"][:, : self.processor.image_seq_length],
|
109 |
-
attention_mask=batch_output["attention_mask"][:, : self.processor.image_seq_length],
|
110 |
-
)
|
111 |
-
|
112 |
-
def decode(self, *args, **kwargs):
|
113 |
-
return self.tokenizer.decode(*args, **kwargs)
|
114 |
-
|
115 |
-
def batch_decode(self, *args, **kwargs):
|
116 |
-
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/torch_utils.py
DELETED
@@ -1,60 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
-
|
7 |
-
EPSILON = 1e-10
|
8 |
-
|
9 |
-
|
10 |
-
def normalize_attention_map_per_query_token(x: torch.Tensor) -> torch.Tensor:
|
11 |
-
"""
|
12 |
-
Normalizes the attention map for ColPali for each query token.
|
13 |
-
The output tensor will have values in the range [0, 1] and the
|
14 |
-
same shape as the input tensor.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
|
18 |
-
"""
|
19 |
-
if x.ndim != 4:
|
20 |
-
raise ValueError("The input tensor must have 4 dimensions.")
|
21 |
-
|
22 |
-
# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
|
23 |
-
min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
|
24 |
-
|
25 |
-
# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
|
26 |
-
max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
|
27 |
-
|
28 |
-
# Normalize the tensor
|
29 |
-
x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
|
30 |
-
|
31 |
-
return x_normalized
|
32 |
-
|
33 |
-
|
34 |
-
def normalize_attention_map_per_query(x: torch.Tensor) -> torch.Tensor:
|
35 |
-
"""
|
36 |
-
Normalizes the attention map for ColPali for each query token.
|
37 |
-
The output tensor will have values in the range [0, 1] and the
|
38 |
-
same shape as the input tensor.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
|
42 |
-
"""
|
43 |
-
# Log warning
|
44 |
-
logger.warning(
|
45 |
-
"This function should not be used for ColPali because it doesn't make sense to normalize the attention map across the text tokens."
|
46 |
-
)
|
47 |
-
|
48 |
-
if x.ndim != 4:
|
49 |
-
raise ValueError("The input tensor must have 4 dimensions.")
|
50 |
-
|
51 |
-
# Compute the minimum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
|
52 |
-
min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0].min(dim=-3, keepdim=True)[0]
|
53 |
-
|
54 |
-
# Compute the maximum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
|
55 |
-
max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0].max(dim=-3, keepdim=True)[0]
|
56 |
-
|
57 |
-
# Normalize the tensor
|
58 |
-
x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
|
59 |
-
|
60 |
-
return x_normalized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/interpretability/vit_configs.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Dict
|
3 |
-
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class ViTConfig:
|
7 |
-
patch_size: int
|
8 |
-
resolution: int
|
9 |
-
|
10 |
-
@property
|
11 |
-
def n_patch_per_dim(self) -> int:
|
12 |
-
if self.resolution % self.patch_size != 0:
|
13 |
-
raise ValueError(f"Resolution {self.resolution} is not divisible by patch size {self.patch_size}")
|
14 |
-
return self.resolution // self.patch_size
|
15 |
-
|
16 |
-
|
17 |
-
VIT_CONFIG: Dict[str, ViTConfig] = {
|
18 |
-
"google/siglip-so400m-patch14-384": ViTConfig(patch_size=14, resolution=384),
|
19 |
-
"timm/ViT-SO400M-14-SigLIP-384": ViTConfig(patch_size=14, resolution=384),
|
20 |
-
"google/paligemma-3b-mix-448": ViTConfig(
|
21 |
-
patch_size=14, resolution=448
|
22 |
-
), # based on "timm/ViT-SO400M-14-SigLIP-384" with increased resolution
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/loss/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .colbert_loss import ColbertLoss
|
|
|
|
colpali_engine/loss/colbert_loss.py
DELETED
@@ -1,122 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
from torch.nn import CrossEntropyLoss
|
4 |
-
|
5 |
-
|
6 |
-
class BiEncoderLoss(torch.nn.Module):
|
7 |
-
def __init__(self):
|
8 |
-
super().__init__()
|
9 |
-
self.ce_loss = CrossEntropyLoss()
|
10 |
-
# self.pooling_strategy = pooling_strategy
|
11 |
-
|
12 |
-
def forward(self, query_embeddings, doc_embeddings):
|
13 |
-
"""
|
14 |
-
query_embeddings: (batch_size, dim)
|
15 |
-
doc_embeddings: (batch_size, dim)
|
16 |
-
"""
|
17 |
-
|
18 |
-
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
|
19 |
-
|
20 |
-
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
|
21 |
-
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
|
22 |
-
# loss = (loss_rowwise + loss_columnwise) / 2
|
23 |
-
return loss_rowwise
|
24 |
-
|
25 |
-
|
26 |
-
class ColbertLoss(torch.nn.Module):
|
27 |
-
def __init__(self):
|
28 |
-
super().__init__()
|
29 |
-
self.ce_loss = CrossEntropyLoss()
|
30 |
-
|
31 |
-
def forward(self, query_embeddings, doc_embeddings):
|
32 |
-
"""
|
33 |
-
query_embeddings: (batch_size, num_query_tokens, dim)
|
34 |
-
doc_embeddings: (batch_size, num_doc_tokens, dim)
|
35 |
-
"""
|
36 |
-
|
37 |
-
scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
|
38 |
-
|
39 |
-
# scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
|
40 |
-
# for i in range(query_embeddings.shape[0]):
|
41 |
-
# for j in range(doc_embeddings.shape[0]):
|
42 |
-
# # step 1 - dot product --> (s1,s2)
|
43 |
-
# q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
|
44 |
-
# # step 2 -> max on doc --> (s1)
|
45 |
-
# q_scores = torch.max(q2d_scores, dim=1)[0]
|
46 |
-
# # step 3 --> sum the max score --> (1)
|
47 |
-
# sum_q_score = torch.sum(q_scores)
|
48 |
-
# # step 4 --> assert is scalar
|
49 |
-
# scores[i, j] = sum_q_score
|
50 |
-
|
51 |
-
# assert (scores_einsum - scores < 0.0001).all().item()
|
52 |
-
|
53 |
-
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
|
54 |
-
# TODO: comparing between queries might not make sense since it's a sum over the length of the query
|
55 |
-
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
|
56 |
-
# loss = (loss_rowwise + loss_columnwise) / 2
|
57 |
-
return loss_rowwise
|
58 |
-
|
59 |
-
|
60 |
-
class ColbertPairwiseCELoss(torch.nn.Module):
|
61 |
-
def __init__(self):
|
62 |
-
super().__init__()
|
63 |
-
self.ce_loss = CrossEntropyLoss()
|
64 |
-
|
65 |
-
def forward(self, query_embeddings, doc_embeddings):
|
66 |
-
"""
|
67 |
-
query_embeddings: (batch_size, num_query_tokens, dim)
|
68 |
-
doc_embeddings: (batch_size, num_doc_tokens, dim)
|
69 |
-
|
70 |
-
Positive scores are the diagonal of the scores matrix.
|
71 |
-
"""
|
72 |
-
|
73 |
-
# Compute the ColBERT scores
|
74 |
-
scores = (
|
75 |
-
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
|
76 |
-
) # (batch_size, batch_size)
|
77 |
-
|
78 |
-
# Positive scores are the diagonal of the scores matrix.
|
79 |
-
pos_scores = scores.diagonal() # (batch_size,)
|
80 |
-
|
81 |
-
# Negative score for a given query is the maximum of the scores against all all other pages.
|
82 |
-
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
|
83 |
-
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
|
84 |
-
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
|
85 |
-
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
|
86 |
-
|
87 |
-
# Compute the loss
|
88 |
-
# The loss is computed as the negative log of the softmax of the positive scores
|
89 |
-
# relative to the negative scores.
|
90 |
-
# This can be simplified to log-sum-exp of negative scores minus the positive score
|
91 |
-
# for numerical stability.
|
92 |
-
# torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
|
93 |
-
loss = F.softplus(neg_scores - pos_scores).mean()
|
94 |
-
|
95 |
-
return loss
|
96 |
-
|
97 |
-
|
98 |
-
class BiPairwiseCELoss(torch.nn.Module):
|
99 |
-
def __init__(self):
|
100 |
-
super().__init__()
|
101 |
-
self.ce_loss = CrossEntropyLoss()
|
102 |
-
|
103 |
-
def forward(self, query_embeddings, doc_embeddings):
|
104 |
-
"""
|
105 |
-
query_embeddings: (batch_size, dim)
|
106 |
-
doc_embeddings: (batch_size, dim)
|
107 |
-
"""
|
108 |
-
|
109 |
-
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
|
110 |
-
|
111 |
-
pos_scores = scores.diagonal()
|
112 |
-
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
|
113 |
-
neg_scores = neg_scores.max(dim=1)[0]
|
114 |
-
|
115 |
-
# Compute the loss
|
116 |
-
# The loss is computed as the negative log of the softmax of the positive scores
|
117 |
-
# relative to the negative scores.
|
118 |
-
# This can be simplified to log-sum-exp of negative scores minus the positive score
|
119 |
-
# for numerical stability.
|
120 |
-
loss = F.softplus(neg_scores - pos_scores).mean()
|
121 |
-
|
122 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/models/__init__.py
DELETED
File without changes
|
colpali_engine/models/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (170 Bytes)
|
|
colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc
DELETED
Binary file (4.87 kB)
|
|
colpali_engine/models/clip_baselines.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import Optional
|
3 |
-
|
4 |
-
import torch
|
5 |
-
from transformers import SiglipModel
|
6 |
-
|
7 |
-
|
8 |
-
class SigLIP(SiglipModel):
|
9 |
-
def forward(self, *args, **kwargs):
|
10 |
-
"""
|
11 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
12 |
-
|
13 |
-
Args:
|
14 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
15 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
16 |
-
|
17 |
-
Returns:
|
18 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
19 |
-
"""
|
20 |
-
return self.forward_branch(*args, **kwargs)
|
21 |
-
|
22 |
-
def forward_branch(
|
23 |
-
self,
|
24 |
-
input_ids: Optional[torch.LongTensor] = None,
|
25 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
26 |
-
attention_mask: Optional[torch.Tensor] = None,
|
27 |
-
position_ids: Optional[torch.LongTensor] = None,
|
28 |
-
return_loss: Optional[bool] = None,
|
29 |
-
output_attentions: Optional[bool] = None,
|
30 |
-
output_hidden_states: Optional[bool] = None,
|
31 |
-
return_dict: Optional[bool] = None,
|
32 |
-
interpolate_pos_encoding: bool = False,
|
33 |
-
):
|
34 |
-
|
35 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
36 |
-
output_hidden_states = (
|
37 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
38 |
-
)
|
39 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
40 |
-
|
41 |
-
if pixel_values is not None:
|
42 |
-
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
43 |
-
|
44 |
-
outputs = self.vision_model(
|
45 |
-
pixel_values=pixel_values.to(dtype=self.dtype),
|
46 |
-
output_attentions=output_attentions,
|
47 |
-
output_hidden_states=output_hidden_states,
|
48 |
-
return_dict=return_dict,
|
49 |
-
interpolate_pos_encoding=interpolate_pos_encoding,
|
50 |
-
)
|
51 |
-
|
52 |
-
else:
|
53 |
-
outputs = self.text_model(
|
54 |
-
input_ids=input_ids,
|
55 |
-
attention_mask=attention_mask,
|
56 |
-
position_ids=position_ids,
|
57 |
-
output_attentions=output_attentions,
|
58 |
-
output_hidden_states=output_hidden_states,
|
59 |
-
return_dict=return_dict,
|
60 |
-
)
|
61 |
-
|
62 |
-
embeds = outputs[1]
|
63 |
-
|
64 |
-
# normalized features
|
65 |
-
embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
|
66 |
-
return embeds
|
67 |
-
|
68 |
-
|
69 |
-
class ColSigLIP(SiglipModel):
|
70 |
-
def __init__(self, config):
|
71 |
-
super(ColSigLIP, self).__init__(config=config)
|
72 |
-
self.dim = 128
|
73 |
-
self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim)
|
74 |
-
self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim)
|
75 |
-
self.main_input_name = "doc_input_ids"
|
76 |
-
|
77 |
-
def forward(self, *args, **kwargs):
|
78 |
-
"""
|
79 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
80 |
-
|
81 |
-
Args:
|
82 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
83 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
84 |
-
|
85 |
-
Returns:
|
86 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
87 |
-
"""
|
88 |
-
return self.forward_branch(*args, **kwargs)
|
89 |
-
|
90 |
-
def forward_branch(
|
91 |
-
self,
|
92 |
-
input_ids: Optional[torch.LongTensor] = None,
|
93 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
94 |
-
attention_mask: Optional[torch.Tensor] = None,
|
95 |
-
position_ids: Optional[torch.LongTensor] = None,
|
96 |
-
return_loss: Optional[bool] = None,
|
97 |
-
output_attentions: Optional[bool] = None,
|
98 |
-
output_hidden_states: Optional[bool] = None,
|
99 |
-
return_dict: Optional[bool] = None,
|
100 |
-
interpolate_pos_encoding: bool = False,
|
101 |
-
):
|
102 |
-
|
103 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
104 |
-
output_hidden_states = (
|
105 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
106 |
-
)
|
107 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
108 |
-
|
109 |
-
if pixel_values is not None:
|
110 |
-
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
|
111 |
-
|
112 |
-
outputs = self.vision_model(
|
113 |
-
pixel_values=pixel_values.to(dtype=self.dtype),
|
114 |
-
output_attentions=output_attentions,
|
115 |
-
output_hidden_states=output_hidden_states,
|
116 |
-
return_dict=return_dict,
|
117 |
-
interpolate_pos_encoding=interpolate_pos_encoding,
|
118 |
-
)
|
119 |
-
|
120 |
-
last_hidden_states = outputs.last_hidden_state
|
121 |
-
|
122 |
-
proj = self.custom_vision_proj(last_hidden_states)
|
123 |
-
# normalize l2 norm
|
124 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
125 |
-
|
126 |
-
else:
|
127 |
-
outputs = self.text_model(
|
128 |
-
input_ids=input_ids,
|
129 |
-
attention_mask=attention_mask,
|
130 |
-
position_ids=position_ids,
|
131 |
-
output_attentions=output_attentions,
|
132 |
-
output_hidden_states=output_hidden_states,
|
133 |
-
return_dict=return_dict,
|
134 |
-
)
|
135 |
-
|
136 |
-
last_hidden_states = outputs.last_hidden_state
|
137 |
-
|
138 |
-
proj = self.custom_text_proj(last_hidden_states)
|
139 |
-
# normalize l2 norm
|
140 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
141 |
-
proj = proj * attention_mask.unsqueeze(-1)
|
142 |
-
|
143 |
-
# normalized features
|
144 |
-
return proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/models/colbert_architectures.py
DELETED
@@ -1,177 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
from transformers import (
|
3 |
-
BertModel,
|
4 |
-
BertPreTrainedModel,
|
5 |
-
CamembertModel,
|
6 |
-
CamembertPreTrainedModel,
|
7 |
-
LlamaModel,
|
8 |
-
LlamaPreTrainedModel,
|
9 |
-
XLMRobertaModel,
|
10 |
-
XLMRobertaPreTrainedModel,
|
11 |
-
)
|
12 |
-
|
13 |
-
|
14 |
-
class ColCamembert(CamembertPreTrainedModel):
|
15 |
-
def __init__(self, config):
|
16 |
-
super(ColCamembert, self).__init__(config=config)
|
17 |
-
self.roberta: CamembertPreTrainedModel = CamembertModel(config)
|
18 |
-
self.dim = 128
|
19 |
-
self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
|
20 |
-
self.main_input_name = "doc_input_ids"
|
21 |
-
|
22 |
-
def forward(self, *args, **kwargs):
|
23 |
-
"""
|
24 |
-
Forward pass through Camenbert and the linear layer for dimensionality reduction
|
25 |
-
|
26 |
-
Args:
|
27 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
28 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
32 |
-
"""
|
33 |
-
outputs = self.roberta(*args, **kwargs)
|
34 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
35 |
-
proj = self.linear(last_hidden_states)
|
36 |
-
# normalize l2 norm
|
37 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
38 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
39 |
-
return proj
|
40 |
-
|
41 |
-
|
42 |
-
class ColXLMRoBERTa(XLMRobertaPreTrainedModel):
|
43 |
-
def __init__(self, config):
|
44 |
-
super(ColXLMRoBERTa, self).__init__(config=config)
|
45 |
-
self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
|
46 |
-
self.dim = 128
|
47 |
-
self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
|
48 |
-
self.main_input_name = "doc_input_ids"
|
49 |
-
|
50 |
-
def forward(self, *args, **kwargs):
|
51 |
-
"""
|
52 |
-
Forward pass through Roberta and the linear layer for dimensionality reduction
|
53 |
-
|
54 |
-
Args:
|
55 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
56 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
57 |
-
|
58 |
-
Returns:
|
59 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
60 |
-
"""
|
61 |
-
outputs = self.roberta(*args, **kwargs)
|
62 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
63 |
-
proj = self.linear(last_hidden_states)
|
64 |
-
# normalize l2 norm
|
65 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
66 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
67 |
-
return proj
|
68 |
-
|
69 |
-
|
70 |
-
class BiXLMRoBERTa(XLMRobertaPreTrainedModel):
|
71 |
-
def __init__(self, config):
|
72 |
-
super(BiXLMRoBERTa, self).__init__(config=config)
|
73 |
-
self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
|
74 |
-
self.main_input_name = "doc_input_ids"
|
75 |
-
|
76 |
-
def forward(self, *args, **kwargs):
|
77 |
-
"""
|
78 |
-
Forward pass through Roberta and the linear layer for dimensionality reduction
|
79 |
-
|
80 |
-
Args:
|
81 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
82 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
86 |
-
"""
|
87 |
-
outputs = self.roberta(*args, **kwargs)
|
88 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
89 |
-
# pooling - mean tokens that have attention mask == 1
|
90 |
-
proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
|
91 |
-
proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
|
92 |
-
# normalize l2 norm
|
93 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
94 |
-
return proj
|
95 |
-
|
96 |
-
|
97 |
-
class ColBERT(BertPreTrainedModel):
|
98 |
-
def __init__(self, config):
|
99 |
-
super(ColBERT, self).__init__(config=config)
|
100 |
-
self.bert: BertModel = BertModel(config)
|
101 |
-
self.dim = 128
|
102 |
-
self.linear = nn.Linear(self.bert.config.hidden_size, self.dim)
|
103 |
-
self.main_input_name = "doc_input_ids"
|
104 |
-
|
105 |
-
def forward(self, *args, **kwargs):
|
106 |
-
"""
|
107 |
-
Forward pass through BERT and the linear layer for dimensionality reduction
|
108 |
-
|
109 |
-
Args:
|
110 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
111 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
112 |
-
|
113 |
-
Returns:
|
114 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
115 |
-
"""
|
116 |
-
outputs = self.bert(*args, **kwargs)
|
117 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
118 |
-
proj = self.linear(last_hidden_states)
|
119 |
-
# normalize l2 norm
|
120 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
121 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
122 |
-
return proj
|
123 |
-
|
124 |
-
|
125 |
-
class BiBERT(BertPreTrainedModel):
|
126 |
-
def __init__(self, config):
|
127 |
-
super(BiBERT, self).__init__(config=config)
|
128 |
-
self.bert: BertModel = BertModel(config)
|
129 |
-
self.main_input_name = "doc_input_ids"
|
130 |
-
|
131 |
-
def forward(self, *args, **kwargs):
|
132 |
-
"""
|
133 |
-
Forward pass through BERT and the linear layer for dimensionality reduction
|
134 |
-
|
135 |
-
Args:
|
136 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
137 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
138 |
-
|
139 |
-
Returns:
|
140 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
141 |
-
"""
|
142 |
-
outputs = self.bert(*args, **kwargs)
|
143 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
144 |
-
# pooling - mean tokens that have attention mask == 1
|
145 |
-
proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
|
146 |
-
proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
|
147 |
-
# normalize l2 norm
|
148 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
149 |
-
return proj
|
150 |
-
|
151 |
-
|
152 |
-
class ColLlama(LlamaPreTrainedModel):
|
153 |
-
def __init__(self, config):
|
154 |
-
super(ColLlama, self).__init__(config=config)
|
155 |
-
self.model: LlamaModel = LlamaModel(config)
|
156 |
-
self.dim = 128
|
157 |
-
self.linear = nn.Linear(self.model.config.hidden_size, self.dim)
|
158 |
-
self.main_input_name = "doc_input_ids"
|
159 |
-
|
160 |
-
def forward(self, *args, **kwargs):
|
161 |
-
"""
|
162 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
163 |
-
|
164 |
-
Args:
|
165 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
166 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
167 |
-
|
168 |
-
Returns:
|
169 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
170 |
-
"""
|
171 |
-
outputs = self.model(*args, **kwargs)
|
172 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
173 |
-
proj = self.linear(last_hidden_states)
|
174 |
-
# normalize l2 norm
|
175 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
176 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
177 |
-
return proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/models/idefics_colbert_architecture.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
from torch import nn
|
2 |
-
from transformers import Idefics2Model, Idefics2PreTrainedModel
|
3 |
-
|
4 |
-
|
5 |
-
class BiIdefics(Idefics2PreTrainedModel):
|
6 |
-
def __init__(self, config):
|
7 |
-
super(BiIdefics, self).__init__(config=config)
|
8 |
-
self.model: Idefics2Model = Idefics2Model(config)
|
9 |
-
self.pooling_strategy = "last"
|
10 |
-
self.main_input_name = "doc_input_ids"
|
11 |
-
|
12 |
-
def forward(self, *args, **kwargs):
|
13 |
-
"""
|
14 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
15 |
-
|
16 |
-
Args:
|
17 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
18 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
19 |
-
|
20 |
-
Returns:
|
21 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
22 |
-
"""
|
23 |
-
outputs = self.model(*args, **kwargs)
|
24 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
25 |
-
# pooling - last token
|
26 |
-
proj = last_hidden_states[:, -1, :]
|
27 |
-
# normalize l2 norm
|
28 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
29 |
-
return proj
|
30 |
-
|
31 |
-
|
32 |
-
class ColIdefics(Idefics2PreTrainedModel):
|
33 |
-
def __init__(self, config):
|
34 |
-
super(ColIdefics, self).__init__(config=config)
|
35 |
-
self.model: Idefics2Model = Idefics2Model(config)
|
36 |
-
self.dim = 128
|
37 |
-
self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
38 |
-
self.main_input_name = "doc_input_ids"
|
39 |
-
|
40 |
-
def forward(self, *args, **kwargs):
|
41 |
-
"""
|
42 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
43 |
-
|
44 |
-
Args:
|
45 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
46 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
47 |
-
|
48 |
-
Returns:
|
49 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
50 |
-
"""
|
51 |
-
outputs = self.model(*args, **kwargs)
|
52 |
-
last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
|
53 |
-
proj = self.linear(last_hidden_states)
|
54 |
-
# normalize l2 norm
|
55 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
56 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
57 |
-
return proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/models/paligemma_colbert_architecture.py
DELETED
@@ -1,191 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel
|
4 |
-
|
5 |
-
|
6 |
-
class BiPaliLast(PaliGemmaPreTrainedModel):
|
7 |
-
def __init__(self, config):
|
8 |
-
super(BiPaliLast, self).__init__(config=config)
|
9 |
-
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
10 |
-
self.pooling_strategy = "last"
|
11 |
-
self.main_input_name = "doc_input_ids"
|
12 |
-
|
13 |
-
def forward(self, *args, **kwargs):
|
14 |
-
"""
|
15 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
16 |
-
|
17 |
-
Args:
|
18 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
19 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
20 |
-
|
21 |
-
Returns:
|
22 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
23 |
-
"""
|
24 |
-
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
25 |
-
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
26 |
-
# pooling - last token
|
27 |
-
proj = last_hidden_states[:, -1, :]
|
28 |
-
# normalize l2 norm
|
29 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
30 |
-
return proj
|
31 |
-
|
32 |
-
|
33 |
-
class BiPaliMean(PaliGemmaPreTrainedModel):
|
34 |
-
def __init__(self, config):
|
35 |
-
super(BiPaliMean, self).__init__(config=config)
|
36 |
-
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
37 |
-
self.pooling_strategy = "mean"
|
38 |
-
self.main_input_name = "doc_input_ids"
|
39 |
-
|
40 |
-
def forward(self, *args, **kwargs):
|
41 |
-
"""
|
42 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
43 |
-
|
44 |
-
Args:
|
45 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
46 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
47 |
-
|
48 |
-
Returns:
|
49 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
50 |
-
"""
|
51 |
-
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
52 |
-
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
53 |
-
# pooling -mean on attention mask==1
|
54 |
-
proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
|
55 |
-
kwargs["attention_mask"], dim=1, keepdim=True
|
56 |
-
)
|
57 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
58 |
-
return proj
|
59 |
-
|
60 |
-
|
61 |
-
class ColPali(PaliGemmaPreTrainedModel):
|
62 |
-
def __init__(self, config):
|
63 |
-
super(ColPali, self).__init__(config=config)
|
64 |
-
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
65 |
-
self.dim = 128
|
66 |
-
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
67 |
-
self.main_input_name = "doc_input_ids"
|
68 |
-
|
69 |
-
def forward(self, *args, **kwargs):
|
70 |
-
"""
|
71 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
72 |
-
|
73 |
-
Args:
|
74 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
75 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
79 |
-
"""
|
80 |
-
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
81 |
-
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
82 |
-
proj = self.custom_text_proj(last_hidden_states)
|
83 |
-
# normalize l2 norm
|
84 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
85 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
86 |
-
return proj
|
87 |
-
|
88 |
-
|
89 |
-
class ColNewSiglip(PaliGemmaPreTrainedModel):
|
90 |
-
def __init__(self, config):
|
91 |
-
super(ColNewSiglip, self).__init__(config=config)
|
92 |
-
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
93 |
-
self.dim = 128
|
94 |
-
self.custom_image_proj = nn.Linear(self.model.config.vision_config.projection_dim, self.dim)
|
95 |
-
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
|
96 |
-
self.main_input_name = "doc_input_ids"
|
97 |
-
|
98 |
-
def forward(self, *args, **kwargs):
|
99 |
-
"""
|
100 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
101 |
-
|
102 |
-
Args:
|
103 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
104 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
105 |
-
|
106 |
-
Returns:
|
107 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
108 |
-
"""
|
109 |
-
# outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
110 |
-
if "pixel_values" in kwargs:
|
111 |
-
image_features = self.vision_model_output(*args, **kwargs)
|
112 |
-
# print(f"Doc: {image_features.shape}")
|
113 |
-
proj = self.custom_image_proj(image_features)
|
114 |
-
# print(f"Doc proj: {proj.shape}")
|
115 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
116 |
-
else:
|
117 |
-
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
118 |
-
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
119 |
-
# print(f"Query: {last_hidden_states.shape}")
|
120 |
-
proj = self.custom_text_proj(last_hidden_states)
|
121 |
-
# print(f"Query proj: {proj.shape}")
|
122 |
-
# normalize l2 norm
|
123 |
-
proj = proj / proj.norm(dim=-1, keepdim=True)
|
124 |
-
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
|
125 |
-
return proj
|
126 |
-
|
127 |
-
def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
|
128 |
-
|
129 |
-
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
130 |
-
# 2. Merge text and images
|
131 |
-
if pixel_values is not None and input_ids.shape[1] != 1:
|
132 |
-
image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
133 |
-
selected_image_feature = image_outputs.last_hidden_state
|
134 |
-
image_features = self.model.multi_modal_projector(selected_image_feature)
|
135 |
-
|
136 |
-
return image_features
|
137 |
-
|
138 |
-
raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
|
139 |
-
|
140 |
-
|
141 |
-
class BiNewSiglip(PaliGemmaPreTrainedModel):
|
142 |
-
def __init__(self, config):
|
143 |
-
super(BiNewSiglip, self).__init__(config=config)
|
144 |
-
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
|
145 |
-
self.main_input_name = "doc_input_ids"
|
146 |
-
|
147 |
-
def forward(self, *args, **kwargs):
|
148 |
-
"""
|
149 |
-
Forward pass through Llama and the linear layer for dimensionality reduction
|
150 |
-
|
151 |
-
Args:
|
152 |
-
- input_ids (torch.LongTensor): The input tokens tensor.
|
153 |
-
- attention_mask (torch.LongTensor): The attention mask tensor.
|
154 |
-
|
155 |
-
Returns:
|
156 |
-
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
|
157 |
-
"""
|
158 |
-
# outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
159 |
-
if "pixel_values" in kwargs:
|
160 |
-
image_features = self.vision_model_output(*args, **kwargs)
|
161 |
-
# print(f"Doc: {image_features.shape}")
|
162 |
-
# pool image features
|
163 |
-
proj = torch.mean(image_features, dim=1)
|
164 |
-
# print(f"Doc proj: {proj.shape}")
|
165 |
-
norm = proj.norm(dim=-1, keepdim=True)
|
166 |
-
proj = proj / norm
|
167 |
-
else:
|
168 |
-
outputs = self.model(*args, output_hidden_states=True, **kwargs)
|
169 |
-
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
170 |
-
# pooling -mean on attention mask==1
|
171 |
-
|
172 |
-
proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
|
173 |
-
kwargs["attention_mask"], dim=1, keepdim=True
|
174 |
-
)
|
175 |
-
# print(f"Query proj: {proj.shape}")
|
176 |
-
norm = proj.norm(dim=-1, keepdim=True)
|
177 |
-
proj = proj / norm
|
178 |
-
return proj
|
179 |
-
|
180 |
-
def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
|
181 |
-
|
182 |
-
inputs_embeds = self.model.get_input_embeddings()(input_ids)
|
183 |
-
# 2. Merge text and images
|
184 |
-
if pixel_values is not None and input_ids.shape[1] != 1:
|
185 |
-
image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
|
186 |
-
selected_image_feature = image_outputs.last_hidden_state
|
187 |
-
image_features = self.model.multi_modal_projector(selected_image_feature)
|
188 |
-
|
189 |
-
return image_features
|
190 |
-
|
191 |
-
raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/trainer/__init__.py
DELETED
File without changes
|
colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (171 Bytes)
|
|
colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc
DELETED
Binary file (3.18 kB)
|
|
colpali_engine/trainer/contrastive_trainer.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import Trainer
|
3 |
-
|
4 |
-
|
5 |
-
class ContrastiveTrainer(Trainer):
|
6 |
-
def __init__(self, loss_func, is_vision_model, *args, **kwargs):
|
7 |
-
super().__init__(*args, **kwargs)
|
8 |
-
self.loss_func = loss_func
|
9 |
-
self.is_vision_model = is_vision_model
|
10 |
-
|
11 |
-
def compute_loss(self, model, inputs, return_outputs=False):
|
12 |
-
query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
|
13 |
-
if self.is_vision_model:
|
14 |
-
if "doc_pixel_attention_mask" not in inputs:
|
15 |
-
doc_outputs = model(
|
16 |
-
input_ids=inputs["doc_input_ids"],
|
17 |
-
attention_mask=inputs["doc_attention_mask"],
|
18 |
-
pixel_values=inputs["doc_pixel_values"],
|
19 |
-
)
|
20 |
-
else:
|
21 |
-
doc_outputs = model(
|
22 |
-
input_ids=inputs["doc_input_ids"],
|
23 |
-
attention_mask=inputs["doc_attention_mask"],
|
24 |
-
pixel_values=inputs["doc_pixel_values"],
|
25 |
-
pixel_attention_mask=inputs["doc_pixel_attention_mask"],
|
26 |
-
)
|
27 |
-
else:
|
28 |
-
doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
|
29 |
-
|
30 |
-
loss = self.loss_func(query_outputs, doc_outputs)
|
31 |
-
return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
|
32 |
-
|
33 |
-
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
|
34 |
-
"""This function is used to generate predictions and return the loss for the given inputs."""
|
35 |
-
if not prediction_loss_only:
|
36 |
-
raise ValueError("prediction_step is only called with prediction_loss_only=True")
|
37 |
-
|
38 |
-
with torch.no_grad():
|
39 |
-
if self.is_vision_model:
|
40 |
-
if "doc_pixel_attention_mask" not in inputs:
|
41 |
-
doc_outputs = model(
|
42 |
-
input_ids=inputs["doc_input_ids"],
|
43 |
-
attention_mask=inputs["doc_attention_mask"],
|
44 |
-
pixel_values=inputs["doc_pixel_values"],
|
45 |
-
)
|
46 |
-
else:
|
47 |
-
doc_outputs = model(
|
48 |
-
input_ids=inputs["doc_input_ids"],
|
49 |
-
attention_mask=inputs["doc_attention_mask"],
|
50 |
-
pixel_values=inputs["doc_pixel_values"],
|
51 |
-
pixel_attention_mask=inputs["doc_pixel_attention_mask"],
|
52 |
-
)
|
53 |
-
query_outputs = model(
|
54 |
-
input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
|
55 |
-
)
|
56 |
-
else:
|
57 |
-
|
58 |
-
query_outputs = model(
|
59 |
-
input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
|
60 |
-
)
|
61 |
-
doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
|
62 |
-
|
63 |
-
loss = self.loss_func(query_outputs, doc_outputs)
|
64 |
-
return loss, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/trainer/retrieval_evaluator.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from mteb.evaluation.evaluators import RetrievalEvaluator
|
3 |
-
|
4 |
-
|
5 |
-
class CustomEvaluator:
|
6 |
-
def __init__(self, is_multi_vector=False):
|
7 |
-
self.is_multi_vector = is_multi_vector
|
8 |
-
self.mteb_evaluator = RetrievalEvaluator()
|
9 |
-
|
10 |
-
def evaluate(self, qs, ps):
|
11 |
-
if self.is_multi_vector:
|
12 |
-
scores = self.evaluate_colbert(qs, ps)
|
13 |
-
else:
|
14 |
-
scores = self.evaluate_biencoder(qs, ps)
|
15 |
-
|
16 |
-
assert scores.shape[0] == len(qs)
|
17 |
-
|
18 |
-
arg_score = scores.argmax(dim=1)
|
19 |
-
# compare to arange
|
20 |
-
accuracy = (arg_score == torch.arange(scores.shape[0], device=scores.device)).sum().item() / scores.shape[0]
|
21 |
-
print(arg_score)
|
22 |
-
print(f"Top 1 Accuracy (verif): {accuracy}")
|
23 |
-
|
24 |
-
# cast to numpy
|
25 |
-
# scores = scores.cpu().numpy()
|
26 |
-
scores = scores.to(torch.float32).cpu().numpy()
|
27 |
-
return scores
|
28 |
-
|
29 |
-
def compute_metrics(self, relevant_docs, results, **kwargs):
|
30 |
-
# wrap mteb package
|
31 |
-
|
32 |
-
ndcg, _map, recall, precision, naucs = self.mteb_evaluator.evaluate(
|
33 |
-
relevant_docs,
|
34 |
-
results,
|
35 |
-
self.mteb_evaluator.k_values,
|
36 |
-
ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
|
37 |
-
)
|
38 |
-
mrr = self.mteb_evaluator.evaluate_custom(relevant_docs, results, self.mteb_evaluator.k_values, "mrr")
|
39 |
-
scores = {
|
40 |
-
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
|
41 |
-
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
|
42 |
-
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
|
43 |
-
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
|
44 |
-
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
|
45 |
-
**{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
|
46 |
-
}
|
47 |
-
return scores
|
48 |
-
|
49 |
-
def evaluate_colbert(self, qs, ps, batch_size=128) -> torch.Tensor:
|
50 |
-
scores = []
|
51 |
-
for i in range(0, len(qs), batch_size):
|
52 |
-
scores_batch = []
|
53 |
-
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
|
54 |
-
"cpu"
|
55 |
-
)
|
56 |
-
for j in range(0, len(ps), batch_size):
|
57 |
-
ps_batch = torch.nn.utils.rnn.pad_sequence(
|
58 |
-
ps[j : j + batch_size], batch_first=True, padding_value=0
|
59 |
-
).to("cpu")
|
60 |
-
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
|
61 |
-
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
62 |
-
scores.append(scores_batch)
|
63 |
-
scores = torch.cat(scores, dim=0)
|
64 |
-
return scores
|
65 |
-
|
66 |
-
def evaluate_biencoder(self, qs, ps) -> torch.Tensor:
|
67 |
-
|
68 |
-
qs = torch.stack(qs)
|
69 |
-
ps = torch.stack(ps)
|
70 |
-
|
71 |
-
scores = torch.einsum("bd,cd->bc", qs, ps)
|
72 |
-
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/__init__.py
DELETED
File without changes
|
colpali_engine/utils/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (169 Bytes)
|
|
colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc
DELETED
Binary file (1.2 kB)
|
|
colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc
DELETED
Binary file (998 Bytes)
|
|
colpali_engine/utils/colidefics_processing_utils.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
# Utils for processing images and queries for ColPaLi
|
2 |
-
|
3 |
-
def process_images(processor, images, max_length: int = 50):
|
4 |
-
texts_doc = []
|
5 |
-
images = [image.convert("RGB") for image in images]
|
6 |
-
|
7 |
-
for _ in images:
|
8 |
-
messages_doc = [
|
9 |
-
{
|
10 |
-
"role": "user",
|
11 |
-
"content": [
|
12 |
-
{"type": "text", "text": "Describe the image."},
|
13 |
-
{"type": "image"},
|
14 |
-
],
|
15 |
-
},
|
16 |
-
]
|
17 |
-
|
18 |
-
text_doc = processor.apply_chat_template(messages_doc, add_generation_prompt=False)
|
19 |
-
texts_doc.append(text_doc.strip())
|
20 |
-
|
21 |
-
batch_doc = processor(
|
22 |
-
text=texts_doc,
|
23 |
-
images=images,
|
24 |
-
return_tensors="pt",
|
25 |
-
padding="longest",
|
26 |
-
)
|
27 |
-
return batch_doc
|
28 |
-
|
29 |
-
|
30 |
-
def process_queries(processor, queries, mock_image, max_length: int = 50):
|
31 |
-
texts_query = []
|
32 |
-
for query in queries:
|
33 |
-
messages_query = [
|
34 |
-
{
|
35 |
-
"role": "user",
|
36 |
-
"content": [
|
37 |
-
{
|
38 |
-
"type": "text",
|
39 |
-
"text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
|
40 |
-
},
|
41 |
-
],
|
42 |
-
},
|
43 |
-
]
|
44 |
-
text_query = processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
|
45 |
-
texts_query.append(text_query)
|
46 |
-
|
47 |
-
batch_query = processor(
|
48 |
-
text=texts_query,
|
49 |
-
return_tensors="pt",
|
50 |
-
padding="longest",
|
51 |
-
max_length=max_length,
|
52 |
-
)
|
53 |
-
return batch_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/colpali_processing_utils.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
# Utils for processing images and queries for ColPaLi
|
2 |
-
|
3 |
-
|
4 |
-
def process_images(processor, images, max_length: int = 50):
|
5 |
-
texts_doc = ["Describe the image."] * len(images)
|
6 |
-
images = [image.convert("RGB") for image in images]
|
7 |
-
|
8 |
-
batch_doc = processor(
|
9 |
-
text=texts_doc,
|
10 |
-
images=images,
|
11 |
-
return_tensors="pt",
|
12 |
-
padding="longest",
|
13 |
-
max_length=max_length + processor.image_seq_length,
|
14 |
-
)
|
15 |
-
return batch_doc
|
16 |
-
|
17 |
-
|
18 |
-
def process_queries(processor, queries, mock_image, max_length: int = 50):
|
19 |
-
texts_query = []
|
20 |
-
for query in queries:
|
21 |
-
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
|
22 |
-
texts_query.append(query)
|
23 |
-
|
24 |
-
batch_query = processor(
|
25 |
-
images=[mock_image.convert("RGB")] * len(texts_query),
|
26 |
-
# NOTE: the image is not used in batch_query but it is required for calling the processor
|
27 |
-
text=texts_query,
|
28 |
-
return_tensors="pt",
|
29 |
-
padding="longest",
|
30 |
-
max_length=max_length + processor.image_seq_length,
|
31 |
-
)
|
32 |
-
del batch_query["pixel_values"]
|
33 |
-
|
34 |
-
batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
|
35 |
-
batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
|
36 |
-
return batch_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/dataset_transformation.py
DELETED
@@ -1,158 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
4 |
-
|
5 |
-
USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
|
6 |
-
|
7 |
-
|
8 |
-
def add_metadata_column(dataset, column_name, value):
|
9 |
-
def add_source(example):
|
10 |
-
example[column_name] = value
|
11 |
-
return example
|
12 |
-
|
13 |
-
return dataset.map(add_source)
|
14 |
-
|
15 |
-
|
16 |
-
def load_train_set() -> DatasetDict:
|
17 |
-
|
18 |
-
ds_paths = [
|
19 |
-
"infovqa_train",
|
20 |
-
"docvqa_train",
|
21 |
-
"arxivqa_train",
|
22 |
-
"tatdqa_train",
|
23 |
-
"syntheticDocQA_government_reports_train",
|
24 |
-
"syntheticDocQA_healthcare_industry_train",
|
25 |
-
"syntheticDocQA_artificial_intelligence_train",
|
26 |
-
"syntheticDocQA_energy_train",
|
27 |
-
]
|
28 |
-
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
29 |
-
ds_tot = []
|
30 |
-
for path in ds_paths:
|
31 |
-
cpath = base_path + path
|
32 |
-
ds = load_dataset(cpath, split="train")
|
33 |
-
if "arxivqa" in path:
|
34 |
-
# subsample 10k
|
35 |
-
ds = ds.shuffle(42).select(range(10000))
|
36 |
-
ds_tot.append(ds)
|
37 |
-
|
38 |
-
dataset = concatenate_datasets(ds_tot)
|
39 |
-
dataset = dataset.shuffle(seed=42)
|
40 |
-
# split into train and test
|
41 |
-
dataset_eval = dataset.select(range(500))
|
42 |
-
dataset = dataset.select(range(500, len(dataset)))
|
43 |
-
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
44 |
-
return ds_dict
|
45 |
-
|
46 |
-
|
47 |
-
def load_train_set_with_tabfquad() -> DatasetDict:
|
48 |
-
|
49 |
-
ds_paths = [
|
50 |
-
"infovqa_train",
|
51 |
-
"docvqa_train",
|
52 |
-
"arxivqa_train",
|
53 |
-
"tatdqa_train",
|
54 |
-
"tabfquad_train_subsampled",
|
55 |
-
"syntheticDocQA_government_reports_train",
|
56 |
-
"syntheticDocQA_healthcare_industry_train",
|
57 |
-
"syntheticDocQA_artificial_intelligence_train",
|
58 |
-
"syntheticDocQA_energy_train",
|
59 |
-
]
|
60 |
-
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
61 |
-
ds_tot = []
|
62 |
-
for path in ds_paths:
|
63 |
-
cpath = base_path + path
|
64 |
-
ds = load_dataset(cpath, split="train")
|
65 |
-
if "arxivqa" in path:
|
66 |
-
# subsample 10k
|
67 |
-
ds = ds.shuffle(42).select(range(10000))
|
68 |
-
ds_tot.append(ds)
|
69 |
-
|
70 |
-
dataset = concatenate_datasets(ds_tot)
|
71 |
-
dataset = dataset.shuffle(seed=42)
|
72 |
-
# split into train and test
|
73 |
-
dataset_eval = dataset.select(range(500))
|
74 |
-
dataset = dataset.select(range(500, len(dataset)))
|
75 |
-
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
76 |
-
return ds_dict
|
77 |
-
|
78 |
-
|
79 |
-
def load_train_set_with_docmatix() -> DatasetDict:
|
80 |
-
|
81 |
-
ds_paths = [
|
82 |
-
"infovqa_train",
|
83 |
-
"docvqa_train",
|
84 |
-
"arxivqa_train",
|
85 |
-
"tatdqa_train",
|
86 |
-
"tabfquad_train_subsampled",
|
87 |
-
"syntheticDocQA_government_reports_train",
|
88 |
-
"syntheticDocQA_healthcare_industry_train",
|
89 |
-
"syntheticDocQA_artificial_intelligence_train",
|
90 |
-
"syntheticDocQA_energy_train",
|
91 |
-
"Docmatix_filtered_train",
|
92 |
-
]
|
93 |
-
base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
|
94 |
-
ds_tot = []
|
95 |
-
for path in ds_paths:
|
96 |
-
cpath = base_path + path
|
97 |
-
ds = load_dataset(cpath, split="train")
|
98 |
-
if "arxivqa" in path:
|
99 |
-
# subsample 10k
|
100 |
-
ds = ds.shuffle(42).select(range(10000))
|
101 |
-
ds_tot.append(ds)
|
102 |
-
|
103 |
-
dataset = concatenate_datasets(ds_tot)
|
104 |
-
dataset = dataset.shuffle(seed=42)
|
105 |
-
# split into train and test
|
106 |
-
dataset_eval = dataset.select(range(500))
|
107 |
-
dataset = dataset.select(range(500, len(dataset)))
|
108 |
-
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
109 |
-
return ds_dict
|
110 |
-
|
111 |
-
|
112 |
-
def load_docvqa_dataset() -> DatasetDict:
|
113 |
-
if USE_LOCAL_DATASET:
|
114 |
-
dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
|
115 |
-
dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
|
116 |
-
dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
|
117 |
-
dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
|
118 |
-
else:
|
119 |
-
dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
|
120 |
-
dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
|
121 |
-
dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
|
122 |
-
dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
|
123 |
-
|
124 |
-
# concatenate the two datasets
|
125 |
-
dataset = concatenate_datasets([dataset_doc, dataset_info])
|
126 |
-
dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
|
127 |
-
# sample 100 from eval dataset
|
128 |
-
dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
|
129 |
-
|
130 |
-
# rename question as query
|
131 |
-
dataset = dataset.rename_column("question", "query")
|
132 |
-
dataset_eval = dataset_eval.rename_column("question", "query")
|
133 |
-
|
134 |
-
# create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
|
135 |
-
dataset = dataset.map(
|
136 |
-
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
|
137 |
-
)
|
138 |
-
dataset_eval = dataset_eval.map(
|
139 |
-
lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
|
140 |
-
)
|
141 |
-
|
142 |
-
ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
|
143 |
-
|
144 |
-
return ds_dict
|
145 |
-
|
146 |
-
|
147 |
-
class TestSetFactory:
|
148 |
-
def __init__(self, dataset_path):
|
149 |
-
self.dataset_path = dataset_path
|
150 |
-
|
151 |
-
def __call__(self, *args, **kwargs):
|
152 |
-
dataset = load_dataset(self.dataset_path, split="test")
|
153 |
-
return dataset
|
154 |
-
|
155 |
-
|
156 |
-
if __name__ == "__main__":
|
157 |
-
ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
|
158 |
-
print(ds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/gpu_stats.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
# cond import
|
2 |
-
try:
|
3 |
-
from pynvml import *
|
4 |
-
|
5 |
-
def print_gpu_utilization():
|
6 |
-
nvmlInit()
|
7 |
-
handle = nvmlDeviceGetHandleByIndex(0)
|
8 |
-
info = nvmlDeviceGetMemoryInfo(handle)
|
9 |
-
print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
|
10 |
-
|
11 |
-
def print_summary(result):
|
12 |
-
print(f"Time: {result.metrics['train_runtime']:.2f}")
|
13 |
-
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
|
14 |
-
print_gpu_utilization()
|
15 |
-
|
16 |
-
except ImportError:
|
17 |
-
print("pynvml not found. GPU stats will not be printed.")
|
18 |
-
|
19 |
-
def print_summary(result):
|
20 |
-
print(f"Time: {result.metrics['train_runtime']:.2f}")
|
21 |
-
print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
|
22 |
-
|
23 |
-
def print_gpu_utilization():
|
24 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/image_from_page_utils.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
from PIL import Image
|
3 |
-
|
4 |
-
|
5 |
-
def load_from_pdf(pdf_path: str):
|
6 |
-
from pdf2image import convert_from_path
|
7 |
-
|
8 |
-
images = convert_from_path(pdf_path)
|
9 |
-
return images
|
10 |
-
|
11 |
-
|
12 |
-
def load_from_image_urls(urls: str):
|
13 |
-
images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
|
14 |
-
return images
|
15 |
-
|
16 |
-
|
17 |
-
def load_from_dataset(dataset):
|
18 |
-
from datasets import load_dataset
|
19 |
-
|
20 |
-
dataset = load_dataset(dataset, split="test")
|
21 |
-
return dataset["image"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/image_utils.py
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Utility functions for working with images.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import base64
|
6 |
-
import io
|
7 |
-
|
8 |
-
from PIL import Image
|
9 |
-
|
10 |
-
|
11 |
-
def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
|
12 |
-
"""
|
13 |
-
Scale an image to a new height while maintaining the aspect ratio.
|
14 |
-
"""
|
15 |
-
# Calculate the scaling factor
|
16 |
-
width, height = image.size
|
17 |
-
aspect_ratio = width / height
|
18 |
-
new_width = int(new_height * aspect_ratio)
|
19 |
-
|
20 |
-
# Resize the image
|
21 |
-
scaled_image = image.resize((new_width, new_height))
|
22 |
-
|
23 |
-
return scaled_image
|
24 |
-
|
25 |
-
|
26 |
-
def scale_to_max_dimension(image: Image.Image, max_dimension: int = 1024) -> Image.Image:
|
27 |
-
"""
|
28 |
-
Scale an image to a maximum dimension while maintaining the aspect ratio.
|
29 |
-
"""
|
30 |
-
# Get the dimensions of the image
|
31 |
-
width, height = image.size
|
32 |
-
|
33 |
-
max_original_dimension = max(width, height)
|
34 |
-
|
35 |
-
if max_original_dimension < max_dimension:
|
36 |
-
return image
|
37 |
-
|
38 |
-
# Calculate the scaling factor
|
39 |
-
aspect_ratio = max_dimension / max_original_dimension
|
40 |
-
new_width = int(width * aspect_ratio)
|
41 |
-
new_height = int(height * aspect_ratio)
|
42 |
-
|
43 |
-
# Resize the image
|
44 |
-
scaled_image = image.resize((new_width, new_height))
|
45 |
-
|
46 |
-
return scaled_image
|
47 |
-
|
48 |
-
|
49 |
-
def get_base64_image(img: str | Image.Image, add_url_prefix: bool = True) -> str:
|
50 |
-
"""
|
51 |
-
Convert an image (from a filepath or a PIL.Image object) to a JPEG-base64 string.
|
52 |
-
"""
|
53 |
-
if isinstance(img, str):
|
54 |
-
img = Image.open(img)
|
55 |
-
elif isinstance(img, Image.Image):
|
56 |
-
pass
|
57 |
-
else:
|
58 |
-
raise ValueError("`img` must be a path to an image or a PIL Image object.")
|
59 |
-
|
60 |
-
buffered = io.BytesIO()
|
61 |
-
img.save(buffered, format="jpeg")
|
62 |
-
b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
63 |
-
|
64 |
-
return f"data:image/jpeg;base64,{b64_data}" if add_url_prefix else b64_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/iter_utils.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
import sys
|
2 |
-
|
3 |
-
|
4 |
-
def islice(iterable, *args):
|
5 |
-
"""
|
6 |
-
Yield a slice of an iterable.
|
7 |
-
>>> islice('ABCDEFG', 2) → A B
|
8 |
-
>>> islice('ABCDEFG', 2, 4) → C D
|
9 |
-
>>> islice('ABCDEFG', 2, None) → C D E F G
|
10 |
-
>>> islice('ABCDEFG', 0, None, 2) → A C E G
|
11 |
-
"""
|
12 |
-
s = slice(*args)
|
13 |
-
start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
|
14 |
-
it = iter(range(start, stop, step))
|
15 |
-
try:
|
16 |
-
nexti = next(it)
|
17 |
-
except StopIteration:
|
18 |
-
# Consume *iterable* up to the *start* position.
|
19 |
-
for i, element in zip(range(start), iterable):
|
20 |
-
pass
|
21 |
-
return
|
22 |
-
try:
|
23 |
-
for i, element in enumerate(iterable):
|
24 |
-
if i == nexti:
|
25 |
-
yield element
|
26 |
-
nexti = next(it)
|
27 |
-
except StopIteration:
|
28 |
-
# Consume to *stop*.
|
29 |
-
for i, element in zip(range(i + 1, stop), iterable):
|
30 |
-
pass
|
31 |
-
|
32 |
-
|
33 |
-
def batched(iterable, n: int):
|
34 |
-
"""
|
35 |
-
Yield batches of n elements from an iterable.
|
36 |
-
>>> batched('ABCDEFG', 3) → ABC DEF G
|
37 |
-
"""
|
38 |
-
if n < 1:
|
39 |
-
raise ValueError("n must be at least one")
|
40 |
-
it = iter(iterable)
|
41 |
-
while batch := tuple(islice(it, n)):
|
42 |
-
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/pdf_utils.py
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
import glob
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
from pathlib import Path
|
5 |
-
|
6 |
-
from pdf2image import convert_from_path
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
random.seed(42)
|
10 |
-
|
11 |
-
|
12 |
-
def convert_pdf_to_images(pdf_file: str, save_folder: str):
|
13 |
-
"""
|
14 |
-
Convert each page of a pdf to a jpg image and save them in a folder.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
- pdf_file (str): path to the pdf file
|
18 |
-
- save_folder (str): path to the folder where the images will be saved
|
19 |
-
|
20 |
-
"""
|
21 |
-
images = convert_from_path(pdf_file)
|
22 |
-
|
23 |
-
for i, image in enumerate(images):
|
24 |
-
if not os.path.exists(save_folder):
|
25 |
-
os.makedirs(save_folder)
|
26 |
-
image.save(os.path.join(save_folder, f"page_{i+1}.jpg"), "JPEG")
|
27 |
-
|
28 |
-
|
29 |
-
def convert_all_pdfs_to_images(path_to_folder: str, n_samples: int = 0):
|
30 |
-
"""
|
31 |
-
Convert all pdfs in a folder and its subfolder to images and save them in a folder.
|
32 |
-
It will sample n_samples pdf files in each subfolder, allowing to have granularity on the number of pdf files to convert.
|
33 |
-
|
34 |
-
|
35 |
-
Args:
|
36 |
-
- path_to_folder (str): path to the folder containing the pdf files
|
37 |
-
- n_samples (int): number of pdf files to sample in each subfolder
|
38 |
-
|
39 |
-
directory structure:
|
40 |
-
- path_to_folder
|
41 |
-
- subfolder1
|
42 |
-
- pdf1
|
43 |
-
- pdf2
|
44 |
-
- ...
|
45 |
-
- subfolder2
|
46 |
-
- pdf1
|
47 |
-
- pdf2
|
48 |
-
- ...
|
49 |
-
- ...
|
50 |
-
|
51 |
-
"""
|
52 |
-
# take n_samples pdf files in each subfolder : I want to take 10 pdf files from each subfolder
|
53 |
-
sub_dirs = [d for d in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, d))]
|
54 |
-
|
55 |
-
sampled_files = []
|
56 |
-
|
57 |
-
for sub_dir in sub_dirs:
|
58 |
-
pdf_files = glob.glob(os.path.join(path_to_folder, sub_dir, "*.pdf"))
|
59 |
-
|
60 |
-
if (n_samples == 0) or (len(pdf_files) <= n_samples):
|
61 |
-
print(f"Taking all pdf files in {sub_dir}")
|
62 |
-
sampled_files.extend(pdf_files)
|
63 |
-
|
64 |
-
else:
|
65 |
-
print(f"Taking {n_samples} pdf files in {sub_dir}")
|
66 |
-
sampled_files.extend(random.sample(pdf_files, n_samples))
|
67 |
-
|
68 |
-
pdf_files = [str(file) for file in sampled_files]
|
69 |
-
|
70 |
-
# Create an empty text file that will contain the file paths of the corrupted pdf files
|
71 |
-
dirpath_corrupted = Path(path_to_folder) / "corrupted_pdf_files.txt"
|
72 |
-
dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)
|
73 |
-
|
74 |
-
with dirpath_corrupted.open("w") as f:
|
75 |
-
with tqdm(total=len(pdf_files)) as pbar:
|
76 |
-
for pdf_file in pdf_files:
|
77 |
-
pbar.set_description(f"Processing {pdf_file}")
|
78 |
-
save_folder = os.path.join("pages_extracted", *Path(pdf_file).parts[-2:])
|
79 |
-
if not os.path.exists(os.path.join(path_to_folder, save_folder)):
|
80 |
-
try:
|
81 |
-
convert_pdf_to_images(pdf_file, os.path.join(path_to_folder, save_folder))
|
82 |
-
except Exception as e:
|
83 |
-
print(f"Error converting {pdf_file}: {e}")
|
84 |
-
f.write(pdf_file)
|
85 |
-
f.write("\n")
|
86 |
-
pbar.update(1)
|
87 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/plot_utils.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
import seaborn as sns
|
2 |
-
|
3 |
-
|
4 |
-
def setup_seaborn():
|
5 |
-
sns.set_style("white")
|
6 |
-
sns.set_context("paper", font_scale=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/torch_utils.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Utility functions for interpretability.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import torch
|
6 |
-
|
7 |
-
|
8 |
-
def get_torch_device() -> str:
|
9 |
-
"""
|
10 |
-
Returns the device and dtype to be used for torch tensors.
|
11 |
-
"""
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
device = "cuda:0"
|
14 |
-
elif torch.backends.mps.is_available(): # for Apple Silicon
|
15 |
-
device = "mps"
|
16 |
-
else:
|
17 |
-
device = "cpu"
|
18 |
-
return device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/train_colpali_engine_models.py
DELETED
@@ -1,247 +0,0 @@
|
|
1 |
-
# HuggingFace trainer
|
2 |
-
import json
|
3 |
-
import os
|
4 |
-
from dataclasses import dataclass
|
5 |
-
from typing import Callable, Dict, Optional
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from datasets import concatenate_datasets
|
9 |
-
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
from tqdm import tqdm
|
12 |
-
from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments
|
13 |
-
|
14 |
-
from colpali_engine.dataset.custom_collator import CustomCollator
|
15 |
-
from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss
|
16 |
-
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
|
17 |
-
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
|
18 |
-
from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
|
19 |
-
|
20 |
-
|
21 |
-
@dataclass
|
22 |
-
class ColModelTrainingConfig:
|
23 |
-
model: PreTrainedModel
|
24 |
-
tr_args: TrainingArguments = None
|
25 |
-
output_dir: str = None
|
26 |
-
max_length: int = 256
|
27 |
-
run_eval: bool = True
|
28 |
-
run_train: bool = True
|
29 |
-
peft_config: Optional[LoraConfig] = None
|
30 |
-
add_suffix: bool = False
|
31 |
-
processor: Idefics2Processor = None
|
32 |
-
tokenizer: PreTrainedTokenizer = None
|
33 |
-
loss_func: Optional[Callable] = ColbertLoss()
|
34 |
-
dataset_loading_func: Optional[Callable] = None
|
35 |
-
eval_dataset_loader: Optional[Dict[str, Callable]] = None
|
36 |
-
pretrained_peft_model_name_or_path: Optional[str] = None
|
37 |
-
|
38 |
-
def __post_init__(self):
|
39 |
-
if self.output_dir is None:
|
40 |
-
sanitized_name = str(self.model.name_or_path).replace("/", "_")
|
41 |
-
self.output_dir = f"./models/{sanitized_name}"
|
42 |
-
|
43 |
-
if self.tr_args is None:
|
44 |
-
self.tr_args = TrainingArguments(output_dir=self.output_dir)
|
45 |
-
elif self.tr_args.output_dir is None:
|
46 |
-
self.tr_args.output_dir = self.output_dir
|
47 |
-
|
48 |
-
# cast if string
|
49 |
-
if isinstance(self.tr_args.learning_rate, str):
|
50 |
-
self.tr_args.learning_rate = float(self.tr_args.learning_rate)
|
51 |
-
self.tr_args.remove_unused_columns = False
|
52 |
-
|
53 |
-
if self.processor is None and self.tokenizer is None:
|
54 |
-
print("Using textual model tokenization")
|
55 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
|
56 |
-
|
57 |
-
if self.pretrained_peft_model_name_or_path is not None:
|
58 |
-
self.model.load_adapter(self.pretrained_peft_model_name_or_path)
|
59 |
-
print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}")
|
60 |
-
|
61 |
-
if self.peft_config is not None:
|
62 |
-
print("Configurating PEFT model")
|
63 |
-
if self.processor is None:
|
64 |
-
# Might be deprecated - use the "else" branch
|
65 |
-
self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True
|
66 |
-
# self.model.enable_input_require_grads()
|
67 |
-
self.model = get_peft_model(self.model, self.peft_config)
|
68 |
-
self.model.print_trainable_parameters()
|
69 |
-
else:
|
70 |
-
# Ugly debugging hack
|
71 |
-
# if self.model.model.config.text_config.vocab_size == 32000:
|
72 |
-
# print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!")
|
73 |
-
# self.model.model.text_model.resize_token_embeddings(32003)
|
74 |
-
# self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2]
|
75 |
-
# self.model.enable_input_require_grads()
|
76 |
-
if self.pretrained_peft_model_name_or_path is None:
|
77 |
-
self.model.add_adapter(self.peft_config)
|
78 |
-
self.model.enable_adapters()
|
79 |
-
else:
|
80 |
-
print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
|
81 |
-
|
82 |
-
print_gpu_utilization()
|
83 |
-
|
84 |
-
|
85 |
-
class ColModelTraining:
|
86 |
-
def __init__(self, config: ColModelTrainingConfig) -> None:
|
87 |
-
self.config = config
|
88 |
-
self.model = self.config.model
|
89 |
-
self.dataset = self.config.dataset_loading_func()
|
90 |
-
self.collator = CustomCollator(
|
91 |
-
processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length
|
92 |
-
)
|
93 |
-
self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
|
94 |
-
self.retriever_evaluator = CustomEvaluator(
|
95 |
-
is_multi_vector=(
|
96 |
-
isinstance(self.config.loss_func, ColbertLoss)
|
97 |
-
or isinstance(self.config.loss_func, ColbertPairwiseCELoss)
|
98 |
-
)
|
99 |
-
)
|
100 |
-
|
101 |
-
def train(self) -> None:
|
102 |
-
|
103 |
-
trainer = ContrastiveTrainer(
|
104 |
-
model=self.model,
|
105 |
-
train_dataset=self.dataset["train"],
|
106 |
-
eval_dataset=self.dataset["test"],
|
107 |
-
args=self.config.tr_args,
|
108 |
-
data_collator=self.collator,
|
109 |
-
loss_func=self.config.loss_func,
|
110 |
-
is_vision_model=self.config.processor is not None,
|
111 |
-
)
|
112 |
-
trainer.args.remove_unused_columns = False
|
113 |
-
|
114 |
-
result = trainer.train()
|
115 |
-
print_summary(result)
|
116 |
-
|
117 |
-
def eval_dataset(self, test_dataset):
|
118 |
-
|
119 |
-
self.model.eval()
|
120 |
-
|
121 |
-
# # debug
|
122 |
-
# if len(test_dataset) > 200:
|
123 |
-
# test_dataset = test_dataset.select(range(0, 100))
|
124 |
-
|
125 |
-
idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
|
126 |
-
idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]
|
127 |
-
|
128 |
-
dataloader_with_query = DataLoader(
|
129 |
-
test_dataset.select(idx_with_query),
|
130 |
-
batch_size=self.config.tr_args.per_device_eval_batch_size,
|
131 |
-
shuffle=False,
|
132 |
-
collate_fn=self.collator,
|
133 |
-
)
|
134 |
-
dataloader_without_query = DataLoader(
|
135 |
-
test_dataset.select(idx_without_query),
|
136 |
-
batch_size=self.config.tr_args.per_device_eval_batch_size,
|
137 |
-
shuffle=False,
|
138 |
-
collate_fn=self.collator,
|
139 |
-
)
|
140 |
-
|
141 |
-
# dataset is ordered so that non-null queries come first
|
142 |
-
test_dataset = concatenate_datasets(
|
143 |
-
[test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)]
|
144 |
-
)
|
145 |
-
|
146 |
-
relevant_docs = {}
|
147 |
-
docidx_2_docid = {}
|
148 |
-
qsidx_2_query = []
|
149 |
-
for idx, sample in enumerate(test_dataset):
|
150 |
-
doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"]))
|
151 |
-
# query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"]))
|
152 |
-
if sample["query"] is not None:
|
153 |
-
relevant_docs[str(idx)] = {doc_id: 1}
|
154 |
-
qsidx_2_query.append(str(idx))
|
155 |
-
docidx_2_docid[str(idx)] = doc_id
|
156 |
-
|
157 |
-
qs = []
|
158 |
-
ps = []
|
159 |
-
|
160 |
-
device = self.model.device
|
161 |
-
with (torch.no_grad()):
|
162 |
-
for dataloader in [dataloader_with_query, dataloader_without_query]:
|
163 |
-
for batch in tqdm(dataloader):
|
164 |
-
if "doc_pixel_values" not in batch:
|
165 |
-
doc = self.model(
|
166 |
-
input_ids=batch["doc_input_ids"].to(device),
|
167 |
-
attention_mask=batch["doc_attention_mask"].to(device),
|
168 |
-
)
|
169 |
-
|
170 |
-
else:
|
171 |
-
if "doc_pixel_attention_mask" in batch:
|
172 |
-
doc = self.model(
|
173 |
-
input_ids=batch["doc_input_ids"].to(device),
|
174 |
-
attention_mask=batch["doc_attention_mask"].to(device),
|
175 |
-
pixel_values=batch["doc_pixel_values"].to(device),
|
176 |
-
pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device),
|
177 |
-
)
|
178 |
-
else:
|
179 |
-
doc = self.model(
|
180 |
-
input_ids=batch["doc_input_ids"].to(device),
|
181 |
-
attention_mask=batch["doc_attention_mask"].to(device),
|
182 |
-
pixel_values=batch["doc_pixel_values"].to(device),
|
183 |
-
)
|
184 |
-
|
185 |
-
ps.extend(list(torch.unbind(doc.to("cpu"))))
|
186 |
-
|
187 |
-
if "query_input_ids" in batch:
|
188 |
-
query = self.model(
|
189 |
-
input_ids=batch["query_input_ids"].to(device),
|
190 |
-
attention_mask=batch["query_attention_mask"].to(device),
|
191 |
-
)
|
192 |
-
# variable len
|
193 |
-
qs.extend(list(torch.unbind(query.to("cpu"))))
|
194 |
-
|
195 |
-
print("Embeddings computed, evaluating")
|
196 |
-
scores = self.retriever_evaluator.evaluate(qs, ps)
|
197 |
-
# scores is 2d array of shape (n_queries, n_docs)
|
198 |
-
# turn it into a dict
|
199 |
-
results = {}
|
200 |
-
assert scores.shape[0] == len(qsidx_2_query)
|
201 |
-
for idx, scores_per_query in enumerate(scores):
|
202 |
-
results[qsidx_2_query[idx]] = {
|
203 |
-
docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query)
|
204 |
-
}
|
205 |
-
|
206 |
-
# evaluate
|
207 |
-
metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results)
|
208 |
-
print(metrics)
|
209 |
-
return metrics
|
210 |
-
|
211 |
-
def eval(self) -> None:
|
212 |
-
|
213 |
-
print("Evaluating on validation set")
|
214 |
-
metrics = self.eval_dataset(self.dataset["test"])
|
215 |
-
print(f"Metrics for validation set: {metrics}")
|
216 |
-
all_metrics = {"validation_set": metrics}
|
217 |
-
|
218 |
-
if self.config.eval_dataset_loader is not None:
|
219 |
-
for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
|
220 |
-
print(f"Evaluating {test_name}")
|
221 |
-
test_ds = test_dataset_loading_func()
|
222 |
-
metrics = self.eval_dataset(test_ds)
|
223 |
-
all_metrics[test_name] = metrics
|
224 |
-
print(f"Metrics for {test_name}: {metrics}")
|
225 |
-
|
226 |
-
# checkpoint dumps
|
227 |
-
with open(f"{self.config.output_dir}/results.json", "w") as f:
|
228 |
-
json.dump(all_metrics, f)
|
229 |
-
|
230 |
-
# save results as json
|
231 |
-
with open(f"{self.config.output_dir}/results.json", "w") as f:
|
232 |
-
json.dump(all_metrics, f)
|
233 |
-
|
234 |
-
def save(self, config_file):
|
235 |
-
# save model
|
236 |
-
self.model.save_pretrained(self.config.output_dir)
|
237 |
-
if self.config.tokenizer is not None:
|
238 |
-
self.config.tokenizer.save_pretrained(self.config.output_dir)
|
239 |
-
if self.config.processor is not None:
|
240 |
-
self.config.processor.save_pretrained(self.config.output_dir) # save config
|
241 |
-
|
242 |
-
# copy-paste the yml file with os
|
243 |
-
os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml")
|
244 |
-
|
245 |
-
# save git hash of the commit at beginning of training
|
246 |
-
with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
|
247 |
-
f.write(self.current_git_hash)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colpali_engine/utils/wrapper.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
import importlib
|
2 |
-
|
3 |
-
from colpali_engine.models.clip_baselines import ColSigLIP, SigLIP
|
4 |
-
from colpali_engine.models.colbert_architectures import (
|
5 |
-
BiBERT,
|
6 |
-
BiXLMRoBERTa,
|
7 |
-
ColBERT,
|
8 |
-
ColCamembert,
|
9 |
-
ColLlama,
|
10 |
-
ColXLMRoBERTa,
|
11 |
-
)
|
12 |
-
from colpali_engine.models.idefics_colbert_architecture import BiIdefics, ColIdefics
|
13 |
-
from colpali_engine.models.paligemma_colbert_architecture import (
|
14 |
-
BiNewSiglip,
|
15 |
-
BiPaliLast,
|
16 |
-
BiPaliMean,
|
17 |
-
ColNewSiglip,
|
18 |
-
ColPali,
|
19 |
-
)
|
20 |
-
|
21 |
-
if importlib.util.find_spec("transformers") is not None:
|
22 |
-
from transformers import AutoProcessor, AutoTokenizer
|
23 |
-
from transformers.tokenization_utils import PreTrainedTokenizer
|
24 |
-
|
25 |
-
class AutoProcessorWrapper:
|
26 |
-
def __new__(cls, *args, **kwargs):
|
27 |
-
return AutoProcessor.from_pretrained(*args, **kwargs)
|
28 |
-
|
29 |
-
class AutoTokenizerWrapper(PreTrainedTokenizer):
|
30 |
-
def __new__(cls, *args, **kwargs):
|
31 |
-
return AutoTokenizer.from_pretrained(*args, **kwargs)
|
32 |
-
|
33 |
-
class AutoColModelWrapper:
|
34 |
-
def __new__(cls, *args, **kwargs):
|
35 |
-
pretrained_model_name_or_path = None
|
36 |
-
if args:
|
37 |
-
pretrained_model_name_or_path = args[0]
|
38 |
-
elif kwargs:
|
39 |
-
pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
|
40 |
-
|
41 |
-
training_objective = kwargs.pop("training_objective", "colbertv1")
|
42 |
-
|
43 |
-
if "camembert" in pretrained_model_name_or_path:
|
44 |
-
return ColCamembert.from_pretrained(*args, **kwargs)
|
45 |
-
elif "xlm-roberta" in pretrained_model_name_or_path:
|
46 |
-
if training_objective == "biencoder":
|
47 |
-
return BiXLMRoBERTa.from_pretrained(*args, **kwargs)
|
48 |
-
return ColXLMRoBERTa.from_pretrained(*args, **kwargs)
|
49 |
-
elif (
|
50 |
-
"llama" in pretrained_model_name_or_path.lower() or "croissant" in pretrained_model_name_or_path.lower()
|
51 |
-
):
|
52 |
-
return ColLlama.from_pretrained(*args, **kwargs)
|
53 |
-
elif "idefics2" in pretrained_model_name_or_path:
|
54 |
-
if training_objective == "biencoder":
|
55 |
-
return BiIdefics.from_pretrained(*args, **kwargs)
|
56 |
-
return ColIdefics.from_pretrained(*args, **kwargs)
|
57 |
-
elif "siglip" in pretrained_model_name_or_path:
|
58 |
-
if training_objective == "biencoder_mean":
|
59 |
-
return SigLIP.from_pretrained(*args, **kwargs)
|
60 |
-
elif training_objective == "colbertv1":
|
61 |
-
return ColSigLIP.from_pretrained(*args, **kwargs)
|
62 |
-
else:
|
63 |
-
raise ValueError(f"Training objective {training_objective} not recognized")
|
64 |
-
elif "paligemma" in pretrained_model_name_or_path:
|
65 |
-
if training_objective == "biencoder_mean":
|
66 |
-
return BiPaliMean.from_pretrained(*args, **kwargs)
|
67 |
-
elif training_objective == "biencoder_last":
|
68 |
-
return BiPaliLast.from_pretrained(*args, **kwargs)
|
69 |
-
elif training_objective == "biencoder_mean_vision":
|
70 |
-
return BiNewSiglip.from_pretrained(*args, **kwargs)
|
71 |
-
elif training_objective == "colbertv1_vision":
|
72 |
-
return ColNewSiglip.from_pretrained(*args, **kwargs)
|
73 |
-
elif training_objective == "colbertv1":
|
74 |
-
return ColPali.from_pretrained(*args, **kwargs)
|
75 |
-
else:
|
76 |
-
raise ValueError(f"Training objective {training_objective} not recognized")
|
77 |
-
else:
|
78 |
-
if training_objective == "biencoder":
|
79 |
-
return BiBERT.from_pretrained(*args, **kwargs)
|
80 |
-
return ColBERT.from_pretrained(*args, **kwargs)
|
81 |
-
|
82 |
-
else:
|
83 |
-
raise ModuleNotFoundError("Transformers must be loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|