geekyrakshit commited on
Commit
abd20d0
·
1 Parent(s): 9a6c015

update: colpali index syncs with wandb artifact

Browse files
.gitignore CHANGED
@@ -7,4 +7,6 @@ cursor_prompt.txt
7
  test.py
8
  **.pdf
9
  images/
10
- wandb/
 
 
 
7
  test.py
8
  **.pdf
9
  images/
10
+ wandb/
11
+ .byaldi/
12
+ artifacts/
medrag_multi_modal/document_loader/load_image.py CHANGED
@@ -3,11 +3,11 @@ import os
3
  from typing import Optional
4
 
5
  import rich
 
6
  import weave
7
  from pdf2image.pdf2image import convert_from_path
8
  from PIL import Image
9
 
10
- import wandb
11
  from medrag_multi_modal.document_loader.load_text import TextLoader
12
 
13
 
 
3
  from typing import Optional
4
 
5
  import rich
6
+ import wandb
7
  import weave
8
  from pdf2image.pdf2image import convert_from_path
9
  from PIL import Image
10
 
 
11
  from medrag_multi_modal.document_loader.load_text import TextLoader
12
 
13
 
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .multi_modal_retrieval import MultiModalRetriever
2
 
3
- __all__ = ["MultiModalRetriever"]
 
1
  from .multi_modal_retrieval import MultiModalRetriever
2
 
3
+ __all__ = ["MultiModalRetriever"]
medrag_multi_modal/retrieval/multi_modal_retrieval.py CHANGED
@@ -1,22 +1,39 @@
 
 
 
1
  import weave
2
  from byaldi import RAGMultiModalModel
3
- import wandb
4
 
5
 
6
  class MultiModalRetriever(weave.Model):
7
  model_name: str
8
  _docs_retrieval_model: RAGMultiModalModel
9
-
10
  def __init__(self, model_name: str = "vidore/colpali-v1.2"):
11
  super().__init__(model_name=model_name)
12
  self._docs_retrieval_model = RAGMultiModalModel.from_pretrained(self.model_name)
13
-
14
  def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
15
  if wandb.run:
16
- artifact = wandb.use_artifact(data_artifact_name, type='dataset')
17
  artifact_dir = artifact.download()
18
  else:
19
  api = wandb.Api()
20
  artifact = api.artifact(data_artifact_name)
21
  artifact_dir = artifact.download()
22
- self._docs_retrieval_model.index(input_path=artifact_dir, index_name=index_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import wandb
4
  import weave
5
  from byaldi import RAGMultiModalModel
 
6
 
7
 
8
  class MultiModalRetriever(weave.Model):
9
  model_name: str
10
  _docs_retrieval_model: RAGMultiModalModel
11
+
12
  def __init__(self, model_name: str = "vidore/colpali-v1.2"):
13
  super().__init__(model_name=model_name)
14
  self._docs_retrieval_model = RAGMultiModalModel.from_pretrained(self.model_name)
15
+
16
  def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
17
  if wandb.run:
18
+ artifact = wandb.use_artifact(data_artifact_name, type="dataset")
19
  artifact_dir = artifact.download()
20
  else:
21
  api = wandb.Api()
22
  artifact = api.artifact(data_artifact_name)
23
  artifact_dir = artifact.download()
24
+ self._docs_retrieval_model.index(
25
+ input_path=artifact_dir,
26
+ index_name=index_name,
27
+ store_collection_with_index=False,
28
+ overwrite=True,
29
+ )
30
+ if wandb.run:
31
+ artifact = wandb.Artifact(
32
+ name=index_name,
33
+ type="colpali-index",
34
+ metadata={"weave_dataset_name": weave_dataset_name},
35
+ )
36
+ artifact.add_dir(
37
+ local_path=os.path.join(".byaldi", index_name), name="index"
38
+ )
39
+ artifact.save()