HUANG-Stephanie commited on
Commit
c4d37d5
1 Parent(s): d106c36

Delete colpali_engine

Browse files
Files changed (43) hide show
  1. colpali_engine/__init__.py +0 -0
  2. colpali_engine/__pycache__/__init__.cpython-310.pyc +0 -0
  3. colpali_engine/dataset/__init__.py +0 -0
  4. colpali_engine/dataset/custom_collator.py +0 -244
  5. colpali_engine/dataset/hf_dataset_names.py +0 -52
  6. colpali_engine/evaluation/__init__.py +0 -1
  7. colpali_engine/evaluation/eval_manager.py +0 -178
  8. colpali_engine/interpretability/__init__.py +0 -4
  9. colpali_engine/interpretability/gen_interpretability_plots.py +0 -113
  10. colpali_engine/interpretability/plot_utils.py +0 -131
  11. colpali_engine/interpretability/processor.py +0 -116
  12. colpali_engine/interpretability/torch_utils.py +0 -60
  13. colpali_engine/interpretability/vit_configs.py +0 -23
  14. colpali_engine/loss/__init__.py +0 -1
  15. colpali_engine/loss/colbert_loss.py +0 -122
  16. colpali_engine/models/__init__.py +0 -0
  17. colpali_engine/models/__pycache__/__init__.cpython-310.pyc +0 -0
  18. colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc +0 -0
  19. colpali_engine/models/clip_baselines.py +0 -144
  20. colpali_engine/models/colbert_architectures.py +0 -177
  21. colpali_engine/models/idefics_colbert_architecture.py +0 -57
  22. colpali_engine/models/paligemma_colbert_architecture.py +0 -191
  23. colpali_engine/trainer/__init__.py +0 -0
  24. colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  25. colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc +0 -0
  26. colpali_engine/trainer/contrastive_trainer.py +0 -64
  27. colpali_engine/trainer/retrieval_evaluator.py +0 -72
  28. colpali_engine/utils/__init__.py +0 -0
  29. colpali_engine/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  30. colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc +0 -0
  31. colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc +0 -0
  32. colpali_engine/utils/colidefics_processing_utils.py +0 -53
  33. colpali_engine/utils/colpali_processing_utils.py +0 -36
  34. colpali_engine/utils/dataset_transformation.py +0 -158
  35. colpali_engine/utils/gpu_stats.py +0 -24
  36. colpali_engine/utils/image_from_page_utils.py +0 -21
  37. colpali_engine/utils/image_utils.py +0 -64
  38. colpali_engine/utils/iter_utils.py +0 -42
  39. colpali_engine/utils/pdf_utils.py +0 -87
  40. colpali_engine/utils/plot_utils.py +0 -6
  41. colpali_engine/utils/torch_utils.py +0 -18
  42. colpali_engine/utils/train_colpali_engine_models.py +0 -247
  43. colpali_engine/utils/wrapper.py +0 -83
colpali_engine/__init__.py DELETED
File without changes
colpali_engine/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (163 Bytes)
 
colpali_engine/dataset/__init__.py DELETED
File without changes
colpali_engine/dataset/custom_collator.py DELETED
@@ -1,244 +0,0 @@
1
- from transformers import PreTrainedTokenizer, ProcessorMixin
2
-
3
-
4
- class CustomCollator:
5
- def __init__(
6
- self,
7
- processor: ProcessorMixin = None,
8
- tokenizer: PreTrainedTokenizer = None,
9
- max_length: int = 2048,
10
- add_suffix: bool = False,
11
- ):
12
- self.processor = processor
13
- self.tokenizer = tokenizer
14
- self.image_token_id = None
15
- self.max_length = max_length
16
- self.suffix = ""
17
- if add_suffix:
18
- self.suffix = "\n" * 10
19
-
20
- if tokenizer is None and processor is None:
21
- raise ValueError("Either processor or tokenizer should be provided.")
22
-
23
- if self.processor is not None:
24
- if self.processor.__class__.__name__ != "SiglipProcessor":
25
- self.image_token_id = self.processor.tokenizer.additional_special_tokens_ids[
26
- self.processor.tokenizer.additional_special_tokens.index("<image>")
27
- ]
28
-
29
- if self.tokenizer is not None:
30
- raise ValueError("Only one of processor or tokenizer should be provided.")
31
-
32
- if self.tokenizer and self.tokenizer.pad_token is None:
33
- self.tokenizer.pad_token = self.tokenizer.eos_token
34
-
35
- def __call__(self, examples):
36
- if self.processor is None:
37
- return self.forward_text(examples)
38
- if self.processor.__class__.__name__ == "Idefics2Processor":
39
- return self.forward_vision_idefics(examples)
40
- if self.processor.__class__.__name__ == "PaliGemmaProcessor":
41
- return self.forward_vision_pali(examples)
42
- if self.processor.__class__.__name__ == "SiglipProcessor":
43
- return self.forward_vision_siglip(examples)
44
- raise ValueError("Processor not supported")
45
-
46
- def forward_text(self, examples):
47
- texts_doc = []
48
- texts_query = []
49
- for example in examples:
50
- text_query = example["query"] + self.suffix
51
- text_doc = example["doc"]
52
-
53
- texts_doc.append(text_doc.strip())
54
- texts_query.append(text_query.strip())
55
-
56
- batch_doc = self.tokenizer(
57
- texts_doc, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
58
- )
59
- batch_query = self.tokenizer(
60
- texts_query, max_length=self.max_length, padding="longest", truncation=True, return_tensors="pt"
61
- )
62
-
63
- # prefix each key with "doc_" or "query_" to avoid key conflicts
64
- batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
65
- batch_query = {f"query_{k}": v for k, v in batch_query.items()}
66
- batch_doc.update(batch_query)
67
-
68
- return batch_doc
69
-
70
- def forward_vision_idefics(self, examples):
71
- texts_doc = []
72
- texts_query = []
73
- images = []
74
- for example in examples:
75
- image = example["image"]
76
-
77
- text_query = None
78
- if example["query"] is not None:
79
- query = example["query"]
80
- messages_query = [
81
- {
82
- "role": "user",
83
- "content": [
84
- {
85
- "type": "text",
86
- "text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
87
- },
88
- ],
89
- },
90
- ]
91
- text_query = self.processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
92
-
93
- messages_doc = [
94
- {
95
- "role": "user",
96
- "content": [
97
- {"type": "text", "text": "Describe the image."},
98
- {"type": "image"},
99
- ],
100
- },
101
- ]
102
-
103
- text_doc = self.processor.apply_chat_template(messages_doc, add_generation_prompt=False)
104
-
105
- texts_doc.append(text_doc.strip())
106
- texts_query.append(text_query)
107
- images.append([image])
108
-
109
- batch_doc = self.processor(
110
- text=texts_doc, images=images, return_tensors="pt", padding="longest", max_length=self.max_length
111
- )
112
-
113
- batch_query = None
114
- if all([t is None for t in texts_query]):
115
- print("All queries are None. Returning None for all queries.")
116
- elif any([t is None for t in texts_query]):
117
- raise ValueError("Some queries are None. This collator does not support None queries yet.")
118
- else:
119
- batch_query = self.processor(
120
- text=texts_query, return_tensors="pt", padding="longest", max_length=self.max_length
121
- )
122
-
123
- # prefix each key with "doc_" or "query_" to avoid key conflicts
124
- batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
125
-
126
- if batch_query is not None:
127
- batch_query = {f"query_{k}": v for k, v in batch_query.items()}
128
- batch_doc.update(batch_query)
129
-
130
- return batch_doc
131
-
132
- def forward_vision_pali(self, examples):
133
- texts_doc = []
134
- texts_query = []
135
- images = []
136
- for example in examples:
137
-
138
- if example["image"] is None:
139
- raise ValueError("Image is None - This collator does not support None images yet.")
140
-
141
- image = example["image"].convert("RGB")
142
- images.append(image)
143
- texts_doc.append("Describe the image.")
144
-
145
- if example["query"] is None:
146
- texts_query.append(None)
147
- else:
148
- query = example["query"]
149
- query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
150
- texts_query.append(query)
151
-
152
- batch_doc = self.processor(
153
- text=texts_doc,
154
- images=images,
155
- return_tensors="pt",
156
- padding="longest",
157
- max_length=self.max_length + self.processor.image_seq_length,
158
- )
159
-
160
- batch_query = None
161
- # check if some but not all queries are None
162
- if all([t is None for t in texts_query]):
163
- print("All queries are None. Returning None for all queries.")
164
- elif any([t is None for t in texts_query]):
165
- raise ValueError("Some queries are None. This collator does not support None queries yet.")
166
- else:
167
- batch_query = self.processor(
168
- images=images, # NOTE: the image is not used in batch_query but it is required for calling the processor
169
- text=texts_query,
170
- return_tensors="pt",
171
- padding="longest",
172
- max_length=self.max_length + self.processor.image_seq_length,
173
- )
174
- del batch_query["pixel_values"]
175
- batch_query["input_ids"] = batch_query["input_ids"][..., self.processor.image_seq_length :]
176
- batch_query["attention_mask"] = batch_query["attention_mask"][..., self.processor.image_seq_length :]
177
-
178
- # prefix each key with "doc_" or "query_" to avoid key conflicts
179
- batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
180
-
181
- if batch_query is not None:
182
- batch_query = {f"query_{k}": v for k, v in batch_query.items()}
183
- batch_doc.update(batch_query)
184
-
185
- return batch_doc
186
-
187
- def forward_vision_siglip(self, examples):
188
- texts_doc = []
189
- texts_query = []
190
- images = []
191
- for example in examples:
192
-
193
- if example["image"] is None:
194
- raise ValueError("Image is None - This collator does not support None images yet.")
195
-
196
- image = example["image"].convert("RGB")
197
- images.append(image)
198
- texts_doc.append("Describe the image.")
199
-
200
- if example["query"] is None:
201
- texts_query.append(None)
202
- else:
203
- query = f"Question: {example['query']}"
204
- texts_query.append(query)
205
-
206
- batch_doc = self.processor(
207
- text=texts_doc,
208
- images=images,
209
- return_tensors="pt",
210
- padding="max_length",
211
- truncation=True,
212
- )
213
-
214
- batch_query = None
215
- # check if some but not all queries are None
216
- if all([t is None for t in texts_query]):
217
- # print("All queries are None.")
218
- pass
219
- elif any([t is None for t in texts_query]):
220
- raise ValueError("Some queries are None. This collator does not support None queries yet.")
221
- else:
222
- batch_query = self.processor(
223
- images=images,
224
- text=texts_query,
225
- return_tensors="pt",
226
- padding="max_length",
227
- max_length=self.max_length,
228
- truncation=True,
229
- )
230
- del batch_query["pixel_values"]
231
-
232
- # prefix each key with "doc_" or "query_" to avoid key conflicts
233
- batch_doc = {f"doc_{k}": v for k, v in batch_doc.items()}
234
-
235
- if batch_query is not None:
236
- batch_query = {f"query_{k}": v for k, v in batch_query.items()}
237
- batch_doc.update(batch_query)
238
- # add attention mask for queries
239
- batch_doc["query_attention_mask"] = batch_doc["query_input_ids"].ne(0).long()
240
-
241
- # add attention mask for docs
242
- batch_doc["doc_attention_mask"] = batch_doc["doc_input_ids"].ne(0).long()
243
-
244
- return batch_doc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/dataset/hf_dataset_names.py DELETED
@@ -1,52 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class TrainDatasets(Enum):
5
- """
6
- Dataset names for the training datasets used in HuggingFace Datasets.
7
- """
8
-
9
- government_reports = "vidore/syntheticDocQA_government_reports_train"
10
- healthcare_industry = "vidore/syntheticDocQA_healthcare_industry_train"
11
- energy = "vidore/syntheticDocQA_energy_train"
12
- artificial_intelligence = "vidore/syntheticDocQA_artificial_intelligence_train"
13
- arxivqa = "vidore/arxivqa_train"
14
- docvqa = "vidore/docvqa_train"
15
- infovqa = "vidore/infovqa_train"
16
- tatqa = "vidore/tatqa_train"
17
-
18
- @staticmethod
19
- def get_synthetic_datasets():
20
- return [
21
- TrainDatasets.government_reports,
22
- TrainDatasets.healthcare_industry,
23
- TrainDatasets.energy,
24
- TrainDatasets.artificial_intelligence,
25
- ]
26
-
27
-
28
- class TestImagesDirpath(Enum):
29
- """
30
- Dataset names for the test datasets used in HuggingFace Datasets.
31
- """
32
-
33
- government_reports = "data/government_reports"
34
- healthcare_industry = "data/healthcare_industry"
35
- energy = "data/energy"
36
- artificial_intelligence = "data/scrapped_pdfs_split/pages_extracted/artificial_intelligence_test"
37
- arxivqa = "data/arxivqa"
38
- docvqa = "data/docvqa"
39
- infovqa = "data/infovqa"
40
- tatqa = "data/tatqa"
41
-
42
-
43
- class CaptionedSyntheticDatasets(Enum):
44
- """
45
- Dataset names for the captioned synthetic datasets used in HuggingFace Datasets.
46
- """
47
-
48
- shift = "vidore/baseline_cap_shiftproject_test"
49
-
50
-
51
- class SyntheticDocQATest(Enum):
52
- shift = "vidore/shiftproject_test"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/evaluation/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .eval_manager import EvalManager
 
 
colpali_engine/evaluation/eval_manager.py DELETED
@@ -1,178 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from typing import Any, ClassVar, Dict, Optional
5
-
6
- import pandas as pd
7
-
8
-
9
- class EvalManager:
10
- """
11
- Stores evaluation results for various datasets and metrics.
12
-
13
- The data is stored in a pandas DataFrame with a MultiIndex for columns.
14
- The first level of the MultiIndex is the dataset name and the second level is the metric name.
15
-
16
- Usage:
17
- >>> evaluator = Evaluator.from_dirpath("data/evaluation_results/")
18
- >>> print(evaluator.data)
19
-
20
- """
21
-
22
- model_col: ClassVar[str] = "model"
23
- dataset_col: ClassVar[str] = "dataset"
24
- metric_col: ClassVar[str] = "metric"
25
-
26
- def __init__(self, data: Optional[pd.DataFrame] = None):
27
- if data is None:
28
- data = pd.DataFrame()
29
- self._df = data
30
- self._df.index = self._df.index.rename(EvalManager.model_col)
31
-
32
- def __str__(self) -> str:
33
- return self.data.__str__()
34
-
35
- @staticmethod
36
- def from_dict(data: Dict[Any, Any]) -> EvalManager:
37
- """
38
- Load evaluation results from a dictionary.
39
-
40
- Expected format:
41
- {
42
- "model1": pd.read_json(path1).T.stack(),
43
- "model2": pd.read_json(path2).T.stack(),
44
- }
45
-
46
- """
47
- df = pd.DataFrame.from_dict(data, orient="index")
48
- return EvalManager(df)
49
-
50
- @staticmethod
51
- def from_json(path: str | Path) -> EvalManager:
52
- datapath = Path(path)
53
- if not datapath.is_file():
54
- raise FileNotFoundError(f"{path} is not a file")
55
- data = {}
56
- data[datapath.stem] = pd.read_json(datapath).T.stack() # pylint: disable=no-member
57
- return EvalManager.from_dict(data)
58
-
59
- @staticmethod
60
- def from_dir(datadir: str | Path) -> EvalManager:
61
- datadir_ = Path(datadir)
62
- if not datadir_.is_dir():
63
- raise FileNotFoundError(f"{datadir} is not a directory")
64
-
65
- eval_files = list(datadir_.glob("*.json"))
66
-
67
- data = {}
68
-
69
- for filepath in eval_files:
70
- data[filepath.stem] = pd.read_json(filepath).T.stack() # pylint: disable=no-member
71
-
72
- return EvalManager.from_dict(data)
73
-
74
- @staticmethod
75
- def from_csv(path: str | Path) -> EvalManager:
76
- """
77
- Load evaluation results from a CSV file.
78
- """
79
- try:
80
- df = pd.read_csv(path, index_col=0, header=[0, 1])
81
- return EvalManager(df)
82
- except Exception as e:
83
- print(f"Error loading {path}: {e}")
84
- raise e
85
-
86
- @property
87
- def data(self) -> pd.DataFrame:
88
- """
89
- Returns the evaluation results as a pandas DataFrame.
90
- """
91
- return self._df.copy()
92
-
93
- @property
94
- def models(self) -> pd.Index:
95
- """
96
- Returns the models for which there are evaluation results.
97
- """
98
- return self.data.index
99
-
100
- @property
101
- def datasets(self) -> pd.Index:
102
- """
103
- Returns the datasets for which there are evaluation results.
104
- """
105
- return self.data.columns.get_level_values(0).unique()
106
-
107
- @property
108
- def metrics(self) -> pd.Index:
109
- """
110
- Returns the metrics for which there are evaluation results.
111
- """
112
- return self.data.columns.get_level_values(1)
113
-
114
- @staticmethod
115
- def melt(df: pd.DataFrame) -> pd.DataFrame:
116
- """
117
- Melt a suitable DataFrame (e.g. returned by `get_df_for_dataset` and
118
- `get_df_for_metric`) into a 'long' format.
119
- """
120
- return df.T.reset_index(names=[EvalManager.dataset_col, EvalManager.metric_col]).melt(
121
- id_vars=[EvalManager.dataset_col, EvalManager.metric_col],
122
- var_name=EvalManager.model_col,
123
- value_name="score",
124
- )
125
-
126
- @property
127
- def melted(self) -> pd.DataFrame:
128
- """
129
- Returns the evaluation results as a 'melted' DataFrame.
130
- Useful for plotting with seaborn.
131
- """
132
- return EvalManager.melt(self.data)
133
-
134
- def get_df_for_model(self, model: str) -> pd.DataFrame:
135
- if model not in self.data.index:
136
- raise ValueError(f"Model {model} not found in the evaluation results")
137
- return self.data.loc[[model], :] # type: ignore
138
-
139
- def get_df_for_dataset(self, dataset: str) -> pd.DataFrame:
140
- if dataset not in self.datasets:
141
- raise ValueError(f"Dataset {dataset} not found in the evaluation results")
142
- return self.data.loc[:, (dataset, slice(None))] # type: ignore
143
-
144
- def get_df_for_metric(self, metric: str) -> pd.DataFrame:
145
- if metric not in self.metrics:
146
- raise ValueError(f"Metric {metric} not found in the evaluation results")
147
- return self.data.loc[:, (slice(None), metric)] # type: ignore
148
-
149
- def sort_by_dataset(self, ascending: bool = True) -> EvalManager:
150
- """
151
- Sort the evaluation results by dataset name.
152
- """
153
- df = self.data.T.sort_index(level=0, ascending=ascending).T
154
- return EvalManager(df)
155
-
156
- def sort_by_metric(self, ascending: bool = True) -> EvalManager:
157
- """
158
- Sort the evaluation results by metric name.
159
- """
160
- df = self.data.T.sort_index(level=1, ascending=ascending).T
161
- return EvalManager(df)
162
-
163
- def sort_columns(self, ascending: bool = True) -> EvalManager:
164
- """
165
- Sort the evaluation results by dataset name and then by metric name.
166
- """
167
- df = self.data.T.sort_index(level=[0, 1], ascending=ascending).T
168
- return EvalManager(df)
169
-
170
- def to_csv(self, path: str | Path):
171
- """
172
- Save the evaluation results to a CSV file.
173
-
174
- Using `Evaluation.from_csv(path_to_saved_csv)` will load the evaluation results back into memory.
175
- """
176
- savepath = Path(path)
177
- savepath.parent.mkdir(parents=True, exist_ok=True)
178
- self.data.to_csv(savepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/interpretability/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- from .plot_utils import *
2
- from .processor import *
3
- from .torch_utils import *
4
- from .vit_configs import *
 
 
 
 
 
colpali_engine/interpretability/gen_interpretability_plots.py DELETED
@@ -1,113 +0,0 @@
1
- import pprint
2
- from dataclasses import asdict, dataclass
3
- from pathlib import Path
4
- from uuid import uuid4
5
-
6
- import matplotlib.pyplot as plt
7
- import torch
8
- from einops import rearrange
9
- from PIL import Image
10
- from tqdm import trange
11
-
12
- from colpali_engine.interpretability.plot_utils import plot_patches
13
- from colpali_engine.interpretability.processor import ColPaliProcessor
14
- from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
15
- from colpali_engine.interpretability.vit_configs import VIT_CONFIG
16
- from colpali_engine.models.paligemma_colbert_architecture import ColPali
17
-
18
- OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
19
-
20
-
21
- @dataclass
22
- class InterpretabilityInput:
23
- query: str
24
- image: Image.Image
25
- start_idx_token: int
26
- end_idx_token: int
27
-
28
-
29
- def generate_interpretability_plots(
30
- model: ColPali,
31
- processor: ColPaliProcessor,
32
- query: str,
33
- image: Image.Image,
34
- savedir: str | Path | None = None,
35
- add_special_prompt_to_doc: bool = True,
36
- ) -> None:
37
-
38
- # Sanity checks
39
- if len(model.active_adapters()) != 1:
40
- raise ValueError("The model must have exactly one active adapter.")
41
-
42
- if model.config.name_or_path not in VIT_CONFIG:
43
- raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
44
- vit_config = VIT_CONFIG[model.config.name_or_path]
45
-
46
- # Handle savepath
47
- if not savedir:
48
- savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
49
- print(f"No savepath provided. Results will be saved to: `{savedir}`.")
50
- elif isinstance(savedir, str):
51
- savedir = Path(savedir)
52
- savedir.mkdir(parents=True, exist_ok=True)
53
-
54
- # Resize the image to square
55
- input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
56
-
57
- # Preprocess the inputs
58
- input_text_processed = processor.process_text(query).to(model.device)
59
- input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
60
- model.device
61
- )
62
-
63
- # Forward pass
64
- with torch.no_grad():
65
- output_text = model.forward(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
66
-
67
- # NOTE: `output_image`` will have shape:
68
- # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
69
- # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
70
- with torch.no_grad():
71
- output_image = model.forward(**asdict(input_image_processed))
72
-
73
- if add_special_prompt_to_doc: # remove the special tokens
74
- output_image = output_image[
75
- :, : processor.processor.image_seq_length, :
76
- ] # (1, n_patch_x * n_patch_y, hidden_dim)
77
-
78
- output_image = rearrange(
79
- output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
80
- ) # (1, n_patch_x, n_patch_y, hidden_dim)
81
-
82
- # Get the unnormalized attention map
83
- attention_map = torch.einsum(
84
- "bnk,bijk->bnij", output_text, output_image
85
- ) # (1, n_text_tokens, n_patch_x, n_patch_y)
86
- attention_map_normalized = normalize_attention_map_per_query_token(
87
- attention_map
88
- ) # (1, n_text_tokens, n_patch_x, n_patch_y)
89
- attention_map_normalized = attention_map_normalized.float()
90
-
91
- # Get text token information
92
- n_tokens = input_text_processed.input_ids.size(1)
93
- text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
94
- print("Text tokens:")
95
- pprint.pprint(text_tokens)
96
- print("\n")
97
-
98
- for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."): # exclude the <bos> and the "\n" tokens
99
- fig, axis = plot_patches(
100
- input_image_square,
101
- vit_config.patch_size,
102
- vit_config.resolution,
103
- patch_opacities=attention_map_normalized[0, token_idx, :, :],
104
- style="dark_background",
105
- )
106
-
107
- fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
108
- savepath = savedir / f"token_{token_idx}.png"
109
- fig.savefig(savepath)
110
- print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
111
- plt.close(fig)
112
-
113
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/interpretability/plot_utils.py DELETED
@@ -1,131 +0,0 @@
1
- from typing import Any, Dict, Optional, Tuple, cast
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import numpy.typing as npt
6
- import seaborn as sns
7
- import torch
8
- from PIL import Image
9
-
10
- MAX_OPACITY = 255
11
-
12
-
13
- def plot_patches(
14
- img: Image.Image,
15
- patch_size: int,
16
- image_resolution: int,
17
- patch_opacities: Optional[npt.NDArray | torch.Tensor] = None,
18
- figsize: Tuple[int, int] = (8, 8),
19
- style: Dict[str, Any] | str | None = None,
20
- ) -> Tuple[plt.Figure, plt.Axes]:
21
- """
22
- Plot patches of a square image.
23
- Set `style` to "dark_background" if your image has a light background.
24
- """
25
-
26
- # Get the number of patches
27
- if image_resolution % patch_size != 0:
28
- raise ValueError("The image resolution must be divisible by the patch size.")
29
- num_patches = image_resolution // patch_size
30
-
31
- # Default style
32
- if style is None:
33
- style = {}
34
-
35
- # Sanity checks
36
- if patch_opacities is not None:
37
- if isinstance(patch_opacities, torch.Tensor):
38
- patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy())
39
- if patch_opacities.shape != (num_patches, num_patches):
40
- raise ValueError("The shape of the patch_opacities tensor is not correct.")
41
- if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)):
42
- raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
43
-
44
- # If the image is not square, raise an error
45
- if img.size[0] != img.size[1]:
46
- raise ValueError("The image must be square.")
47
-
48
- # Get the image as a numpy array
49
- img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
50
-
51
- # Create a figure
52
- with plt.style.context(style):
53
- fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)
54
-
55
- # Plot the patches
56
- for i in range(num_patches):
57
- for j in range(num_patches):
58
- patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
59
- # Set the opacity of the patch
60
- if patch_opacities is not None:
61
- patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY)
62
- axis[i, j].imshow(patch)
63
- axis[i, j].axis("off")
64
-
65
- fig.subplots_adjust(wspace=0.1, hspace=0.1)
66
-
67
- fig.tight_layout()
68
-
69
- return fig, axis
70
-
71
-
72
- def plot_attention_heatmap(
73
- img: Image.Image,
74
- patch_size: int,
75
- image_resolution: int,
76
- attention_map: npt.NDArray | torch.Tensor,
77
- figsize: Tuple[int, int] = (8, 8),
78
- style: Dict[str, Any] | str | None = None,
79
- show_colorbar: bool = False,
80
- show_axes: bool = False,
81
- ) -> Tuple[plt.Figure, plt.Axes]:
82
- """
83
- Plot a heatmap of the attention map over the image.
84
- The image must be square and `attention_map` must be normalized between 0 and 1.
85
- """
86
-
87
- # Get the number of patches
88
- if image_resolution % patch_size != 0:
89
- raise ValueError("The image resolution must be divisible by the patch size.")
90
- num_patches = image_resolution // patch_size
91
-
92
- # Default style
93
- if style is None:
94
- style = {}
95
-
96
- # Sanity checks
97
- if isinstance(attention_map, torch.Tensor):
98
- attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
99
- if attention_map.shape != (num_patches, num_patches):
100
- raise ValueError("The shape of the patch_opacities tensor is not correct.")
101
- if not np.all((0 <= attention_map) & (attention_map <= 1)):
102
- raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
103
-
104
- # If the image is not square, raise an error
105
- if img.size[0] != img.size[1]:
106
- raise ValueError("The image must be square.")
107
-
108
- # Get the image as a numpy array
109
- img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
110
-
111
- # Get the attention map as a numpy array
112
- attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
113
- img.size, Image.Resampling.BICUBIC
114
- )
115
-
116
- # Create a figure
117
- with plt.style.context(style):
118
- fig, ax = plt.subplots(figsize=figsize)
119
- ax.imshow(img_array)
120
- im = ax.imshow(
121
- attention_map_image,
122
- cmap=sns.color_palette("mako", as_cmap=True),
123
- alpha=0.5,
124
- )
125
- if show_colorbar:
126
- fig.colorbar(im)
127
- if not show_axes:
128
- ax.set_axis_off()
129
- fig.tight_layout()
130
-
131
- return fig, ax
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/interpretability/processor.py DELETED
@@ -1,116 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from dataclasses import dataclass
4
- from typing import List, cast
5
-
6
- import torch
7
- from PIL import Image
8
- from transformers import LlamaTokenizerFast, PaliGemmaProcessor
9
-
10
-
11
- @dataclass
12
- class ColPaliTextInput:
13
- input_ids: torch.Tensor
14
- attention_mask: torch.Tensor
15
-
16
- def to(self, device: torch.device) -> ColPaliTextInput:
17
- return ColPaliTextInput(
18
- input_ids=self.input_ids.to(device),
19
- attention_mask=self.attention_mask.to(device),
20
- )
21
-
22
-
23
- @dataclass
24
- class ColPaliImageInput:
25
- input_ids: torch.Tensor
26
- pixel_values: torch.Tensor
27
- attention_mask: torch.Tensor
28
-
29
- def to(self, device: str | torch.device) -> ColPaliImageInput:
30
- return ColPaliImageInput(
31
- input_ids=self.input_ids.to(device),
32
- pixel_values=self.pixel_values.to(device),
33
- attention_mask=self.attention_mask.to(device),
34
- )
35
-
36
-
37
- class ColPaliProcessor:
38
- def __init__(self, processor: PaliGemmaProcessor):
39
- self.processor = processor
40
- self.tokenizer = cast(LlamaTokenizerFast, self.processor.tokenizer) # type: ignore
41
-
42
- @staticmethod
43
- def from_pretrained(model_name: str) -> ColPaliProcessor:
44
- return ColPaliProcessor(processor=cast(PaliGemmaProcessor, PaliGemmaProcessor.from_pretrained(model_name)))
45
-
46
- def process_text(
47
- self,
48
- text: str | List[str],
49
- padding: str = "longest",
50
- return_tensors: str = "pt",
51
- add_special_tokens: bool = True,
52
- ) -> ColPaliTextInput:
53
- """
54
- Process text inputs for the model.
55
- If `add_special_tokens` is True (default), the text will be prepended with the <bos> token and appended with " \n".
56
- """
57
- if add_special_tokens:
58
- if isinstance(text, str):
59
- text = self.tokenizer.bos_token + text + "\n"
60
- elif isinstance(text, list):
61
- text = [self.tokenizer.bos_token + t + "\n" for t in text]
62
- else:
63
- raise ValueError("text must be a string or a list of strings.")
64
-
65
- batch_output = self.tokenizer(
66
- text, padding=padding, return_tensors=return_tensors, add_special_tokens=add_special_tokens
67
- )
68
-
69
- return ColPaliTextInput(
70
- input_ids=cast(torch.Tensor, batch_output["input_ids"]),
71
- attention_mask=cast(torch.Tensor, batch_output["attention_mask"]),
72
- )
73
-
74
- def process_image(
75
- self,
76
- image: Image.Image | List[Image.Image],
77
- padding: str = "longest",
78
- do_convert_rgb: bool = True,
79
- return_tensors: str = "pt",
80
- add_special_prompt: bool = True,
81
- ) -> ColPaliImageInput:
82
- # NOTE: The special prompt was used at training time,
83
- special_prompt = "Describe the image." if add_special_prompt else None
84
- if isinstance(image, Image.Image):
85
- text_input = [special_prompt]
86
- elif isinstance(image, list):
87
- text_input = [special_prompt] * len(image)
88
- else:
89
- raise ValueError("image must be a PIL Image or a list of PIL Images.")
90
-
91
- batch_output = self.processor(
92
- text=text_input,
93
- images=image,
94
- padding=padding,
95
- do_convert_rgb=do_convert_rgb,
96
- return_tensors=return_tensors,
97
- )
98
-
99
- if add_special_prompt:
100
- return ColPaliImageInput(
101
- input_ids=batch_output["input_ids"],
102
- pixel_values=batch_output["pixel_values"],
103
- attention_mask=batch_output["attention_mask"],
104
- )
105
- else:
106
- return ColPaliImageInput(
107
- input_ids=batch_output["input_ids"][:, : self.processor.image_seq_length],
108
- pixel_values=batch_output["pixel_values"][:, : self.processor.image_seq_length],
109
- attention_mask=batch_output["attention_mask"][:, : self.processor.image_seq_length],
110
- )
111
-
112
- def decode(self, *args, **kwargs):
113
- return self.tokenizer.decode(*args, **kwargs)
114
-
115
- def batch_decode(self, *args, **kwargs):
116
- return self.tokenizer.batch_decode(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/interpretability/torch_utils.py DELETED
@@ -1,60 +0,0 @@
1
- import logging
2
-
3
- import torch
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
- EPSILON = 1e-10
8
-
9
-
10
- def normalize_attention_map_per_query_token(x: torch.Tensor) -> torch.Tensor:
11
- """
12
- Normalizes the attention map for ColPali for each query token.
13
- The output tensor will have values in the range [0, 1] and the
14
- same shape as the input tensor.
15
-
16
- Args:
17
- x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
18
- """
19
- if x.ndim != 4:
20
- raise ValueError("The input tensor must have 4 dimensions.")
21
-
22
- # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
23
- min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
24
-
25
- # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
26
- max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
27
-
28
- # Normalize the tensor
29
- x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
30
-
31
- return x_normalized
32
-
33
-
34
- def normalize_attention_map_per_query(x: torch.Tensor) -> torch.Tensor:
35
- """
36
- Normalizes the attention map for ColPali for each query token.
37
- The output tensor will have values in the range [0, 1] and the
38
- same shape as the input tensor.
39
-
40
- Args:
41
- x: The attention map tensor of shape (batch_size, n_text_tokens, n_patch_x, n_patch_y).
42
- """
43
- # Log warning
44
- logger.warning(
45
- "This function should not be used for ColPali because it doesn't make sense to normalize the attention map across the text tokens."
46
- )
47
-
48
- if x.ndim != 4:
49
- raise ValueError("The input tensor must have 4 dimensions.")
50
-
51
- # Compute the minimum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
52
- min_vals = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0].min(dim=-3, keepdim=True)[0]
53
-
54
- # Compute the maximum values along the last three dimensions (n_text_tokens, n_patch_x, n_patch_y)
55
- max_vals = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0].max(dim=-3, keepdim=True)[0]
56
-
57
- # Normalize the tensor
58
- x_normalized = (x - min_vals) / (max_vals - min_vals + EPSILON) # Adding a small epsilon to avoid division by zero
59
-
60
- return x_normalized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/interpretability/vit_configs.py DELETED
@@ -1,23 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict
3
-
4
-
5
- @dataclass
6
- class ViTConfig:
7
- patch_size: int
8
- resolution: int
9
-
10
- @property
11
- def n_patch_per_dim(self) -> int:
12
- if self.resolution % self.patch_size != 0:
13
- raise ValueError(f"Resolution {self.resolution} is not divisible by patch size {self.patch_size}")
14
- return self.resolution // self.patch_size
15
-
16
-
17
- VIT_CONFIG: Dict[str, ViTConfig] = {
18
- "google/siglip-so400m-patch14-384": ViTConfig(patch_size=14, resolution=384),
19
- "timm/ViT-SO400M-14-SigLIP-384": ViTConfig(patch_size=14, resolution=384),
20
- "google/paligemma-3b-mix-448": ViTConfig(
21
- patch_size=14, resolution=448
22
- ), # based on "timm/ViT-SO400M-14-SigLIP-384" with increased resolution
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/loss/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .colbert_loss import ColbertLoss
 
 
colpali_engine/loss/colbert_loss.py DELETED
@@ -1,122 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch.nn import CrossEntropyLoss
4
-
5
-
6
- class BiEncoderLoss(torch.nn.Module):
7
- def __init__(self):
8
- super().__init__()
9
- self.ce_loss = CrossEntropyLoss()
10
- # self.pooling_strategy = pooling_strategy
11
-
12
- def forward(self, query_embeddings, doc_embeddings):
13
- """
14
- query_embeddings: (batch_size, dim)
15
- doc_embeddings: (batch_size, dim)
16
- """
17
-
18
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
19
-
20
- loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
21
- # loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
22
- # loss = (loss_rowwise + loss_columnwise) / 2
23
- return loss_rowwise
24
-
25
-
26
- class ColbertLoss(torch.nn.Module):
27
- def __init__(self):
28
- super().__init__()
29
- self.ce_loss = CrossEntropyLoss()
30
-
31
- def forward(self, query_embeddings, doc_embeddings):
32
- """
33
- query_embeddings: (batch_size, num_query_tokens, dim)
34
- doc_embeddings: (batch_size, num_doc_tokens, dim)
35
- """
36
-
37
- scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
38
-
39
- # scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
40
- # for i in range(query_embeddings.shape[0]):
41
- # for j in range(doc_embeddings.shape[0]):
42
- # # step 1 - dot product --> (s1,s2)
43
- # q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
44
- # # step 2 -> max on doc --> (s1)
45
- # q_scores = torch.max(q2d_scores, dim=1)[0]
46
- # # step 3 --> sum the max score --> (1)
47
- # sum_q_score = torch.sum(q_scores)
48
- # # step 4 --> assert is scalar
49
- # scores[i, j] = sum_q_score
50
-
51
- # assert (scores_einsum - scores < 0.0001).all().item()
52
-
53
- loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
54
- # TODO: comparing between queries might not make sense since it's a sum over the length of the query
55
- # loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
56
- # loss = (loss_rowwise + loss_columnwise) / 2
57
- return loss_rowwise
58
-
59
-
60
- class ColbertPairwiseCELoss(torch.nn.Module):
61
- def __init__(self):
62
- super().__init__()
63
- self.ce_loss = CrossEntropyLoss()
64
-
65
- def forward(self, query_embeddings, doc_embeddings):
66
- """
67
- query_embeddings: (batch_size, num_query_tokens, dim)
68
- doc_embeddings: (batch_size, num_doc_tokens, dim)
69
-
70
- Positive scores are the diagonal of the scores matrix.
71
- """
72
-
73
- # Compute the ColBERT scores
74
- scores = (
75
- torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
76
- ) # (batch_size, batch_size)
77
-
78
- # Positive scores are the diagonal of the scores matrix.
79
- pos_scores = scores.diagonal() # (batch_size,)
80
-
81
- # Negative score for a given query is the maximum of the scores against all all other pages.
82
- # NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
83
- # we can subtract 1 from the diagonal to exclude it from the maximum operation.
84
- neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
85
- neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
86
-
87
- # Compute the loss
88
- # The loss is computed as the negative log of the softmax of the positive scores
89
- # relative to the negative scores.
90
- # This can be simplified to log-sum-exp of negative scores minus the positive score
91
- # for numerical stability.
92
- # torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
93
- loss = F.softplus(neg_scores - pos_scores).mean()
94
-
95
- return loss
96
-
97
-
98
- class BiPairwiseCELoss(torch.nn.Module):
99
- def __init__(self):
100
- super().__init__()
101
- self.ce_loss = CrossEntropyLoss()
102
-
103
- def forward(self, query_embeddings, doc_embeddings):
104
- """
105
- query_embeddings: (batch_size, dim)
106
- doc_embeddings: (batch_size, dim)
107
- """
108
-
109
- scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
110
-
111
- pos_scores = scores.diagonal()
112
- neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
113
- neg_scores = neg_scores.max(dim=1)[0]
114
-
115
- # Compute the loss
116
- # The loss is computed as the negative log of the softmax of the positive scores
117
- # relative to the negative scores.
118
- # This can be simplified to log-sum-exp of negative scores minus the positive score
119
- # for numerical stability.
120
- loss = F.softplus(neg_scores - pos_scores).mean()
121
-
122
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/models/__init__.py DELETED
File without changes
colpali_engine/models/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (170 Bytes)
 
colpali_engine/models/__pycache__/paligemma_colbert_architecture.cpython-310.pyc DELETED
Binary file (4.87 kB)
 
colpali_engine/models/clip_baselines.py DELETED
@@ -1,144 +0,0 @@
1
- import os
2
- from typing import Optional
3
-
4
- import torch
5
- from transformers import SiglipModel
6
-
7
-
8
- class SigLIP(SiglipModel):
9
- def forward(self, *args, **kwargs):
10
- """
11
- Forward pass through Llama and the linear layer for dimensionality reduction
12
-
13
- Args:
14
- - input_ids (torch.LongTensor): The input tokens tensor.
15
- - attention_mask (torch.LongTensor): The attention mask tensor.
16
-
17
- Returns:
18
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
19
- """
20
- return self.forward_branch(*args, **kwargs)
21
-
22
- def forward_branch(
23
- self,
24
- input_ids: Optional[torch.LongTensor] = None,
25
- pixel_values: Optional[torch.FloatTensor] = None,
26
- attention_mask: Optional[torch.Tensor] = None,
27
- position_ids: Optional[torch.LongTensor] = None,
28
- return_loss: Optional[bool] = None,
29
- output_attentions: Optional[bool] = None,
30
- output_hidden_states: Optional[bool] = None,
31
- return_dict: Optional[bool] = None,
32
- interpolate_pos_encoding: bool = False,
33
- ):
34
-
35
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
36
- output_hidden_states = (
37
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
38
- )
39
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
40
-
41
- if pixel_values is not None:
42
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
43
-
44
- outputs = self.vision_model(
45
- pixel_values=pixel_values.to(dtype=self.dtype),
46
- output_attentions=output_attentions,
47
- output_hidden_states=output_hidden_states,
48
- return_dict=return_dict,
49
- interpolate_pos_encoding=interpolate_pos_encoding,
50
- )
51
-
52
- else:
53
- outputs = self.text_model(
54
- input_ids=input_ids,
55
- attention_mask=attention_mask,
56
- position_ids=position_ids,
57
- output_attentions=output_attentions,
58
- output_hidden_states=output_hidden_states,
59
- return_dict=return_dict,
60
- )
61
-
62
- embeds = outputs[1]
63
-
64
- # normalized features
65
- embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True)
66
- return embeds
67
-
68
-
69
- class ColSigLIP(SiglipModel):
70
- def __init__(self, config):
71
- super(ColSigLIP, self).__init__(config=config)
72
- self.dim = 128
73
- self.custom_vision_proj = torch.nn.Linear(self.config.vision_config.hidden_size, self.dim)
74
- self.custom_text_proj = torch.nn.Linear(self.config.text_config.hidden_size, self.dim)
75
- self.main_input_name = "doc_input_ids"
76
-
77
- def forward(self, *args, **kwargs):
78
- """
79
- Forward pass through Llama and the linear layer for dimensionality reduction
80
-
81
- Args:
82
- - input_ids (torch.LongTensor): The input tokens tensor.
83
- - attention_mask (torch.LongTensor): The attention mask tensor.
84
-
85
- Returns:
86
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
87
- """
88
- return self.forward_branch(*args, **kwargs)
89
-
90
- def forward_branch(
91
- self,
92
- input_ids: Optional[torch.LongTensor] = None,
93
- pixel_values: Optional[torch.FloatTensor] = None,
94
- attention_mask: Optional[torch.Tensor] = None,
95
- position_ids: Optional[torch.LongTensor] = None,
96
- return_loss: Optional[bool] = None,
97
- output_attentions: Optional[bool] = None,
98
- output_hidden_states: Optional[bool] = None,
99
- return_dict: Optional[bool] = None,
100
- interpolate_pos_encoding: bool = False,
101
- ):
102
-
103
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
104
- output_hidden_states = (
105
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
106
- )
107
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
108
-
109
- if pixel_values is not None:
110
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
111
-
112
- outputs = self.vision_model(
113
- pixel_values=pixel_values.to(dtype=self.dtype),
114
- output_attentions=output_attentions,
115
- output_hidden_states=output_hidden_states,
116
- return_dict=return_dict,
117
- interpolate_pos_encoding=interpolate_pos_encoding,
118
- )
119
-
120
- last_hidden_states = outputs.last_hidden_state
121
-
122
- proj = self.custom_vision_proj(last_hidden_states)
123
- # normalize l2 norm
124
- proj = proj / proj.norm(dim=-1, keepdim=True)
125
-
126
- else:
127
- outputs = self.text_model(
128
- input_ids=input_ids,
129
- attention_mask=attention_mask,
130
- position_ids=position_ids,
131
- output_attentions=output_attentions,
132
- output_hidden_states=output_hidden_states,
133
- return_dict=return_dict,
134
- )
135
-
136
- last_hidden_states = outputs.last_hidden_state
137
-
138
- proj = self.custom_text_proj(last_hidden_states)
139
- # normalize l2 norm
140
- proj = proj / proj.norm(dim=-1, keepdim=True)
141
- proj = proj * attention_mask.unsqueeze(-1)
142
-
143
- # normalized features
144
- return proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/models/colbert_architectures.py DELETED
@@ -1,177 +0,0 @@
1
- from torch import nn
2
- from transformers import (
3
- BertModel,
4
- BertPreTrainedModel,
5
- CamembertModel,
6
- CamembertPreTrainedModel,
7
- LlamaModel,
8
- LlamaPreTrainedModel,
9
- XLMRobertaModel,
10
- XLMRobertaPreTrainedModel,
11
- )
12
-
13
-
14
- class ColCamembert(CamembertPreTrainedModel):
15
- def __init__(self, config):
16
- super(ColCamembert, self).__init__(config=config)
17
- self.roberta: CamembertPreTrainedModel = CamembertModel(config)
18
- self.dim = 128
19
- self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
20
- self.main_input_name = "doc_input_ids"
21
-
22
- def forward(self, *args, **kwargs):
23
- """
24
- Forward pass through Camenbert and the linear layer for dimensionality reduction
25
-
26
- Args:
27
- - input_ids (torch.LongTensor): The input tokens tensor.
28
- - attention_mask (torch.LongTensor): The attention mask tensor.
29
-
30
- Returns:
31
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
32
- """
33
- outputs = self.roberta(*args, **kwargs)
34
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
35
- proj = self.linear(last_hidden_states)
36
- # normalize l2 norm
37
- proj = proj / proj.norm(dim=-1, keepdim=True)
38
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
39
- return proj
40
-
41
-
42
- class ColXLMRoBERTa(XLMRobertaPreTrainedModel):
43
- def __init__(self, config):
44
- super(ColXLMRoBERTa, self).__init__(config=config)
45
- self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
46
- self.dim = 128
47
- self.linear = nn.Linear(self.roberta.config.hidden_size, self.dim)
48
- self.main_input_name = "doc_input_ids"
49
-
50
- def forward(self, *args, **kwargs):
51
- """
52
- Forward pass through Roberta and the linear layer for dimensionality reduction
53
-
54
- Args:
55
- - input_ids (torch.LongTensor): The input tokens tensor.
56
- - attention_mask (torch.LongTensor): The attention mask tensor.
57
-
58
- Returns:
59
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
60
- """
61
- outputs = self.roberta(*args, **kwargs)
62
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
63
- proj = self.linear(last_hidden_states)
64
- # normalize l2 norm
65
- proj = proj / proj.norm(dim=-1, keepdim=True)
66
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
67
- return proj
68
-
69
-
70
- class BiXLMRoBERTa(XLMRobertaPreTrainedModel):
71
- def __init__(self, config):
72
- super(BiXLMRoBERTa, self).__init__(config=config)
73
- self.roberta: XLMRobertaPreTrainedModel = XLMRobertaModel(config)
74
- self.main_input_name = "doc_input_ids"
75
-
76
- def forward(self, *args, **kwargs):
77
- """
78
- Forward pass through Roberta and the linear layer for dimensionality reduction
79
-
80
- Args:
81
- - input_ids (torch.LongTensor): The input tokens tensor.
82
- - attention_mask (torch.LongTensor): The attention mask tensor.
83
-
84
- Returns:
85
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
86
- """
87
- outputs = self.roberta(*args, **kwargs)
88
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
89
- # pooling - mean tokens that have attention mask == 1
90
- proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
91
- proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
92
- # normalize l2 norm
93
- proj = proj / proj.norm(dim=-1, keepdim=True)
94
- return proj
95
-
96
-
97
- class ColBERT(BertPreTrainedModel):
98
- def __init__(self, config):
99
- super(ColBERT, self).__init__(config=config)
100
- self.bert: BertModel = BertModel(config)
101
- self.dim = 128
102
- self.linear = nn.Linear(self.bert.config.hidden_size, self.dim)
103
- self.main_input_name = "doc_input_ids"
104
-
105
- def forward(self, *args, **kwargs):
106
- """
107
- Forward pass through BERT and the linear layer for dimensionality reduction
108
-
109
- Args:
110
- - input_ids (torch.LongTensor): The input tokens tensor.
111
- - attention_mask (torch.LongTensor): The attention mask tensor.
112
-
113
- Returns:
114
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
115
- """
116
- outputs = self.bert(*args, **kwargs)
117
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
118
- proj = self.linear(last_hidden_states)
119
- # normalize l2 norm
120
- proj = proj / proj.norm(dim=-1, keepdim=True)
121
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
122
- return proj
123
-
124
-
125
- class BiBERT(BertPreTrainedModel):
126
- def __init__(self, config):
127
- super(BiBERT, self).__init__(config=config)
128
- self.bert: BertModel = BertModel(config)
129
- self.main_input_name = "doc_input_ids"
130
-
131
- def forward(self, *args, **kwargs):
132
- """
133
- Forward pass through BERT and the linear layer for dimensionality reduction
134
-
135
- Args:
136
- - input_ids (torch.LongTensor): The input tokens tensor.
137
- - attention_mask (torch.LongTensor): The attention mask tensor.
138
-
139
- Returns:
140
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
141
- """
142
- outputs = self.bert(*args, **kwargs)
143
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
144
- # pooling - mean tokens that have attention mask == 1
145
- proj = last_hidden_states * kwargs["attention_mask"].unsqueeze(-1)
146
- proj = proj.sum(dim=1) / kwargs["attention_mask"].sum(dim=1, keepdim=True)
147
- # normalize l2 norm
148
- proj = proj / proj.norm(dim=-1, keepdim=True)
149
- return proj
150
-
151
-
152
- class ColLlama(LlamaPreTrainedModel):
153
- def __init__(self, config):
154
- super(ColLlama, self).__init__(config=config)
155
- self.model: LlamaModel = LlamaModel(config)
156
- self.dim = 128
157
- self.linear = nn.Linear(self.model.config.hidden_size, self.dim)
158
- self.main_input_name = "doc_input_ids"
159
-
160
- def forward(self, *args, **kwargs):
161
- """
162
- Forward pass through Llama and the linear layer for dimensionality reduction
163
-
164
- Args:
165
- - input_ids (torch.LongTensor): The input tokens tensor.
166
- - attention_mask (torch.LongTensor): The attention mask tensor.
167
-
168
- Returns:
169
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
170
- """
171
- outputs = self.model(*args, **kwargs)
172
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
173
- proj = self.linear(last_hidden_states)
174
- # normalize l2 norm
175
- proj = proj / proj.norm(dim=-1, keepdim=True)
176
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
177
- return proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/models/idefics_colbert_architecture.py DELETED
@@ -1,57 +0,0 @@
1
- from torch import nn
2
- from transformers import Idefics2Model, Idefics2PreTrainedModel
3
-
4
-
5
- class BiIdefics(Idefics2PreTrainedModel):
6
- def __init__(self, config):
7
- super(BiIdefics, self).__init__(config=config)
8
- self.model: Idefics2Model = Idefics2Model(config)
9
- self.pooling_strategy = "last"
10
- self.main_input_name = "doc_input_ids"
11
-
12
- def forward(self, *args, **kwargs):
13
- """
14
- Forward pass through Llama and the linear layer for dimensionality reduction
15
-
16
- Args:
17
- - input_ids (torch.LongTensor): The input tokens tensor.
18
- - attention_mask (torch.LongTensor): The attention mask tensor.
19
-
20
- Returns:
21
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
22
- """
23
- outputs = self.model(*args, **kwargs)
24
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
25
- # pooling - last token
26
- proj = last_hidden_states[:, -1, :]
27
- # normalize l2 norm
28
- proj = proj / proj.norm(dim=-1, keepdim=True)
29
- return proj
30
-
31
-
32
- class ColIdefics(Idefics2PreTrainedModel):
33
- def __init__(self, config):
34
- super(ColIdefics, self).__init__(config=config)
35
- self.model: Idefics2Model = Idefics2Model(config)
36
- self.dim = 128
37
- self.linear = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
38
- self.main_input_name = "doc_input_ids"
39
-
40
- def forward(self, *args, **kwargs):
41
- """
42
- Forward pass through Llama and the linear layer for dimensionality reduction
43
-
44
- Args:
45
- - input_ids (torch.LongTensor): The input tokens tensor.
46
- - attention_mask (torch.LongTensor): The attention mask tensor.
47
-
48
- Returns:
49
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
50
- """
51
- outputs = self.model(*args, **kwargs)
52
- last_hidden_states = outputs[0] # (batch_size, sequence_length, hidden_size)
53
- proj = self.linear(last_hidden_states)
54
- # normalize l2 norm
55
- proj = proj / proj.norm(dim=-1, keepdim=True)
56
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
57
- return proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/models/paligemma_colbert_architecture.py DELETED
@@ -1,191 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration, PaliGemmaPreTrainedModel
4
-
5
-
6
- class BiPaliLast(PaliGemmaPreTrainedModel):
7
- def __init__(self, config):
8
- super(BiPaliLast, self).__init__(config=config)
9
- self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
10
- self.pooling_strategy = "last"
11
- self.main_input_name = "doc_input_ids"
12
-
13
- def forward(self, *args, **kwargs):
14
- """
15
- Forward pass through Llama and the linear layer for dimensionality reduction
16
-
17
- Args:
18
- - input_ids (torch.LongTensor): The input tokens tensor.
19
- - attention_mask (torch.LongTensor): The attention mask tensor.
20
-
21
- Returns:
22
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
23
- """
24
- outputs = self.model(*args, output_hidden_states=True, **kwargs)
25
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
26
- # pooling - last token
27
- proj = last_hidden_states[:, -1, :]
28
- # normalize l2 norm
29
- proj = proj / proj.norm(dim=-1, keepdim=True)
30
- return proj
31
-
32
-
33
- class BiPaliMean(PaliGemmaPreTrainedModel):
34
- def __init__(self, config):
35
- super(BiPaliMean, self).__init__(config=config)
36
- self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
37
- self.pooling_strategy = "mean"
38
- self.main_input_name = "doc_input_ids"
39
-
40
- def forward(self, *args, **kwargs):
41
- """
42
- Forward pass through Llama and the linear layer for dimensionality reduction
43
-
44
- Args:
45
- - input_ids (torch.LongTensor): The input tokens tensor.
46
- - attention_mask (torch.LongTensor): The attention mask tensor.
47
-
48
- Returns:
49
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
50
- """
51
- outputs = self.model(*args, output_hidden_states=True, **kwargs)
52
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
53
- # pooling -mean on attention mask==1
54
- proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
55
- kwargs["attention_mask"], dim=1, keepdim=True
56
- )
57
- proj = proj / proj.norm(dim=-1, keepdim=True)
58
- return proj
59
-
60
-
61
- class ColPali(PaliGemmaPreTrainedModel):
62
- def __init__(self, config):
63
- super(ColPali, self).__init__(config=config)
64
- self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
65
- self.dim = 128
66
- self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
67
- self.main_input_name = "doc_input_ids"
68
-
69
- def forward(self, *args, **kwargs):
70
- """
71
- Forward pass through Llama and the linear layer for dimensionality reduction
72
-
73
- Args:
74
- - input_ids (torch.LongTensor): The input tokens tensor.
75
- - attention_mask (torch.LongTensor): The attention mask tensor.
76
-
77
- Returns:
78
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
79
- """
80
- outputs = self.model(*args, output_hidden_states=True, **kwargs)
81
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
82
- proj = self.custom_text_proj(last_hidden_states)
83
- # normalize l2 norm
84
- proj = proj / proj.norm(dim=-1, keepdim=True)
85
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
86
- return proj
87
-
88
-
89
- class ColNewSiglip(PaliGemmaPreTrainedModel):
90
- def __init__(self, config):
91
- super(ColNewSiglip, self).__init__(config=config)
92
- self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
93
- self.dim = 128
94
- self.custom_image_proj = nn.Linear(self.model.config.vision_config.projection_dim, self.dim)
95
- self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
96
- self.main_input_name = "doc_input_ids"
97
-
98
- def forward(self, *args, **kwargs):
99
- """
100
- Forward pass through Llama and the linear layer for dimensionality reduction
101
-
102
- Args:
103
- - input_ids (torch.LongTensor): The input tokens tensor.
104
- - attention_mask (torch.LongTensor): The attention mask tensor.
105
-
106
- Returns:
107
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
108
- """
109
- # outputs = self.model(*args, output_hidden_states=True, **kwargs)
110
- if "pixel_values" in kwargs:
111
- image_features = self.vision_model_output(*args, **kwargs)
112
- # print(f"Doc: {image_features.shape}")
113
- proj = self.custom_image_proj(image_features)
114
- # print(f"Doc proj: {proj.shape}")
115
- proj = proj / proj.norm(dim=-1, keepdim=True)
116
- else:
117
- outputs = self.model(*args, output_hidden_states=True, **kwargs)
118
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
119
- # print(f"Query: {last_hidden_states.shape}")
120
- proj = self.custom_text_proj(last_hidden_states)
121
- # print(f"Query proj: {proj.shape}")
122
- # normalize l2 norm
123
- proj = proj / proj.norm(dim=-1, keepdim=True)
124
- proj = proj * kwargs["attention_mask"].unsqueeze(-1)
125
- return proj
126
-
127
- def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
128
-
129
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
130
- # 2. Merge text and images
131
- if pixel_values is not None and input_ids.shape[1] != 1:
132
- image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
133
- selected_image_feature = image_outputs.last_hidden_state
134
- image_features = self.model.multi_modal_projector(selected_image_feature)
135
-
136
- return image_features
137
-
138
- raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
139
-
140
-
141
- class BiNewSiglip(PaliGemmaPreTrainedModel):
142
- def __init__(self, config):
143
- super(BiNewSiglip, self).__init__(config=config)
144
- self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
145
- self.main_input_name = "doc_input_ids"
146
-
147
- def forward(self, *args, **kwargs):
148
- """
149
- Forward pass through Llama and the linear layer for dimensionality reduction
150
-
151
- Args:
152
- - input_ids (torch.LongTensor): The input tokens tensor.
153
- - attention_mask (torch.LongTensor): The attention mask tensor.
154
-
155
- Returns:
156
- - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
157
- """
158
- # outputs = self.model(*args, output_hidden_states=True, **kwargs)
159
- if "pixel_values" in kwargs:
160
- image_features = self.vision_model_output(*args, **kwargs)
161
- # print(f"Doc: {image_features.shape}")
162
- # pool image features
163
- proj = torch.mean(image_features, dim=1)
164
- # print(f"Doc proj: {proj.shape}")
165
- norm = proj.norm(dim=-1, keepdim=True)
166
- proj = proj / norm
167
- else:
168
- outputs = self.model(*args, output_hidden_states=True, **kwargs)
169
- last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
170
- # pooling -mean on attention mask==1
171
-
172
- proj = torch.sum(last_hidden_states * kwargs["attention_mask"].unsqueeze(-1), dim=1) / torch.sum(
173
- kwargs["attention_mask"], dim=1, keepdim=True
174
- )
175
- # print(f"Query proj: {proj.shape}")
176
- norm = proj.norm(dim=-1, keepdim=True)
177
- proj = proj / norm
178
- return proj
179
-
180
- def vision_model_output(self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, **kwargs):
181
-
182
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
183
- # 2. Merge text and images
184
- if pixel_values is not None and input_ids.shape[1] != 1:
185
- image_outputs = self.model.vision_tower(pixel_values.to(inputs_embeds.dtype))
186
- selected_image_feature = image_outputs.last_hidden_state
187
- image_features = self.model.multi_modal_projector(selected_image_feature)
188
-
189
- return image_features
190
-
191
- raise ValueError("pixel_values is None or input_ids.shape[1] == 1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/trainer/__init__.py DELETED
File without changes
colpali_engine/trainer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (171 Bytes)
 
colpali_engine/trainer/__pycache__/retrieval_evaluator.cpython-310.pyc DELETED
Binary file (3.18 kB)
 
colpali_engine/trainer/contrastive_trainer.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- from transformers import Trainer
3
-
4
-
5
- class ContrastiveTrainer(Trainer):
6
- def __init__(self, loss_func, is_vision_model, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.loss_func = loss_func
9
- self.is_vision_model = is_vision_model
10
-
11
- def compute_loss(self, model, inputs, return_outputs=False):
12
- query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
13
- if self.is_vision_model:
14
- if "doc_pixel_attention_mask" not in inputs:
15
- doc_outputs = model(
16
- input_ids=inputs["doc_input_ids"],
17
- attention_mask=inputs["doc_attention_mask"],
18
- pixel_values=inputs["doc_pixel_values"],
19
- )
20
- else:
21
- doc_outputs = model(
22
- input_ids=inputs["doc_input_ids"],
23
- attention_mask=inputs["doc_attention_mask"],
24
- pixel_values=inputs["doc_pixel_values"],
25
- pixel_attention_mask=inputs["doc_pixel_attention_mask"],
26
- )
27
- else:
28
- doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
29
-
30
- loss = self.loss_func(query_outputs, doc_outputs)
31
- return (loss, (query_outputs, doc_outputs)) if return_outputs else loss
32
-
33
- def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
34
- """This function is used to generate predictions and return the loss for the given inputs."""
35
- if not prediction_loss_only:
36
- raise ValueError("prediction_step is only called with prediction_loss_only=True")
37
-
38
- with torch.no_grad():
39
- if self.is_vision_model:
40
- if "doc_pixel_attention_mask" not in inputs:
41
- doc_outputs = model(
42
- input_ids=inputs["doc_input_ids"],
43
- attention_mask=inputs["doc_attention_mask"],
44
- pixel_values=inputs["doc_pixel_values"],
45
- )
46
- else:
47
- doc_outputs = model(
48
- input_ids=inputs["doc_input_ids"],
49
- attention_mask=inputs["doc_attention_mask"],
50
- pixel_values=inputs["doc_pixel_values"],
51
- pixel_attention_mask=inputs["doc_pixel_attention_mask"],
52
- )
53
- query_outputs = model(
54
- input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
55
- )
56
- else:
57
-
58
- query_outputs = model(
59
- input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
60
- )
61
- doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])
62
-
63
- loss = self.loss_func(query_outputs, doc_outputs)
64
- return loss, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/trainer/retrieval_evaluator.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
- from mteb.evaluation.evaluators import RetrievalEvaluator
3
-
4
-
5
- class CustomEvaluator:
6
- def __init__(self, is_multi_vector=False):
7
- self.is_multi_vector = is_multi_vector
8
- self.mteb_evaluator = RetrievalEvaluator()
9
-
10
- def evaluate(self, qs, ps):
11
- if self.is_multi_vector:
12
- scores = self.evaluate_colbert(qs, ps)
13
- else:
14
- scores = self.evaluate_biencoder(qs, ps)
15
-
16
- assert scores.shape[0] == len(qs)
17
-
18
- arg_score = scores.argmax(dim=1)
19
- # compare to arange
20
- accuracy = (arg_score == torch.arange(scores.shape[0], device=scores.device)).sum().item() / scores.shape[0]
21
- print(arg_score)
22
- print(f"Top 1 Accuracy (verif): {accuracy}")
23
-
24
- # cast to numpy
25
- # scores = scores.cpu().numpy()
26
- scores = scores.to(torch.float32).cpu().numpy()
27
- return scores
28
-
29
- def compute_metrics(self, relevant_docs, results, **kwargs):
30
- # wrap mteb package
31
-
32
- ndcg, _map, recall, precision, naucs = self.mteb_evaluator.evaluate(
33
- relevant_docs,
34
- results,
35
- self.mteb_evaluator.k_values,
36
- ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
37
- )
38
- mrr = self.mteb_evaluator.evaluate_custom(relevant_docs, results, self.mteb_evaluator.k_values, "mrr")
39
- scores = {
40
- **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
41
- **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
42
- **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
43
- **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
44
- **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
45
- **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
46
- }
47
- return scores
48
-
49
- def evaluate_colbert(self, qs, ps, batch_size=128) -> torch.Tensor:
50
- scores = []
51
- for i in range(0, len(qs), batch_size):
52
- scores_batch = []
53
- qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
54
- "cpu"
55
- )
56
- for j in range(0, len(ps), batch_size):
57
- ps_batch = torch.nn.utils.rnn.pad_sequence(
58
- ps[j : j + batch_size], batch_first=True, padding_value=0
59
- ).to("cpu")
60
- scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
61
- scores_batch = torch.cat(scores_batch, dim=1).cpu()
62
- scores.append(scores_batch)
63
- scores = torch.cat(scores, dim=0)
64
- return scores
65
-
66
- def evaluate_biencoder(self, qs, ps) -> torch.Tensor:
67
-
68
- qs = torch.stack(qs)
69
- ps = torch.stack(ps)
70
-
71
- scores = torch.einsum("bd,cd->bc", qs, ps)
72
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/__init__.py DELETED
File without changes
colpali_engine/utils/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (169 Bytes)
 
colpali_engine/utils/__pycache__/colpali_processing_utils.cpython-310.pyc DELETED
Binary file (1.2 kB)
 
colpali_engine/utils/__pycache__/image_from_page_utils.cpython-310.pyc DELETED
Binary file (998 Bytes)
 
colpali_engine/utils/colidefics_processing_utils.py DELETED
@@ -1,53 +0,0 @@
1
- # Utils for processing images and queries for ColPaLi
2
-
3
- def process_images(processor, images, max_length: int = 50):
4
- texts_doc = []
5
- images = [image.convert("RGB") for image in images]
6
-
7
- for _ in images:
8
- messages_doc = [
9
- {
10
- "role": "user",
11
- "content": [
12
- {"type": "text", "text": "Describe the image."},
13
- {"type": "image"},
14
- ],
15
- },
16
- ]
17
-
18
- text_doc = processor.apply_chat_template(messages_doc, add_generation_prompt=False)
19
- texts_doc.append(text_doc.strip())
20
-
21
- batch_doc = processor(
22
- text=texts_doc,
23
- images=images,
24
- return_tensors="pt",
25
- padding="longest",
26
- )
27
- return batch_doc
28
-
29
-
30
- def process_queries(processor, queries, mock_image, max_length: int = 50):
31
- texts_query = []
32
- for query in queries:
33
- messages_query = [
34
- {
35
- "role": "user",
36
- "content": [
37
- {
38
- "type": "text",
39
- "text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
40
- },
41
- ],
42
- },
43
- ]
44
- text_query = processor.apply_chat_template(messages_query, add_generation_prompt=False).strip()
45
- texts_query.append(text_query)
46
-
47
- batch_query = processor(
48
- text=texts_query,
49
- return_tensors="pt",
50
- padding="longest",
51
- max_length=max_length,
52
- )
53
- return batch_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/colpali_processing_utils.py DELETED
@@ -1,36 +0,0 @@
1
- # Utils for processing images and queries for ColPaLi
2
-
3
-
4
- def process_images(processor, images, max_length: int = 50):
5
- texts_doc = ["Describe the image."] * len(images)
6
- images = [image.convert("RGB") for image in images]
7
-
8
- batch_doc = processor(
9
- text=texts_doc,
10
- images=images,
11
- return_tensors="pt",
12
- padding="longest",
13
- max_length=max_length + processor.image_seq_length,
14
- )
15
- return batch_doc
16
-
17
-
18
- def process_queries(processor, queries, mock_image, max_length: int = 50):
19
- texts_query = []
20
- for query in queries:
21
- query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
22
- texts_query.append(query)
23
-
24
- batch_query = processor(
25
- images=[mock_image.convert("RGB")] * len(texts_query),
26
- # NOTE: the image is not used in batch_query but it is required for calling the processor
27
- text=texts_query,
28
- return_tensors="pt",
29
- padding="longest",
30
- max_length=max_length + processor.image_seq_length,
31
- )
32
- del batch_query["pixel_values"]
33
-
34
- batch_query["input_ids"] = batch_query["input_ids"][..., processor.image_seq_length :]
35
- batch_query["attention_mask"] = batch_query["attention_mask"][..., processor.image_seq_length :]
36
- return batch_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/dataset_transformation.py DELETED
@@ -1,158 +0,0 @@
1
- import os
2
-
3
- from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
4
-
5
- USE_LOCAL_DATASET = os.environ.get("USE_LOCAL_DATASET", "1") == "1"
6
-
7
-
8
- def add_metadata_column(dataset, column_name, value):
9
- def add_source(example):
10
- example[column_name] = value
11
- return example
12
-
13
- return dataset.map(add_source)
14
-
15
-
16
- def load_train_set() -> DatasetDict:
17
-
18
- ds_paths = [
19
- "infovqa_train",
20
- "docvqa_train",
21
- "arxivqa_train",
22
- "tatdqa_train",
23
- "syntheticDocQA_government_reports_train",
24
- "syntheticDocQA_healthcare_industry_train",
25
- "syntheticDocQA_artificial_intelligence_train",
26
- "syntheticDocQA_energy_train",
27
- ]
28
- base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
29
- ds_tot = []
30
- for path in ds_paths:
31
- cpath = base_path + path
32
- ds = load_dataset(cpath, split="train")
33
- if "arxivqa" in path:
34
- # subsample 10k
35
- ds = ds.shuffle(42).select(range(10000))
36
- ds_tot.append(ds)
37
-
38
- dataset = concatenate_datasets(ds_tot)
39
- dataset = dataset.shuffle(seed=42)
40
- # split into train and test
41
- dataset_eval = dataset.select(range(500))
42
- dataset = dataset.select(range(500, len(dataset)))
43
- ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
44
- return ds_dict
45
-
46
-
47
- def load_train_set_with_tabfquad() -> DatasetDict:
48
-
49
- ds_paths = [
50
- "infovqa_train",
51
- "docvqa_train",
52
- "arxivqa_train",
53
- "tatdqa_train",
54
- "tabfquad_train_subsampled",
55
- "syntheticDocQA_government_reports_train",
56
- "syntheticDocQA_healthcare_industry_train",
57
- "syntheticDocQA_artificial_intelligence_train",
58
- "syntheticDocQA_energy_train",
59
- ]
60
- base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
61
- ds_tot = []
62
- for path in ds_paths:
63
- cpath = base_path + path
64
- ds = load_dataset(cpath, split="train")
65
- if "arxivqa" in path:
66
- # subsample 10k
67
- ds = ds.shuffle(42).select(range(10000))
68
- ds_tot.append(ds)
69
-
70
- dataset = concatenate_datasets(ds_tot)
71
- dataset = dataset.shuffle(seed=42)
72
- # split into train and test
73
- dataset_eval = dataset.select(range(500))
74
- dataset = dataset.select(range(500, len(dataset)))
75
- ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
76
- return ds_dict
77
-
78
-
79
- def load_train_set_with_docmatix() -> DatasetDict:
80
-
81
- ds_paths = [
82
- "infovqa_train",
83
- "docvqa_train",
84
- "arxivqa_train",
85
- "tatdqa_train",
86
- "tabfquad_train_subsampled",
87
- "syntheticDocQA_government_reports_train",
88
- "syntheticDocQA_healthcare_industry_train",
89
- "syntheticDocQA_artificial_intelligence_train",
90
- "syntheticDocQA_energy_train",
91
- "Docmatix_filtered_train",
92
- ]
93
- base_path = "./data_dir/" if USE_LOCAL_DATASET else "vidore/"
94
- ds_tot = []
95
- for path in ds_paths:
96
- cpath = base_path + path
97
- ds = load_dataset(cpath, split="train")
98
- if "arxivqa" in path:
99
- # subsample 10k
100
- ds = ds.shuffle(42).select(range(10000))
101
- ds_tot.append(ds)
102
-
103
- dataset = concatenate_datasets(ds_tot)
104
- dataset = dataset.shuffle(seed=42)
105
- # split into train and test
106
- dataset_eval = dataset.select(range(500))
107
- dataset = dataset.select(range(500, len(dataset)))
108
- ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
109
- return ds_dict
110
-
111
-
112
- def load_docvqa_dataset() -> DatasetDict:
113
- if USE_LOCAL_DATASET:
114
- dataset_doc = load_dataset("./data_dir/DocVQA", "DocVQA", split="validation")
115
- dataset_doc_eval = load_dataset("./data_dir/DocVQA", "DocVQA", split="test")
116
- dataset_info = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="validation")
117
- dataset_info_eval = load_dataset("./data_dir/DocVQA", "InfographicVQA", split="test")
118
- else:
119
- dataset_doc = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
120
- dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
121
- dataset_info = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
122
- dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
123
-
124
- # concatenate the two datasets
125
- dataset = concatenate_datasets([dataset_doc, dataset_info])
126
- dataset_eval = concatenate_datasets([dataset_doc_eval, dataset_info_eval])
127
- # sample 100 from eval dataset
128
- dataset_eval = dataset_eval.shuffle(seed=42).select(range(200))
129
-
130
- # rename question as query
131
- dataset = dataset.rename_column("question", "query")
132
- dataset_eval = dataset_eval.rename_column("question", "query")
133
-
134
- # create new column image_filename that corresponds to ucsf_document_id if not None, else image_url
135
- dataset = dataset.map(
136
- lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
137
- )
138
- dataset_eval = dataset_eval.map(
139
- lambda x: {"image_filename": x["ucsf_document_id"] if x["ucsf_document_id"] is not None else x["image_url"]}
140
- )
141
-
142
- ds_dict = DatasetDict({"train": dataset, "test": dataset_eval})
143
-
144
- return ds_dict
145
-
146
-
147
- class TestSetFactory:
148
- def __init__(self, dataset_path):
149
- self.dataset_path = dataset_path
150
-
151
- def __call__(self, *args, **kwargs):
152
- dataset = load_dataset(self.dataset_path, split="test")
153
- return dataset
154
-
155
-
156
- if __name__ == "__main__":
157
- ds = TestSetFactory("vidore/tabfquad_test_subsampled")()
158
- print(ds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/gpu_stats.py DELETED
@@ -1,24 +0,0 @@
1
- # cond import
2
- try:
3
- from pynvml import *
4
-
5
- def print_gpu_utilization():
6
- nvmlInit()
7
- handle = nvmlDeviceGetHandleByIndex(0)
8
- info = nvmlDeviceGetMemoryInfo(handle)
9
- print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
10
-
11
- def print_summary(result):
12
- print(f"Time: {result.metrics['train_runtime']:.2f}")
13
- print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
14
- print_gpu_utilization()
15
-
16
- except ImportError:
17
- print("pynvml not found. GPU stats will not be printed.")
18
-
19
- def print_summary(result):
20
- print(f"Time: {result.metrics['train_runtime']:.2f}")
21
- print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
22
-
23
- def print_gpu_utilization():
24
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/image_from_page_utils.py DELETED
@@ -1,21 +0,0 @@
1
- import requests
2
- from PIL import Image
3
-
4
-
5
- def load_from_pdf(pdf_path: str):
6
- from pdf2image import convert_from_path
7
-
8
- images = convert_from_path(pdf_path)
9
- return images
10
-
11
-
12
- def load_from_image_urls(urls: str):
13
- images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
14
- return images
15
-
16
-
17
- def load_from_dataset(dataset):
18
- from datasets import load_dataset
19
-
20
- dataset = load_dataset(dataset, split="test")
21
- return dataset["image"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/image_utils.py DELETED
@@ -1,64 +0,0 @@
1
- """
2
- Utility functions for working with images.
3
- """
4
-
5
- import base64
6
- import io
7
-
8
- from PIL import Image
9
-
10
-
11
- def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
12
- """
13
- Scale an image to a new height while maintaining the aspect ratio.
14
- """
15
- # Calculate the scaling factor
16
- width, height = image.size
17
- aspect_ratio = width / height
18
- new_width = int(new_height * aspect_ratio)
19
-
20
- # Resize the image
21
- scaled_image = image.resize((new_width, new_height))
22
-
23
- return scaled_image
24
-
25
-
26
- def scale_to_max_dimension(image: Image.Image, max_dimension: int = 1024) -> Image.Image:
27
- """
28
- Scale an image to a maximum dimension while maintaining the aspect ratio.
29
- """
30
- # Get the dimensions of the image
31
- width, height = image.size
32
-
33
- max_original_dimension = max(width, height)
34
-
35
- if max_original_dimension < max_dimension:
36
- return image
37
-
38
- # Calculate the scaling factor
39
- aspect_ratio = max_dimension / max_original_dimension
40
- new_width = int(width * aspect_ratio)
41
- new_height = int(height * aspect_ratio)
42
-
43
- # Resize the image
44
- scaled_image = image.resize((new_width, new_height))
45
-
46
- return scaled_image
47
-
48
-
49
- def get_base64_image(img: str | Image.Image, add_url_prefix: bool = True) -> str:
50
- """
51
- Convert an image (from a filepath or a PIL.Image object) to a JPEG-base64 string.
52
- """
53
- if isinstance(img, str):
54
- img = Image.open(img)
55
- elif isinstance(img, Image.Image):
56
- pass
57
- else:
58
- raise ValueError("`img` must be a path to an image or a PIL Image object.")
59
-
60
- buffered = io.BytesIO()
61
- img.save(buffered, format="jpeg")
62
- b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
63
-
64
- return f"data:image/jpeg;base64,{b64_data}" if add_url_prefix else b64_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/iter_utils.py DELETED
@@ -1,42 +0,0 @@
1
- import sys
2
-
3
-
4
- def islice(iterable, *args):
5
- """
6
- Yield a slice of an iterable.
7
- >>> islice('ABCDEFG', 2) → A B
8
- >>> islice('ABCDEFG', 2, 4) → C D
9
- >>> islice('ABCDEFG', 2, None) → C D E F G
10
- >>> islice('ABCDEFG', 0, None, 2) → A C E G
11
- """
12
- s = slice(*args)
13
- start, stop, step = s.start or 0, s.stop or sys.maxsize, s.step or 1
14
- it = iter(range(start, stop, step))
15
- try:
16
- nexti = next(it)
17
- except StopIteration:
18
- # Consume *iterable* up to the *start* position.
19
- for i, element in zip(range(start), iterable):
20
- pass
21
- return
22
- try:
23
- for i, element in enumerate(iterable):
24
- if i == nexti:
25
- yield element
26
- nexti = next(it)
27
- except StopIteration:
28
- # Consume to *stop*.
29
- for i, element in zip(range(i + 1, stop), iterable):
30
- pass
31
-
32
-
33
- def batched(iterable, n: int):
34
- """
35
- Yield batches of n elements from an iterable.
36
- >>> batched('ABCDEFG', 3) → ABC DEF G
37
- """
38
- if n < 1:
39
- raise ValueError("n must be at least one")
40
- it = iter(iterable)
41
- while batch := tuple(islice(it, n)):
42
- yield batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/pdf_utils.py DELETED
@@ -1,87 +0,0 @@
1
- import glob
2
- import os
3
- import random
4
- from pathlib import Path
5
-
6
- from pdf2image import convert_from_path
7
- from tqdm import tqdm
8
-
9
- random.seed(42)
10
-
11
-
12
- def convert_pdf_to_images(pdf_file: str, save_folder: str):
13
- """
14
- Convert each page of a pdf to a jpg image and save them in a folder.
15
-
16
- Args:
17
- - pdf_file (str): path to the pdf file
18
- - save_folder (str): path to the folder where the images will be saved
19
-
20
- """
21
- images = convert_from_path(pdf_file)
22
-
23
- for i, image in enumerate(images):
24
- if not os.path.exists(save_folder):
25
- os.makedirs(save_folder)
26
- image.save(os.path.join(save_folder, f"page_{i+1}.jpg"), "JPEG")
27
-
28
-
29
- def convert_all_pdfs_to_images(path_to_folder: str, n_samples: int = 0):
30
- """
31
- Convert all pdfs in a folder and its subfolder to images and save them in a folder.
32
- It will sample n_samples pdf files in each subfolder, allowing to have granularity on the number of pdf files to convert.
33
-
34
-
35
- Args:
36
- - path_to_folder (str): path to the folder containing the pdf files
37
- - n_samples (int): number of pdf files to sample in each subfolder
38
-
39
- directory structure:
40
- - path_to_folder
41
- - subfolder1
42
- - pdf1
43
- - pdf2
44
- - ...
45
- - subfolder2
46
- - pdf1
47
- - pdf2
48
- - ...
49
- - ...
50
-
51
- """
52
- # take n_samples pdf files in each subfolder : I want to take 10 pdf files from each subfolder
53
- sub_dirs = [d for d in os.listdir(path_to_folder) if os.path.isdir(os.path.join(path_to_folder, d))]
54
-
55
- sampled_files = []
56
-
57
- for sub_dir in sub_dirs:
58
- pdf_files = glob.glob(os.path.join(path_to_folder, sub_dir, "*.pdf"))
59
-
60
- if (n_samples == 0) or (len(pdf_files) <= n_samples):
61
- print(f"Taking all pdf files in {sub_dir}")
62
- sampled_files.extend(pdf_files)
63
-
64
- else:
65
- print(f"Taking {n_samples} pdf files in {sub_dir}")
66
- sampled_files.extend(random.sample(pdf_files, n_samples))
67
-
68
- pdf_files = [str(file) for file in sampled_files]
69
-
70
- # Create an empty text file that will contain the file paths of the corrupted pdf files
71
- dirpath_corrupted = Path(path_to_folder) / "corrupted_pdf_files.txt"
72
- dirpath_corrupted.parent.mkdir(parents=True, exist_ok=True)
73
-
74
- with dirpath_corrupted.open("w") as f:
75
- with tqdm(total=len(pdf_files)) as pbar:
76
- for pdf_file in pdf_files:
77
- pbar.set_description(f"Processing {pdf_file}")
78
- save_folder = os.path.join("pages_extracted", *Path(pdf_file).parts[-2:])
79
- if not os.path.exists(os.path.join(path_to_folder, save_folder)):
80
- try:
81
- convert_pdf_to_images(pdf_file, os.path.join(path_to_folder, save_folder))
82
- except Exception as e:
83
- print(f"Error converting {pdf_file}: {e}")
84
- f.write(pdf_file)
85
- f.write("\n")
86
- pbar.update(1)
87
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/plot_utils.py DELETED
@@ -1,6 +0,0 @@
1
- import seaborn as sns
2
-
3
-
4
- def setup_seaborn():
5
- sns.set_style("white")
6
- sns.set_context("paper", font_scale=2)
 
 
 
 
 
 
 
colpali_engine/utils/torch_utils.py DELETED
@@ -1,18 +0,0 @@
1
- """
2
- Utility functions for interpretability.
3
- """
4
-
5
- import torch
6
-
7
-
8
- def get_torch_device() -> str:
9
- """
10
- Returns the device and dtype to be used for torch tensors.
11
- """
12
- if torch.cuda.is_available():
13
- device = "cuda:0"
14
- elif torch.backends.mps.is_available(): # for Apple Silicon
15
- device = "mps"
16
- else:
17
- device = "cpu"
18
- return device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/train_colpali_engine_models.py DELETED
@@ -1,247 +0,0 @@
1
- # HuggingFace trainer
2
- import json
3
- import os
4
- from dataclasses import dataclass
5
- from typing import Callable, Dict, Optional
6
-
7
- import torch
8
- from datasets import concatenate_datasets
9
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
10
- from torch.utils.data import DataLoader
11
- from tqdm import tqdm
12
- from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments
13
-
14
- from colpali_engine.dataset.custom_collator import CustomCollator
15
- from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss
16
- from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
17
- from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
18
- from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary
19
-
20
-
21
- @dataclass
22
- class ColModelTrainingConfig:
23
- model: PreTrainedModel
24
- tr_args: TrainingArguments = None
25
- output_dir: str = None
26
- max_length: int = 256
27
- run_eval: bool = True
28
- run_train: bool = True
29
- peft_config: Optional[LoraConfig] = None
30
- add_suffix: bool = False
31
- processor: Idefics2Processor = None
32
- tokenizer: PreTrainedTokenizer = None
33
- loss_func: Optional[Callable] = ColbertLoss()
34
- dataset_loading_func: Optional[Callable] = None
35
- eval_dataset_loader: Optional[Dict[str, Callable]] = None
36
- pretrained_peft_model_name_or_path: Optional[str] = None
37
-
38
- def __post_init__(self):
39
- if self.output_dir is None:
40
- sanitized_name = str(self.model.name_or_path).replace("/", "_")
41
- self.output_dir = f"./models/{sanitized_name}"
42
-
43
- if self.tr_args is None:
44
- self.tr_args = TrainingArguments(output_dir=self.output_dir)
45
- elif self.tr_args.output_dir is None:
46
- self.tr_args.output_dir = self.output_dir
47
-
48
- # cast if string
49
- if isinstance(self.tr_args.learning_rate, str):
50
- self.tr_args.learning_rate = float(self.tr_args.learning_rate)
51
- self.tr_args.remove_unused_columns = False
52
-
53
- if self.processor is None and self.tokenizer is None:
54
- print("Using textual model tokenization")
55
- self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path)
56
-
57
- if self.pretrained_peft_model_name_or_path is not None:
58
- self.model.load_adapter(self.pretrained_peft_model_name_or_path)
59
- print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}")
60
-
61
- if self.peft_config is not None:
62
- print("Configurating PEFT model")
63
- if self.processor is None:
64
- # Might be deprecated - use the "else" branch
65
- self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True
66
- # self.model.enable_input_require_grads()
67
- self.model = get_peft_model(self.model, self.peft_config)
68
- self.model.print_trainable_parameters()
69
- else:
70
- # Ugly debugging hack
71
- # if self.model.model.config.text_config.vocab_size == 32000:
72
- # print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!")
73
- # self.model.model.text_model.resize_token_embeddings(32003)
74
- # self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2]
75
- # self.model.enable_input_require_grads()
76
- if self.pretrained_peft_model_name_or_path is None:
77
- self.model.add_adapter(self.peft_config)
78
- self.model.enable_adapters()
79
- else:
80
- print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.")
81
-
82
- print_gpu_utilization()
83
-
84
-
85
- class ColModelTraining:
86
- def __init__(self, config: ColModelTrainingConfig) -> None:
87
- self.config = config
88
- self.model = self.config.model
89
- self.dataset = self.config.dataset_loading_func()
90
- self.collator = CustomCollator(
91
- processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length
92
- )
93
- self.current_git_hash = os.popen("git rev-parse HEAD").read().strip()
94
- self.retriever_evaluator = CustomEvaluator(
95
- is_multi_vector=(
96
- isinstance(self.config.loss_func, ColbertLoss)
97
- or isinstance(self.config.loss_func, ColbertPairwiseCELoss)
98
- )
99
- )
100
-
101
- def train(self) -> None:
102
-
103
- trainer = ContrastiveTrainer(
104
- model=self.model,
105
- train_dataset=self.dataset["train"],
106
- eval_dataset=self.dataset["test"],
107
- args=self.config.tr_args,
108
- data_collator=self.collator,
109
- loss_func=self.config.loss_func,
110
- is_vision_model=self.config.processor is not None,
111
- )
112
- trainer.args.remove_unused_columns = False
113
-
114
- result = trainer.train()
115
- print_summary(result)
116
-
117
- def eval_dataset(self, test_dataset):
118
-
119
- self.model.eval()
120
-
121
- # # debug
122
- # if len(test_dataset) > 200:
123
- # test_dataset = test_dataset.select(range(0, 100))
124
-
125
- idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None]
126
- idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None]
127
-
128
- dataloader_with_query = DataLoader(
129
- test_dataset.select(idx_with_query),
130
- batch_size=self.config.tr_args.per_device_eval_batch_size,
131
- shuffle=False,
132
- collate_fn=self.collator,
133
- )
134
- dataloader_without_query = DataLoader(
135
- test_dataset.select(idx_without_query),
136
- batch_size=self.config.tr_args.per_device_eval_batch_size,
137
- shuffle=False,
138
- collate_fn=self.collator,
139
- )
140
-
141
- # dataset is ordered so that non-null queries come first
142
- test_dataset = concatenate_datasets(
143
- [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)]
144
- )
145
-
146
- relevant_docs = {}
147
- docidx_2_docid = {}
148
- qsidx_2_query = []
149
- for idx, sample in enumerate(test_dataset):
150
- doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"]))
151
- # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"]))
152
- if sample["query"] is not None:
153
- relevant_docs[str(idx)] = {doc_id: 1}
154
- qsidx_2_query.append(str(idx))
155
- docidx_2_docid[str(idx)] = doc_id
156
-
157
- qs = []
158
- ps = []
159
-
160
- device = self.model.device
161
- with (torch.no_grad()):
162
- for dataloader in [dataloader_with_query, dataloader_without_query]:
163
- for batch in tqdm(dataloader):
164
- if "doc_pixel_values" not in batch:
165
- doc = self.model(
166
- input_ids=batch["doc_input_ids"].to(device),
167
- attention_mask=batch["doc_attention_mask"].to(device),
168
- )
169
-
170
- else:
171
- if "doc_pixel_attention_mask" in batch:
172
- doc = self.model(
173
- input_ids=batch["doc_input_ids"].to(device),
174
- attention_mask=batch["doc_attention_mask"].to(device),
175
- pixel_values=batch["doc_pixel_values"].to(device),
176
- pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device),
177
- )
178
- else:
179
- doc = self.model(
180
- input_ids=batch["doc_input_ids"].to(device),
181
- attention_mask=batch["doc_attention_mask"].to(device),
182
- pixel_values=batch["doc_pixel_values"].to(device),
183
- )
184
-
185
- ps.extend(list(torch.unbind(doc.to("cpu"))))
186
-
187
- if "query_input_ids" in batch:
188
- query = self.model(
189
- input_ids=batch["query_input_ids"].to(device),
190
- attention_mask=batch["query_attention_mask"].to(device),
191
- )
192
- # variable len
193
- qs.extend(list(torch.unbind(query.to("cpu"))))
194
-
195
- print("Embeddings computed, evaluating")
196
- scores = self.retriever_evaluator.evaluate(qs, ps)
197
- # scores is 2d array of shape (n_queries, n_docs)
198
- # turn it into a dict
199
- results = {}
200
- assert scores.shape[0] == len(qsidx_2_query)
201
- for idx, scores_per_query in enumerate(scores):
202
- results[qsidx_2_query[idx]] = {
203
- docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query)
204
- }
205
-
206
- # evaluate
207
- metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results)
208
- print(metrics)
209
- return metrics
210
-
211
- def eval(self) -> None:
212
-
213
- print("Evaluating on validation set")
214
- metrics = self.eval_dataset(self.dataset["test"])
215
- print(f"Metrics for validation set: {metrics}")
216
- all_metrics = {"validation_set": metrics}
217
-
218
- if self.config.eval_dataset_loader is not None:
219
- for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items():
220
- print(f"Evaluating {test_name}")
221
- test_ds = test_dataset_loading_func()
222
- metrics = self.eval_dataset(test_ds)
223
- all_metrics[test_name] = metrics
224
- print(f"Metrics for {test_name}: {metrics}")
225
-
226
- # checkpoint dumps
227
- with open(f"{self.config.output_dir}/results.json", "w") as f:
228
- json.dump(all_metrics, f)
229
-
230
- # save results as json
231
- with open(f"{self.config.output_dir}/results.json", "w") as f:
232
- json.dump(all_metrics, f)
233
-
234
- def save(self, config_file):
235
- # save model
236
- self.model.save_pretrained(self.config.output_dir)
237
- if self.config.tokenizer is not None:
238
- self.config.tokenizer.save_pretrained(self.config.output_dir)
239
- if self.config.processor is not None:
240
- self.config.processor.save_pretrained(self.config.output_dir) # save config
241
-
242
- # copy-paste the yml file with os
243
- os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml")
244
-
245
- # save git hash of the commit at beginning of training
246
- with open(f"{self.config.output_dir}/git_hash.txt", "w") as f:
247
- f.write(self.current_git_hash)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colpali_engine/utils/wrapper.py DELETED
@@ -1,83 +0,0 @@
1
- import importlib
2
-
3
- from colpali_engine.models.clip_baselines import ColSigLIP, SigLIP
4
- from colpali_engine.models.colbert_architectures import (
5
- BiBERT,
6
- BiXLMRoBERTa,
7
- ColBERT,
8
- ColCamembert,
9
- ColLlama,
10
- ColXLMRoBERTa,
11
- )
12
- from colpali_engine.models.idefics_colbert_architecture import BiIdefics, ColIdefics
13
- from colpali_engine.models.paligemma_colbert_architecture import (
14
- BiNewSiglip,
15
- BiPaliLast,
16
- BiPaliMean,
17
- ColNewSiglip,
18
- ColPali,
19
- )
20
-
21
- if importlib.util.find_spec("transformers") is not None:
22
- from transformers import AutoProcessor, AutoTokenizer
23
- from transformers.tokenization_utils import PreTrainedTokenizer
24
-
25
- class AutoProcessorWrapper:
26
- def __new__(cls, *args, **kwargs):
27
- return AutoProcessor.from_pretrained(*args, **kwargs)
28
-
29
- class AutoTokenizerWrapper(PreTrainedTokenizer):
30
- def __new__(cls, *args, **kwargs):
31
- return AutoTokenizer.from_pretrained(*args, **kwargs)
32
-
33
- class AutoColModelWrapper:
34
- def __new__(cls, *args, **kwargs):
35
- pretrained_model_name_or_path = None
36
- if args:
37
- pretrained_model_name_or_path = args[0]
38
- elif kwargs:
39
- pretrained_model_name_or_path = kwargs["pretrained_model_name_or_path"]
40
-
41
- training_objective = kwargs.pop("training_objective", "colbertv1")
42
-
43
- if "camembert" in pretrained_model_name_or_path:
44
- return ColCamembert.from_pretrained(*args, **kwargs)
45
- elif "xlm-roberta" in pretrained_model_name_or_path:
46
- if training_objective == "biencoder":
47
- return BiXLMRoBERTa.from_pretrained(*args, **kwargs)
48
- return ColXLMRoBERTa.from_pretrained(*args, **kwargs)
49
- elif (
50
- "llama" in pretrained_model_name_or_path.lower() or "croissant" in pretrained_model_name_or_path.lower()
51
- ):
52
- return ColLlama.from_pretrained(*args, **kwargs)
53
- elif "idefics2" in pretrained_model_name_or_path:
54
- if training_objective == "biencoder":
55
- return BiIdefics.from_pretrained(*args, **kwargs)
56
- return ColIdefics.from_pretrained(*args, **kwargs)
57
- elif "siglip" in pretrained_model_name_or_path:
58
- if training_objective == "biencoder_mean":
59
- return SigLIP.from_pretrained(*args, **kwargs)
60
- elif training_objective == "colbertv1":
61
- return ColSigLIP.from_pretrained(*args, **kwargs)
62
- else:
63
- raise ValueError(f"Training objective {training_objective} not recognized")
64
- elif "paligemma" in pretrained_model_name_or_path:
65
- if training_objective == "biencoder_mean":
66
- return BiPaliMean.from_pretrained(*args, **kwargs)
67
- elif training_objective == "biencoder_last":
68
- return BiPaliLast.from_pretrained(*args, **kwargs)
69
- elif training_objective == "biencoder_mean_vision":
70
- return BiNewSiglip.from_pretrained(*args, **kwargs)
71
- elif training_objective == "colbertv1_vision":
72
- return ColNewSiglip.from_pretrained(*args, **kwargs)
73
- elif training_objective == "colbertv1":
74
- return ColPali.from_pretrained(*args, **kwargs)
75
- else:
76
- raise ValueError(f"Training objective {training_objective} not recognized")
77
- else:
78
- if training_objective == "biencoder":
79
- return BiBERT.from_pretrained(*args, **kwargs)
80
- return ColBERT.from_pretrained(*args, **kwargs)
81
-
82
- else:
83
- raise ModuleNotFoundError("Transformers must be loaded")