geekyrakshit commited on
Commit
e382637
·
unverified ·
2 Parent(s): bf14736 2ab36c4

Merge pull request #13 from soumik12345/feat/retrieval

Browse files
.github/workflows/tests.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests
2
+ on:
3
+ pull_request:
4
+ paths:
5
+ - .github/workflows/tests.yml
6
+ - medrag_multi_modal/**
7
+ - pyproject.toml
8
+
9
+ jobs:
10
+ code-format:
11
+ name: check code format using black
12
+ runs-on: ubuntu-latest
13
+ steps:
14
+ - uses: actions/checkout@v3
15
+ - uses: psf/black@stable
16
+ lint:
17
+ name: Check linting using ruff
18
+ runs-on: ubuntu-latest
19
+ steps:
20
+ - uses: actions/checkout@v4
21
+ - uses: chartboost/ruff-action@v1
.gitignore CHANGED
@@ -18,3 +18,6 @@ wandb/
18
  cursor_prompt.txt
19
  test.py
20
  uv.lock
 
 
 
 
18
  cursor_prompt.txt
19
  test.py
20
  uv.lock
21
+ grays-anatomy-bm25s/
22
+ prompt**.txt
23
+ **.safetensors
docs/retreival/bm25s.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # BM25-Sparse Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.bm25s_retrieval
docs/retreival/colpali.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ColPali Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.colpali_retrieval
docs/retreival/contriever.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Contriever Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.contriever_retrieval
docs/retreival/medcpt.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # MedCPT Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.medcpt_retrieval
docs/retreival/multi_modal_retrieval.md DELETED
@@ -1,3 +0,0 @@
1
- # Multi-Modal Retrieval
2
-
3
- ::: medrag_multi_modal.retrieval.multi_modal_retrieval
 
 
 
 
docs/retreival/nv_embed_2.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # NV-Embed-v2 Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.nv_embed_2
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,3 +1,15 @@
1
- from .multi_modal_retrieval import MultiModalRetriever
 
 
 
 
 
2
 
3
- __all__ = ["MultiModalRetriever"]
 
 
 
 
 
 
 
 
1
+ from .bm25s_retrieval import BM25sRetriever
2
+ from .colpali_retrieval import CalPaliRetriever
3
+ from .common import SimilarityMetric
4
+ from .contriever_retrieval import ContrieverRetriever
5
+ from .medcpt_retrieval import MedCPTRetriever
6
+ from .nv_embed_2 import NVEmbed2Retriever
7
 
8
+ __all__ = [
9
+ "CalPaliRetriever",
10
+ "BM25sRetriever",
11
+ "ContrieverRetriever",
12
+ "SimilarityMetric",
13
+ "MedCPTRetriever",
14
+ "NVEmbed2Retriever",
15
+ ]
medrag_multi_modal/retrieval/bm25s_retrieval.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from typing import Optional
4
+
5
+ import bm25s
6
+ import weave
7
+ from Stemmer import Stemmer
8
+
9
+ import wandb
10
+
11
+ LANGUAGE_DICT = {
12
+ "english": "en",
13
+ "french": "fr",
14
+ "german": "de",
15
+ }
16
+
17
+
18
+ class BM25sRetriever(weave.Model):
19
+ """
20
+ `BM25sRetriever` is a class that provides functionality for indexing and
21
+ retrieving documents using the [BM25-Sparse](https://github.com/xhluca/bm25s).
22
+
23
+ Args:
24
+ language (str): The language of the documents to be indexed and retrieved.
25
+ use_stemmer (bool): A flag indicating whether to use stemming during tokenization.
26
+ retriever (Optional[bm25s.BM25]): An instance of the BM25 retriever. If not provided,
27
+ a new instance is created.
28
+ """
29
+
30
+ language: str
31
+ use_stemmer: bool
32
+ _retriever: Optional[bm25s.BM25]
33
+
34
+ def __init__(
35
+ self,
36
+ language: str = "english",
37
+ use_stemmer: bool = True,
38
+ retriever: Optional[bm25s.BM25] = None,
39
+ ):
40
+ super().__init__(language=language, use_stemmer=use_stemmer)
41
+ self._retriever = retriever or bm25s.BM25()
42
+
43
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
44
+ """
45
+ Indexes a dataset of text chunks using the BM25 algorithm.
46
+
47
+ This function takes a dataset of text chunks identified by `chunk_dataset_name`,
48
+ tokenizes the text using the BM25 tokenizer with optional stemming, and indexes
49
+ the tokenized text using the BM25 retriever. If an `index_name` is provided, the
50
+ index is saved to disk and logged as a Weights & Biases artifact.
51
+
52
+ !!! example "Example Usage"
53
+ ```python
54
+ import weave
55
+ from dotenv import load_dotenv
56
+
57
+ import wandb
58
+ from medrag_multi_modal.retrieval import BM25sRetriever
59
+
60
+ load_dotenv()
61
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
62
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="bm25s-index")
63
+ retriever = BM25sRetriever()
64
+ retriever.index(chunk_dataset_name="grays-anatomy-text:v13", index_name="grays-anatomy-bm25s")
65
+ ```
66
+
67
+ Args:
68
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
69
+ index_name (Optional[str]): The name to save the index under. If provided, the index
70
+ is saved to disk and logged as a Weights & Biases artifact.
71
+ """
72
+ chunk_dataset = weave.ref(chunk_dataset_name).get().rows
73
+ corpus = [row["text"] for row in chunk_dataset]
74
+ corpus_tokens = bm25s.tokenize(
75
+ corpus,
76
+ stopwords=LANGUAGE_DICT[self.language],
77
+ stemmer=Stemmer(self.language) if self.use_stemmer else None,
78
+ )
79
+ self._retriever.index(corpus_tokens)
80
+ if index_name:
81
+ self._retriever.save(
82
+ index_name, corpus=[dict(row) for row in chunk_dataset]
83
+ )
84
+ if wandb.run:
85
+ artifact = wandb.Artifact(
86
+ name=index_name,
87
+ type="bm25s-index",
88
+ metadata={
89
+ "language": self.language,
90
+ "use_stemmer": self.use_stemmer,
91
+ },
92
+ )
93
+ artifact.add_dir(index_name, name=index_name)
94
+ artifact.save()
95
+
96
+ @classmethod
97
+ def from_wandb_artifact(cls, index_artifact_address: str):
98
+ """
99
+ Creates an instance of the class from a Weights & Biases artifact.
100
+
101
+ This class method retrieves a BM25 index artifact from Weights & Biases,
102
+ downloads the artifact, and loads the BM25 retriever with the index and its
103
+ associated corpus. The method also extracts metadata from the artifact to
104
+ initialize the class instance with the appropriate language and stemming
105
+ settings.
106
+
107
+ !!! example "Example Usage"
108
+ ```python
109
+ import weave
110
+ from dotenv import load_dotenv
111
+
112
+ from medrag_multi_modal.retrieval import BM25sRetriever
113
+
114
+ load_dotenv()
115
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
116
+ retriever = BM25sRetriever.from_wandb_artifact(
117
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest"
118
+ )
119
+ ```
120
+
121
+ Args:
122
+ index_artifact_address (str): The address of the Weights & Biases artifact
123
+ containing the BM25 index.
124
+
125
+ Returns:
126
+ An instance of the class initialized with the BM25 retriever and metadata
127
+ from the artifact.
128
+ """
129
+ if wandb.run:
130
+ artifact = wandb.run.use_artifact(
131
+ index_artifact_address, type="bm25s-index"
132
+ )
133
+ artifact_dir = artifact.download()
134
+ else:
135
+ api = wandb.Api()
136
+ artifact = api.artifact(index_artifact_address)
137
+ artifact_dir = artifact.download()
138
+ retriever = bm25s.BM25.load(
139
+ glob(os.path.join(artifact_dir, "*"))[0], load_corpus=True
140
+ )
141
+ metadata = artifact.metadata
142
+ return cls(
143
+ language=metadata["language"],
144
+ use_stemmer=metadata["use_stemmer"],
145
+ retriever=retriever,
146
+ )
147
+
148
+ @weave.op()
149
+ def retrieve(self, query: str, top_k: int = 2):
150
+ """
151
+ Retrieves the top-k most relevant chunks for a given query using the BM25 algorithm.
152
+
153
+ This method tokenizes the input query using the BM25 tokenizer, which takes into
154
+ account the language-specific stopwords and optional stemming. It then retrieves
155
+ the top-k most relevant chunks from the BM25 index based on the tokenized query.
156
+ The results are returned as a list of dictionaries, each containing a chunk and
157
+ its corresponding relevance score.
158
+
159
+ Args:
160
+ query (str): The input query string to search for relevant chunks.
161
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
162
+
163
+ Returns:
164
+ list: A list of dictionaries, each containing a retrieved chunk and its
165
+ relevance score.
166
+ """
167
+ query_tokens = bm25s.tokenize(
168
+ query,
169
+ stopwords=LANGUAGE_DICT[self.language],
170
+ stemmer=Stemmer(self.language) if self.use_stemmer else None,
171
+ )
172
+ results = self._retriever.retrieve(query_tokens, k=top_k)
173
+ retrieved_chunks = []
174
+ for chunk, score in zip(
175
+ results.documents.flatten().tolist(),
176
+ results.scores.flatten().tolist(),
177
+ ):
178
+ retrieved_chunks.append({"chunk": chunk, "score": score})
179
+ return retrieved_chunks
180
+
181
+ @weave.op()
182
+ def predict(self, query: str, top_k: int = 2):
183
+ """
184
+ Predicts the top-k most relevant chunks for a given query using the BM25 algorithm.
185
+
186
+ This function is a wrapper around the `retrieve` method. It takes an input query string,
187
+ tokenizes it using the BM25 tokenizer, and retrieves the top-k most relevant chunks from
188
+ the BM25 index. The results are returned as a list of dictionaries, each containing a chunk
189
+ and its corresponding relevance score.
190
+
191
+ !!! example "Example Usage"
192
+ ```python
193
+ import weave
194
+ from dotenv import load_dotenv
195
+
196
+ from medrag_multi_modal.retrieval import BM25sRetriever
197
+
198
+ load_dotenv()
199
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
200
+ retriever = BM25sRetriever.from_wandb_artifact(
201
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-bm25s:latest"
202
+ )
203
+ retrieved_chunks = retriever.predict(query="What are Ribosomes?")
204
+ ```
205
+
206
+ Args:
207
+ query (str): The input query string to search for relevant chunks.
208
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
209
+
210
+ Returns:
211
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
212
+ """
213
+ return self.retrieve(query, top_k)
medrag_multi_modal/retrieval/{multi_modal_retrieval.py → colpali_retrieval.py} RENAMED
@@ -1,8 +1,11 @@
1
  import os
2
- from typing import Any, Optional
3
 
4
  import weave
5
- from byaldi import RAGMultiModalModel
 
 
 
6
  from PIL import Image
7
 
8
  import wandb
@@ -10,64 +13,33 @@ import wandb
10
  from ..utils import get_wandb_artifact
11
 
12
 
13
- class MultiModalRetriever(weave.Model):
14
  """
15
- MultiModalRetriever is a class that facilitates the retrieval of page images using ColPali.
16
 
17
  This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
18
  It can be initialized with a pre-trained model or from a specified W&B artifact. The class
19
  also provides methods to index new data and to predict/retrieve documents based on a query.
20
 
21
- !!! example "Indexing Data"
22
- ```python
23
- import wandb
24
- from medrag_multi_modal.retrieval import MultiModalRetriever
25
-
26
- wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
27
- retriever = MultiModalRetriever()
28
- retriever.index(
29
- data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
30
- weave_dataset_name="grays-anatomy-images:v0",
31
- index_name="grays-anatomy",
32
- )
33
- ```
34
-
35
- !!! example "Retrieving Documents"
36
- ```python
37
- import weave
38
-
39
- import wandb
40
- from medrag_multi_modal.retrieval import MultiModalRetriever
41
-
42
- weave.init(project_name="ml-colabs/medrag-multi-modal")
43
- retriever = MultiModalRetriever.from_artifact(
44
- index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
45
- metadata_dataset_name="grays-anatomy-images:v0",
46
- data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
47
- )
48
- retriever.predict(
49
- query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
50
- top_k=3,
51
- )
52
- ```
53
-
54
  Attributes:
55
  model_name (str): The name of the model to be used for retrieval.
56
  """
57
 
58
  model_name: str
59
- _docs_retrieval_model: Optional[RAGMultiModalModel] = None
60
  _metadata: Optional[dict] = None
61
  _data_artifact_dir: Optional[str] = None
62
 
63
  def __init__(
64
  self,
65
  model_name: str = "vidore/colpali-v1.2",
66
- docs_retrieval_model: Optional[RAGMultiModalModel] = None,
67
  data_artifact_dir: Optional[str] = None,
68
  metadata_dataset_name: Optional[str] = None,
69
  ):
70
  super().__init__(model_name=model_name)
 
 
71
  self._docs_retrieval_model = (
72
  docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
73
  )
@@ -78,25 +50,54 @@ class MultiModalRetriever(weave.Model):
78
  else None
79
  )
80
 
81
- @classmethod
82
- def from_artifact(
83
- cls,
84
- index_artifact_name: str,
85
- metadata_dataset_name: str,
86
- data_artifact_name: str,
87
- ):
88
- index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
89
- data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
90
- docs_retrieval_model = RAGMultiModalModel.from_index(
91
- index_path=os.path.join(index_artifact_dir, "index")
92
- )
93
- return cls(
94
- docs_retrieval_model=docs_retrieval_model,
95
- metadata_dataset_name=metadata_dataset_name,
96
- data_artifact_dir=data_artifact_dir,
97
- )
98
-
99
  def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
101
  self._docs_retrieval_model.index(
102
  input_path=data_artifact_dir,
@@ -115,6 +116,76 @@ class MultiModalRetriever(weave.Model):
115
  )
116
  artifact.save()
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @weave.op()
119
  def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
120
  """
@@ -125,6 +196,41 @@ class MultiModalRetriever(weave.Model):
125
  documents based on the provided query. It returns a list of dictionaries, each
126
  containing the document image, document ID, and the relevance score.
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  Args:
129
  query (str): The search query string.
130
  top_k (int, optional): The number of top results to retrieve. Defaults to 10.
 
1
  import os
2
+ from typing import TYPE_CHECKING, Any, Optional
3
 
4
  import weave
5
+
6
+ if TYPE_CHECKING:
7
+ from byaldi import RAGMultiModalModel
8
+
9
  from PIL import Image
10
 
11
  import wandb
 
13
  from ..utils import get_wandb_artifact
14
 
15
 
16
+ class CalPaliRetriever(weave.Model):
17
  """
18
+ CalPaliRetriever is a class that facilitates the retrieval of page images using ColPali.
19
 
20
  This class leverages the `byaldi.RAGMultiModalModel` to perform document retrieval tasks.
21
  It can be initialized with a pre-trained model or from a specified W&B artifact. The class
22
  also provides methods to index new data and to predict/retrieve documents based on a query.
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  Attributes:
25
  model_name (str): The name of the model to be used for retrieval.
26
  """
27
 
28
  model_name: str
29
+ _docs_retrieval_model: Optional["RAGMultiModalModel"] = None
30
  _metadata: Optional[dict] = None
31
  _data_artifact_dir: Optional[str] = None
32
 
33
  def __init__(
34
  self,
35
  model_name: str = "vidore/colpali-v1.2",
36
+ docs_retrieval_model: Optional["RAGMultiModalModel"] = None,
37
  data_artifact_dir: Optional[str] = None,
38
  metadata_dataset_name: Optional[str] = None,
39
  ):
40
  super().__init__(model_name=model_name)
41
+ from byaldi import RAGMultiModalModel
42
+
43
  self._docs_retrieval_model = (
44
  docs_retrieval_model or RAGMultiModalModel.from_pretrained(self.model_name)
45
  )
 
50
  else None
51
  )
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def index(self, data_artifact_name: str, weave_dataset_name: str, index_name: str):
54
+ """
55
+ Indexes a dataset of documents and saves the index as a Weave artifact.
56
+
57
+ This method retrieves a dataset of documents from a Weave artifact using the provided
58
+ data artifact name. It then indexes the documents using the document retrieval model
59
+ and assigns the specified index name. The index is stored locally without storing the
60
+ collection with the index and overwrites any existing index with the same name.
61
+
62
+ If a Weave run is active, the method creates a new Weave artifact with the specified
63
+ index name and type "colpali-index". It adds the local index directory to the artifact
64
+ and saves it to Weave, including metadata with the provided Weave dataset name.
65
+
66
+ !!! example "Indexing Data"
67
+ First you need to install `Byaldi` library by Answer.ai.
68
+
69
+ ```bash
70
+ uv pip install Byaldi>=0.0.5
71
+ ```
72
+
73
+ Next, you can index the data by running the following code:
74
+
75
+ ```python
76
+ import wandb
77
+ from medrag_multi_modal.retrieval import CalPaliRetriever
78
+
79
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="index")
80
+ retriever = CalPaliRetriever()
81
+ retriever.index(
82
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
83
+ weave_dataset_name="grays-anatomy-images:v0",
84
+ index_name="grays-anatomy",
85
+ )
86
+ ```
87
+
88
+ ??? note "Optional Speedup using Flash Attention"
89
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
90
+ installing the `flash-attn` package.
91
+
92
+ ```bash
93
+ uv pip install flash-attn --no-build-isolation
94
+ ```
95
+
96
+ Args:
97
+ data_artifact_name (str): The name of the Weave artifact containing the dataset.
98
+ weave_dataset_name (str): The name of the Weave dataset to include in the artifact metadata.
99
+ index_name (str): The name to assign to the created index.
100
+ """
101
  data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
102
  self._docs_retrieval_model.index(
103
  input_path=data_artifact_dir,
 
116
  )
117
  artifact.save()
118
 
119
+ @classmethod
120
+ def from_wandb_artifact(
121
+ cls,
122
+ index_artifact_name: str,
123
+ metadata_dataset_name: str,
124
+ data_artifact_name: str,
125
+ ):
126
+ """
127
+ Creates an instance of the class from Weights & Biases (wandb) artifacts.
128
+
129
+ This method retrieves the necessary artifacts from wandb to initialize the
130
+ ColPaliRetriever. It fetches the index artifact directory and the data artifact
131
+ directory using the provided artifact names. It then loads the document retrieval
132
+ model from the index path within the index artifact directory. Finally, it returns
133
+ an instance of the class initialized with the retrieved document retrieval model,
134
+ metadata dataset name, and data artifact directory.
135
+
136
+ !!! example "Retrieving Documents"
137
+ First you need to install `Byaldi` library by Answer.ai.
138
+
139
+ ```bash
140
+ uv pip install Byaldi>=0.0.5
141
+ ```
142
+
143
+ Next, you can retrieve the documents by running the following code:
144
+
145
+ ```python
146
+ import weave
147
+
148
+ import wandb
149
+ from medrag_multi_modal.retrieval import CalPaliRetriever
150
+
151
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
152
+ retriever = CalPaliRetriever.from_wandb_artifact(
153
+ index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
154
+ metadata_dataset_name="grays-anatomy-images:v0",
155
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
156
+ )
157
+ ```
158
+
159
+ ??? note "Optional Speedup using Flash Attention"
160
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
161
+ installing the `flash-attn` package.
162
+
163
+ ```bash
164
+ uv pip install flash-attn --no-build-isolation
165
+ ```
166
+
167
+ Args:
168
+ index_artifact_name (str): The name of the wandb artifact containing the index.
169
+ metadata_dataset_name (str): The name of the dataset containing metadata.
170
+ data_artifact_name (str): The name of the wandb artifact containing the data.
171
+
172
+ Returns:
173
+ An instance of the class initialized with the retrieved document retrieval model,
174
+ metadata dataset name, and data artifact directory.
175
+ """
176
+ from byaldi import RAGMultiModalModel
177
+
178
+ index_artifact_dir = get_wandb_artifact(index_artifact_name, "colpali-index")
179
+ data_artifact_dir = get_wandb_artifact(data_artifact_name, "dataset")
180
+ docs_retrieval_model = RAGMultiModalModel.from_index(
181
+ index_path=os.path.join(index_artifact_dir, "index")
182
+ )
183
+ return cls(
184
+ docs_retrieval_model=docs_retrieval_model,
185
+ metadata_dataset_name=metadata_dataset_name,
186
+ data_artifact_dir=data_artifact_dir,
187
+ )
188
+
189
  @weave.op()
190
  def predict(self, query: str, top_k: int = 3) -> list[dict[str, Any]]:
191
  """
 
196
  documents based on the provided query. It returns a list of dictionaries, each
197
  containing the document image, document ID, and the relevance score.
198
 
199
+ !!! example "Retrieving Documents"
200
+ First you need to install `Byaldi` library by Answer.ai.
201
+
202
+ ```bash
203
+ uv pip install Byaldi>=0.0.5
204
+ ```
205
+
206
+ Next, you can retrieve the documents by running the following code:
207
+
208
+ ```python
209
+ import weave
210
+
211
+ import wandb
212
+ from medrag_multi_modal.retrieval import CalPaliRetriever
213
+
214
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
215
+ retriever = CalPaliRetriever.from_wandb_artifact(
216
+ index_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy:v0",
217
+ metadata_dataset_name="grays-anatomy-images:v0",
218
+ data_artifact_name="ml-colabs/medrag-multi-modal/grays-anatomy-images:v1",
219
+ )
220
+ retriever.predict(
221
+ query="which neurotransmitters convey information between Merkel cells and sensory afferents?",
222
+ top_k=3,
223
+ )
224
+ ```
225
+
226
+ ??? note "Optional Speedup using Flash Attention"
227
+ If you have a GPU with Flash Attention support, you can enable it for ColPali by simply
228
+ installing the `flash-attn` package.
229
+
230
+ ```bash
231
+ uv pip install flash-attn --no-build-isolation
232
+ ```
233
+
234
  Args:
235
  query (str): The search query string.
236
  top_k (int, optional): The number of top results to retrieve. Defaults to 10.
medrag_multi_modal/retrieval/common.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import safetensors
4
+ import safetensors.torch
5
+ import torch
6
+
7
+ import wandb
8
+
9
+
10
+ class SimilarityMetric(Enum):
11
+ COSINE = "cosine"
12
+ EUCLIDEAN = "euclidean"
13
+
14
+
15
+ def mean_pooling(token_embeddings, mask):
16
+ token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
17
+ sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
18
+ return sentence_embeddings
19
+
20
+
21
+ def argsort_scores(scores: list[float], descending: bool = False):
22
+ return [
23
+ {"item": item, "original_index": idx}
24
+ for idx, item in sorted(
25
+ list(enumerate(scores)), key=lambda x: x[1], reverse=descending
26
+ )
27
+ ]
28
+
29
+
30
+ def save_vector_index(
31
+ vector_index: torch.Tensor,
32
+ type: str,
33
+ index_name: str,
34
+ metadata: dict,
35
+ filename: str = "vector_index.safetensors",
36
+ ):
37
+ safetensors.torch.save_file({"vector_index": vector_index.cpu()}, filename)
38
+ if wandb.run:
39
+ artifact = wandb.Artifact(
40
+ name=index_name,
41
+ type=type,
42
+ metadata=metadata,
43
+ )
44
+ artifact.add_file(filename)
45
+ artifact.save()
medrag_multi_modal/retrieval/contriever_retrieval.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import safetensors.torch
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import weave
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ BertPreTrainedModel,
13
+ PreTrainedTokenizerFast,
14
+ )
15
+
16
+ from ..utils import get_torch_backend, get_wandb_artifact
17
+ from .common import SimilarityMetric, argsort_scores, mean_pooling, save_vector_index
18
+
19
+
20
+ class ContrieverRetriever(weave.Model):
21
+ """
22
+ `ContrieverRetriever` is a class to perform retrieval tasks using the Contriever model.
23
+
24
+ It provides methods to encode text data into embeddings, index a dataset of text chunks,
25
+ and retrieve the most relevant chunks for a given query based on similarity metrics.
26
+
27
+ Args:
28
+ model_name (str): The name of the pre-trained model to use for encoding.
29
+ vector_index (Optional[torch.Tensor]): The tensor containing the vector representations
30
+ of the indexed chunks.
31
+ chunk_dataset (Optional[list[dict]]): The weave dataset of text chunks to be indexed.
32
+ """
33
+
34
+ model_name: str
35
+ _chunk_dataset: Optional[list[dict]]
36
+ _tokenizer: PreTrainedTokenizerFast
37
+ _model: BertPreTrainedModel
38
+ _vector_index: Optional[torch.Tensor]
39
+
40
+ def __init__(
41
+ self,
42
+ model_name: str = "facebook/contriever",
43
+ vector_index: Optional[torch.Tensor] = None,
44
+ chunk_dataset: Optional[list[dict]] = None,
45
+ ):
46
+ super().__init__(model_name=model_name)
47
+ self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
48
+ self._model = AutoModel.from_pretrained(self.model_name)
49
+ self._vector_index = vector_index
50
+ self._chunk_dataset = chunk_dataset
51
+
52
+ def encode(self, corpus: list[str]) -> torch.Tensor:
53
+ inputs = self._tokenizer(
54
+ corpus, padding=True, truncation=True, return_tensors="pt"
55
+ )
56
+ outputs = self._model(**inputs)
57
+ return mean_pooling(outputs[0], inputs["attention_mask"])
58
+
59
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
60
+ """
61
+ Indexes a dataset of text chunks and optionally saves the vector index to a file.
62
+
63
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the
64
+ text chunks into vector representations using the Contriever model, and stores the
65
+ resulting vector index. If an index name is provided, the vector index is saved to
66
+ a file in the safetensors format. Additionally, if a Weave run is active, the vector
67
+ index file is logged as an artifact to Weave.
68
+
69
+ !!! example "Example Usage"
70
+ ```python
71
+ import weave
72
+ from dotenv import load_dotenv
73
+
74
+ import wandb
75
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
76
+
77
+ load_dotenv()
78
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
79
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="contriever-index")
80
+ retriever = ContrieverRetriever(model_name="facebook/contriever")
81
+ retriever.index(
82
+ chunk_dataset_name="grays-anatomy-chunks:v0",
83
+ index_name="grays-anatomy-contriever",
84
+ )
85
+ ```
86
+
87
+ Args:
88
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks
89
+ to be indexed.
90
+ index_name (Optional[str]): The name of the index artifact to be saved. If provided,
91
+ the vector index is saved to a file and logged as an artifact to Weave.
92
+ """
93
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
94
+ corpus = [row["text"] for row in self._chunk_dataset]
95
+ with torch.no_grad():
96
+ vector_index = self.encode(corpus)
97
+ self._vector_index = vector_index
98
+ if index_name:
99
+ save_vector_index(
100
+ self._vector_index,
101
+ "contriever-index",
102
+ index_name,
103
+ {"model_name": self.model_name},
104
+ )
105
+
106
+ @classmethod
107
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
108
+ """
109
+ Creates an instance of the class from a Weave artifact.
110
+
111
+ This method retrieves a vector index and metadata from a Weave artifact stored in
112
+ Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
113
+ reference. The vector index is loaded from a safetensors file and moved to the
114
+ appropriate device (CPU or GPU). The text chunks are converted into a list of
115
+ dictionaries. The method then returns an instance of the class initialized with
116
+ the retrieved model name, vector index, and chunk dataset.
117
+
118
+ !!! example "Example Usage"
119
+ ```python
120
+ import weave
121
+ from dotenv import load_dotenv
122
+
123
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
124
+
125
+ load_dotenv()
126
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
127
+ retriever = ContrieverRetriever.from_wandb_artifact(
128
+ chunk_dataset_name="grays-anatomy-chunks:v0",
129
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1",
130
+ )
131
+ ```
132
+
133
+ Args:
134
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks.
135
+ index_artifact_address (str): The address of the Weave artifact containing the
136
+ vector index.
137
+
138
+ Returns:
139
+ An instance of the class initialized with the retrieved model name, vector index,
140
+ and chunk dataset.
141
+ """
142
+ artifact_dir, metadata = get_wandb_artifact(
143
+ index_artifact_address, "contriever-index", get_metadata=True
144
+ )
145
+ with safetensors.torch.safe_open(
146
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
147
+ ) as f:
148
+ vector_index = f.get_tensor("vector_index")
149
+ device = torch.device(get_torch_backend())
150
+ vector_index = vector_index.to(device)
151
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
152
+ return cls(
153
+ model_name=metadata["model_name"],
154
+ vector_index=vector_index,
155
+ chunk_dataset=chunk_dataset,
156
+ )
157
+
158
+ @weave.op()
159
+ def retrieve(
160
+ self,
161
+ query: str,
162
+ top_k: int = 2,
163
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
164
+ ):
165
+ """
166
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
167
+
168
+ This method encodes the input query into an embedding and computes similarity scores between
169
+ the query embedding and the precomputed vector index. The similarity metric can be either
170
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
171
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
172
+
173
+ Args:
174
+ query (str): The input query string to search for relevant chunks.
175
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
176
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
177
+
178
+ Returns:
179
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
180
+ """
181
+ query = [query]
182
+ device = torch.device(get_torch_backend())
183
+ with torch.no_grad():
184
+ query_embedding = self.encode(query).to(device)
185
+ if metric == SimilarityMetric.EUCLIDEAN:
186
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
187
+ else:
188
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
189
+ scores = scores.cpu().numpy().tolist()
190
+ scores = argsort_scores(scores, descending=True)[:top_k]
191
+ retrieved_chunks = []
192
+ for score in scores:
193
+ retrieved_chunks.append(
194
+ {
195
+ "chunk": self._chunk_dataset[score["original_index"]],
196
+ "score": score["item"],
197
+ }
198
+ )
199
+ return retrieved_chunks
200
+
201
+ @weave.op()
202
+ def predict(
203
+ self,
204
+ query: str,
205
+ top_k: int = 2,
206
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
207
+ ):
208
+ """
209
+ Predicts the top-k most relevant chunks for a given query using the specified similarity metric.
210
+
211
+ This function is a wrapper around the `retrieve` method. It takes an input query string,
212
+ retrieves the top-k most relevant chunks from the precomputed vector index based on the
213
+ specified similarity metric, and returns the results as a list of dictionaries, each containing
214
+ a chunk and its corresponding relevance score.
215
+
216
+ !!! example "Example Usage"
217
+ ```python
218
+ import weave
219
+ from dotenv import load_dotenv
220
+
221
+ from medrag_multi_modal.retrieval import ContrieverRetriever, SimilarityMetric
222
+
223
+ load_dotenv()
224
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
225
+ retriever = ContrieverRetriever.from_wandb_artifact(
226
+ chunk_dataset_name="grays-anatomy-chunks:v0",
227
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-contriever:v1",
228
+ )
229
+ scores = retriever.predict(query="What are Ribosomes?", metric=SimilarityMetric.COSINE)
230
+ ```
231
+
232
+ Args:
233
+ query (str): The input query string to search for relevant chunks.
234
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
235
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
236
+
237
+ Returns:
238
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
239
+ """
240
+ return self.retrieve(query, top_k, metric)
medrag_multi_modal/retrieval/medcpt_retrieval.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import safetensors.torch
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import weave
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ BertPreTrainedModel,
13
+ PreTrainedTokenizerFast,
14
+ )
15
+
16
+ from ..utils import get_torch_backend, get_wandb_artifact
17
+ from .common import SimilarityMetric, argsort_scores, save_vector_index
18
+
19
+
20
+ class MedCPTRetriever(weave.Model):
21
+ """
22
+ A class to retrieve relevant text chunks using MedCPT models.
23
+
24
+ This class provides methods to index a dataset of text chunks and retrieve the most relevant
25
+ chunks for a given query using MedCPT models. It uses separate models for encoding queries
26
+ and articles, and supports both cosine similarity and Euclidean distance as similarity metrics.
27
+
28
+ Args:
29
+ query_encoder_model_name (str): The name of the model used for encoding queries.
30
+ article_encoder_model_name (str): The name of the model used for encoding articles.
31
+ chunk_size (Optional[int]): The maximum length of text chunks.
32
+ vector_index (Optional[torch.Tensor]): The vector index of encoded text chunks.
33
+ chunk_dataset (Optional[list[dict]]): The dataset of text chunks.
34
+ """
35
+
36
+ query_encoder_model_name: str
37
+ article_encoder_model_name: str
38
+ chunk_size: Optional[int]
39
+ _chunk_dataset: Optional[list[dict]]
40
+ _query_tokenizer: PreTrainedTokenizerFast
41
+ _article_tokenizer: PreTrainedTokenizerFast
42
+ _query_encoder_model: BertPreTrainedModel
43
+ _article_encoder_model: BertPreTrainedModel
44
+ _vector_index: Optional[torch.Tensor]
45
+
46
+ def __init__(
47
+ self,
48
+ query_encoder_model_name: str,
49
+ article_encoder_model_name: str,
50
+ chunk_size: Optional[int] = None,
51
+ vector_index: Optional[torch.Tensor] = None,
52
+ chunk_dataset: Optional[list[dict]] = None,
53
+ ):
54
+ super().__init__(
55
+ query_encoder_model_name=query_encoder_model_name,
56
+ article_encoder_model_name=article_encoder_model_name,
57
+ chunk_size=chunk_size,
58
+ )
59
+ self._query_tokenizer = AutoTokenizer.from_pretrained(
60
+ self.query_encoder_model_name
61
+ )
62
+ self._article_tokenizer = AutoTokenizer.from_pretrained(
63
+ self.article_encoder_model_name
64
+ )
65
+ self._query_encoder_model = AutoModel.from_pretrained(
66
+ self.query_encoder_model_name
67
+ )
68
+ self._article_encoder_model = AutoModel.from_pretrained(
69
+ self.article_encoder_model_name
70
+ )
71
+ self._chunk_dataset = chunk_dataset
72
+ self._vector_index = vector_index
73
+
74
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
75
+ """
76
+ Indexes a dataset of text chunks and optionally saves the vector index.
77
+
78
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the text
79
+ chunks using the article encoder model, and stores the resulting vector index. If an
80
+ index name is provided, the vector index is saved to a file using the `save_vector_index`
81
+ function.
82
+
83
+ !!! example "Example Usage"
84
+ ```python
85
+ import weave
86
+ from dotenv import load_dotenv
87
+
88
+ import wandb
89
+ from medrag_multi_modal.retrieval import MedCPTRetriever
90
+
91
+ load_dotenv()
92
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
93
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="medcpt-index")
94
+ retriever = MedCPTRetriever(
95
+ query_encoder_model_name="ncbi/MedCPT-Query-Encoder",
96
+ article_encoder_model_name="ncbi/MedCPT-Article-Encoder",
97
+ )
98
+ retriever.index(
99
+ chunk_dataset_name="grays-anatomy-chunks:v0",
100
+ index_name="grays-anatomy-medcpt",
101
+ )
102
+ ```
103
+
104
+ Args:
105
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
106
+ index_name (Optional[str]): The name to use when saving the vector index. If not provided,
107
+ the vector index is not saved.
108
+
109
+ """
110
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
111
+ corpus = [row["text"] for row in self._chunk_dataset]
112
+ with torch.no_grad():
113
+ encoded = self._article_tokenizer(
114
+ corpus,
115
+ truncation=True,
116
+ padding=True,
117
+ return_tensors="pt",
118
+ max_length=self.chunk_size,
119
+ )
120
+ vector_index = (
121
+ self._article_encoder_model(**encoded)
122
+ .last_hidden_state[:, 0, :]
123
+ .contiguous()
124
+ )
125
+ self._vector_index = vector_index
126
+ if index_name:
127
+ save_vector_index(
128
+ self._vector_index,
129
+ "medcpt-index",
130
+ index_name,
131
+ {
132
+ "query_encoder_model_name": self.query_encoder_model_name,
133
+ "article_encoder_model_name": self.article_encoder_model_name,
134
+ "chunk_size": self.chunk_size,
135
+ },
136
+ )
137
+
138
+ @classmethod
139
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
140
+ """
141
+ Initializes an instance of the class from a Weave artifact.
142
+
143
+ This method retrieves a precomputed vector index and its associated metadata from a Weave artifact
144
+ stored in Weights & Biases (wandb). It then loads the vector index into memory and initializes an
145
+ instance of the class with the retrieved model names, vector index, and chunk dataset.
146
+
147
+ !!! example "Example Usage"
148
+ ```python
149
+ import weave
150
+ from dotenv import load_dotenv
151
+
152
+ import wandb
153
+ from medrag_multi_modal.retrieval import MedCPTRetriever
154
+
155
+ load_dotenv()
156
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
157
+ retriever = MedCPTRetriever.from_wandb_artifact(
158
+ chunk_dataset_name="grays-anatomy-chunks:v0",
159
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
160
+ )
161
+ ```
162
+
163
+ Args:
164
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
165
+ index_artifact_address (str): The address of the Weave artifact containing the precomputed vector index.
166
+
167
+ Returns:
168
+ An instance of the class initialized with the retrieved model name, vector index, and chunk dataset.
169
+ """
170
+ artifact_dir, metadata = get_wandb_artifact(
171
+ index_artifact_address, "medcpt-index", get_metadata=True
172
+ )
173
+ with safetensors.torch.safe_open(
174
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
175
+ ) as f:
176
+ vector_index = f.get_tensor("vector_index")
177
+ device = torch.device(get_torch_backend())
178
+ vector_index = vector_index.to(device)
179
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
180
+ return cls(
181
+ query_encoder_model_name=metadata["query_encoder_model_name"],
182
+ article_encoder_model_name=metadata["article_encoder_model_name"],
183
+ chunk_size=metadata["chunk_size"],
184
+ vector_index=vector_index,
185
+ chunk_dataset=chunk_dataset,
186
+ )
187
+
188
+ @weave.op()
189
+ def retrieve(
190
+ self,
191
+ query: str,
192
+ top_k: int = 2,
193
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
194
+ ):
195
+ """
196
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
197
+
198
+ This method encodes the input query into an embedding and computes similarity scores between
199
+ the query embedding and the precomputed vector index. The similarity metric can be either
200
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
201
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
202
+
203
+ Args:
204
+ query (str): The input query string to search for relevant chunks.
205
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
206
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
207
+
208
+ Returns:
209
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
210
+ """
211
+ query = [query]
212
+ device = torch.device(get_torch_backend())
213
+ with torch.no_grad():
214
+ encoded = self._query_tokenizer(
215
+ query,
216
+ truncation=True,
217
+ padding=True,
218
+ return_tensors="pt",
219
+ )
220
+ query_embedding = self._query_encoder_model(**encoded).last_hidden_state[
221
+ :, 0, :
222
+ ]
223
+ query_embedding = query_embedding.to(device)
224
+ if metric == SimilarityMetric.EUCLIDEAN:
225
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
226
+ else:
227
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
228
+ scores = scores.cpu().numpy().tolist()
229
+ scores = argsort_scores(scores, descending=True)[:top_k]
230
+ retrieved_chunks = []
231
+ for score in scores:
232
+ retrieved_chunks.append(
233
+ {
234
+ "chunk": self._chunk_dataset[score["original_index"]],
235
+ "score": score["item"],
236
+ }
237
+ )
238
+ return retrieved_chunks
239
+
240
+ @weave.op()
241
+ def predict(
242
+ self,
243
+ query: str,
244
+ top_k: int = 2,
245
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
246
+ ):
247
+ """
248
+ Predicts the most relevant chunks for a given query.
249
+
250
+ This function uses the `retrieve` method to find the top-k relevant chunks
251
+ from the dataset based on the input query. It allows specifying the number
252
+ of top relevant chunks to retrieve and the similarity metric to use for scoring.
253
+
254
+ !!! example "Example Usage"
255
+ ```python
256
+ import weave
257
+ from dotenv import load_dotenv
258
+
259
+ import wandb
260
+ from medrag_multi_modal.retrieval import MedCPTRetriever
261
+
262
+ load_dotenv()
263
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
264
+ retriever = MedCPTRetriever.from_wandb_artifact(
265
+ chunk_dataset_name="grays-anatomy-chunks:v0",
266
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
267
+ )
268
+ retriever.predict(query="What are Ribosomes?")
269
+ ```
270
+
271
+ Args:
272
+ query (str): The input query string to search for relevant chunks.
273
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
274
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
275
+
276
+ Returns:
277
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
278
+ """
279
+ return self.retrieve(query, top_k, metric)
medrag_multi_modal/retrieval/nv_embed_2.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import weave
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ from ..utils import get_torch_backend, get_wandb_artifact
11
+ from .common import SimilarityMetric, argsort_scores, save_vector_index
12
+
13
+
14
+ class NVEmbed2Retriever(weave.Model):
15
+ """
16
+ `NVEmbed2Retriever` is a class for retrieving relevant text chunks from a dataset using the
17
+ [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2) model.
18
+
19
+ This class leverages the SentenceTransformer model to encode text chunks into vector representations and
20
+ performs similarity-based retrieval. It supports indexing a dataset of text chunks, saving the vector index,
21
+ and retrieving the most relevant chunks for a given query.
22
+
23
+ Args:
24
+ model_name (str): The name of the pre-trained model to use for encoding.
25
+ vector_index (Optional[torch.Tensor]): The tensor containing the vector representations of the indexed chunks.
26
+ chunk_dataset (Optional[list[dict]]): The dataset of text chunks to be indexed.
27
+ """
28
+
29
+ model_name: str
30
+ _chunk_dataset: Optional[list[dict]]
31
+ _model: SentenceTransformer
32
+ _vector_index: Optional[torch.Tensor]
33
+
34
+ def __init__(
35
+ self,
36
+ model_name: str = "sentence-transformers/nvembed2-nli-v1",
37
+ vector_index: Optional[torch.Tensor] = None,
38
+ chunk_dataset: Optional[list[dict]] = None,
39
+ ):
40
+ super().__init__(model_name=model_name)
41
+ self._model = SentenceTransformer(
42
+ self.model_name,
43
+ trust_remote_code=True,
44
+ model_kwargs={"torch_dtype": torch.float16},
45
+ device=get_torch_backend(),
46
+ )
47
+ self._model.max_seq_length = 32768
48
+ self._model.tokenizer.padding_side = "right"
49
+ self._vector_index = vector_index
50
+ self._chunk_dataset = chunk_dataset
51
+
52
+ def add_eos(self, input_examples):
53
+ input_examples = [
54
+ input_example + self._model.tokenizer.eos_token
55
+ for input_example in input_examples
56
+ ]
57
+ return input_examples
58
+
59
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
60
+ """
61
+ Indexes a dataset of text chunks and optionally saves the vector index to a file.
62
+
63
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the
64
+ text chunks into vector representations using the NV-Embed-v2 model, and stores the
65
+ resulting vector index. If an index name is provided, the vector index is saved to
66
+ a file in the safetensors format. Additionally, if a Weave run is active, the vector
67
+ index file is logged as an artifact to Weave.
68
+
69
+ !!! example "Example Usage"
70
+ ```python
71
+ import weave
72
+ from dotenv import load_dotenv
73
+
74
+ import wandb
75
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
76
+
77
+ load_dotenv()
78
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
79
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="nvembed2-index")
80
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
81
+ retriever.index(
82
+ chunk_dataset_name="grays-anatomy-chunks:v0",
83
+ index_name="grays-anatomy-nvembed2",
84
+ )
85
+ ```
86
+
87
+ ??? note "Optional Speedup using Flash Attention"
88
+ If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
89
+ installing the `flash-attn` package.
90
+
91
+ ```bash
92
+ uv pip install flash-attn --no-build-isolation
93
+ ```
94
+
95
+ Args:
96
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks
97
+ to be indexed.
98
+ index_name (Optional[str]): The name of the index artifact to be saved. If provided,
99
+ the vector index is saved to a file and logged as an artifact to Weave.
100
+ """
101
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
102
+ corpus = [row["text"] for row in self._chunk_dataset]
103
+ self._vector_index = self._model.encode(
104
+ self.add_eos(corpus), batch_size=len(corpus), normalize_embeddings=True
105
+ )
106
+ with torch.no_grad():
107
+ if index_name:
108
+ save_vector_index(
109
+ torch.from_numpy(self._vector_index),
110
+ "nvembed2-index",
111
+ index_name,
112
+ {"model_name": self.model_name},
113
+ )
114
+
115
+ @classmethod
116
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
117
+ """
118
+ Creates an instance of the class from a Weave artifact.
119
+
120
+ This method retrieves a vector index and metadata from a Weave artifact stored in
121
+ Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
122
+ reference. The vector index is loaded from a safetensors file and moved to the
123
+ appropriate device (CPU or GPU). The text chunks are converted into a list of
124
+ dictionaries. The method then returns an instance of the class initialized with
125
+ the retrieved model name, vector index, and chunk dataset.
126
+
127
+ !!! example "Example Usage"
128
+ ```python
129
+ import weave
130
+ from dotenv import load_dotenv
131
+
132
+ import wandb
133
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
134
+
135
+ load_dotenv()
136
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
137
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
138
+ retriever.index(
139
+ chunk_dataset_name="grays-anatomy-chunks:v0",
140
+ index_name="grays-anatomy-nvembed2",
141
+ )
142
+ retriever = NVEmbed2Retriever.from_wandb_artifact(
143
+ chunk_dataset_name="grays-anatomy-chunks:v0",
144
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0",
145
+ )
146
+ ```
147
+
148
+ ??? note "Optional Speedup using Flash Attention"
149
+ If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
150
+ installing the `flash-attn` package.
151
+
152
+ ```bash
153
+ uv pip install flash-attn --no-build-isolation
154
+ ```
155
+
156
+ Args:
157
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks.
158
+ index_artifact_address (str): The address of the Weave artifact containing the
159
+ vector index.
160
+
161
+ Returns:
162
+ An instance of the class initialized with the retrieved model name, vector index,
163
+ and chunk dataset.
164
+ """
165
+ artifact_dir, metadata = get_wandb_artifact(
166
+ index_artifact_address, "nvembed2-index", get_metadata=True
167
+ )
168
+ with safetensors.torch.safe_open(
169
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
170
+ ) as f:
171
+ vector_index = f.get_tensor("vector_index")
172
+ device = torch.device(get_torch_backend())
173
+ vector_index = vector_index.to(device)
174
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
175
+ return cls(
176
+ model_name=metadata["model_name"],
177
+ vector_index=vector_index,
178
+ chunk_dataset=chunk_dataset,
179
+ )
180
+
181
+ @weave.op()
182
+ def retrieve(
183
+ self,
184
+ query: list[str],
185
+ top_k: int = 2,
186
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
187
+ ):
188
+ """
189
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
190
+
191
+ This method encodes the input query into an embedding and computes similarity scores between
192
+ the query embedding and the precomputed vector index. The similarity metric can be either
193
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
194
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
195
+
196
+ Args:
197
+ query (list[str]): The input query strings to search for relevant chunks.
198
+ top_k (int, optional): The number of top relevant chunks to retrieve.
199
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
200
+
201
+ Returns:
202
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
203
+ """
204
+ device = torch.device(get_torch_backend())
205
+ with torch.no_grad():
206
+ query_embedding = self._model.encode(
207
+ self.add_eos(query), normalize_embeddings=True
208
+ )
209
+ query_embedding = torch.from_numpy(query_embedding).to(device)
210
+ if metric == SimilarityMetric.EUCLIDEAN:
211
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
212
+ else:
213
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
214
+ scores = scores.cpu().numpy().tolist()
215
+ scores = argsort_scores(scores, descending=True)[:top_k]
216
+ retrieved_chunks = []
217
+ for score in scores:
218
+ retrieved_chunks.append(
219
+ {
220
+ "chunk": self._chunk_dataset[score["original_index"]],
221
+ "score": score["item"],
222
+ }
223
+ )
224
+ return retrieved_chunks
225
+
226
+ @weave.op()
227
+ def predict(
228
+ self,
229
+ query: str,
230
+ top_k: int = 2,
231
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
232
+ ):
233
+ """
234
+ Predicts the top-k most relevant chunks for a given query using the specified similarity metric.
235
+
236
+ This method formats the input query string by prepending an instruction prompt and then calls the
237
+ `retrieve` method to get the most relevant chunks. The similarity metric can be either cosine similarity
238
+ or Euclidean distance. The top-k chunks with the highest similarity scores are returned.
239
+
240
+ !!! example "Example Usage"
241
+ ```python
242
+ import weave
243
+ from dotenv import load_dotenv
244
+
245
+ import wandb
246
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
247
+
248
+ load_dotenv()
249
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
250
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
251
+ retriever.index(
252
+ chunk_dataset_name="grays-anatomy-chunks:v0",
253
+ index_name="grays-anatomy-nvembed2",
254
+ )
255
+ retriever = NVEmbed2Retriever.from_wandb_artifact(
256
+ chunk_dataset_name="grays-anatomy-chunks:v0",
257
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0",
258
+ )
259
+ retriever.predict(query="What are Ribosomes?")
260
+ ```
261
+
262
+ ??? note "Optional Speedup using Flash Attention"
263
+ If you have a GPU with Flash Attention support, you can enable it for NV-Embed-v2 by simply
264
+ installing the `flash-attn` package.
265
+
266
+ ```bash
267
+ uv pip install flash-attn --no-build-isolation
268
+ ```
269
+
270
+ Args:
271
+ query (str): The input query string to search for relevant chunks.
272
+ top_k (int, optional): The number of top relevant chunks to retrieve.
273
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
274
+
275
+ Returns:
276
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
277
+ """
278
+ query = [
279
+ f"""Instruct: Given a question, retrieve passages that answer the question
280
+ Query: {query}"""
281
+ ]
282
+ return self.retrieve(query, top_k, metric)
medrag_multi_modal/utils.py CHANGED
@@ -1,7 +1,13 @@
 
 
1
  import wandb
2
 
3
 
4
- def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
 
 
 
 
5
  if wandb.run:
6
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
7
  artifact_dir = artifact.download()
@@ -9,4 +15,17 @@ def get_wandb_artifact(artifact_name: str, artifact_type: str) -> str:
9
  api = wandb.Api()
10
  artifact = api.artifact(artifact_name)
11
  artifact_dir = artifact.download()
 
 
12
  return artifact_dir
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
  import wandb
4
 
5
 
6
+ def get_wandb_artifact(
7
+ artifact_name: str,
8
+ artifact_type: str,
9
+ get_metadata: bool = False,
10
+ ) -> str:
11
  if wandb.run:
12
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
13
  artifact_dir = artifact.download()
 
15
  api = wandb.Api()
16
  artifact = api.artifact(artifact_name)
17
  artifact_dir = artifact.download()
18
+ if get_metadata:
19
+ return artifact_dir, artifact.metadata
20
  return artifact_dir
21
+
22
+
23
+ def get_torch_backend():
24
+ if torch.cuda.is_available():
25
+ if torch.backends.cuda.is_built():
26
+ return "cuda"
27
+ if torch.backends.mps.is_available():
28
+ if torch.backends.mps.is_built():
29
+ return "mps"
30
+ return "cpu"
31
+ return "cpu"
mkdocs.yml CHANGED
@@ -78,6 +78,10 @@ nav:
78
  - FitzPIL: 'document_loader/image_loader/fitzpil_img_loader.md'
79
  - Chunking: 'chunking.md'
80
  - Retrieval:
81
- - Multi-Modal Retrieval: 'retreival/multi_modal_retrieval.md'
 
 
 
 
82
 
83
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
78
  - FitzPIL: 'document_loader/image_loader/fitzpil_img_loader.md'
79
  - Chunking: 'chunking.md'
80
  - Retrieval:
81
+ - BM25-Sparse: 'retreival/bm25s.md'
82
+ - ColPali: 'retreival/colpali.md'
83
+ - Contriever: 'retreival/contriever.md'
84
+ - MedCPT: 'retreival/medcpt.md'
85
+ - NV-Embed-v2: 'retreival/nv_embed_2.md'
86
 
87
  repo_url: https://github.com/soumik12345/medrag-multi-modal
pyproject.toml CHANGED
@@ -5,8 +5,12 @@ description = ""
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
- "Byaldi>=0.0.5",
 
 
 
9
  "firerequests>=0.0.7",
 
10
  "pdf2image>=1.17.0",
11
  "python-dotenv>=1.0.1",
12
  "pymupdf4llm>=0.0.17",
@@ -16,6 +20,8 @@ dependencies = [
16
  "uv>=0.4.20",
17
  "pytest>=8.3.3",
18
  "PyPDF2>=3.0.1",
 
 
19
  "isort>=5.13.2",
20
  "black>=24.10.0",
21
  "ruff>=0.6.9",
@@ -31,30 +37,34 @@ dependencies = [
31
  "pdfplumber>=0.11.4",
32
  "semchunk>=2.2.0",
33
  "tiktoken>=0.8.0",
 
34
  ]
35
 
36
  [project.optional-dependencies]
37
  core = [
38
- "Byaldi>=0.0.5",
 
 
 
39
  "firerequests>=0.0.7",
 
40
  "marker-pdf>=0.2.17",
41
  "pdf2image>=1.17.0",
42
  "pdfplumber>=0.11.4",
43
  "PyPDF2>=3.0.1",
 
44
  "python-dotenv>=1.0.1",
45
  "pymupdf4llm>=0.0.17",
 
46
  "semchunk>=2.2.0",
47
  "tiktoken>=0.8.0",
48
  "torch>=2.4.1",
49
  "weave>=0.51.14",
 
50
  ]
51
 
52
- dev = [
53
- "pytest>=8.3.3",
54
- "isort>=5.13.2",
55
- "black>=24.10.0",
56
- "ruff>=0.6.9",
57
- ]
58
 
59
  docs = [
60
  "mkdocs>=1.6.1",
 
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
+ "adapters>=1.0.0",
9
+ "bm25s[full]>=0.2.2",
10
+ "datasets>=3.0.1",
11
+ "einops>=0.8.0",
12
  "firerequests>=0.0.7",
13
+ "jax[cpu]>=0.4.34",
14
  "pdf2image>=1.17.0",
15
  "python-dotenv>=1.0.1",
16
  "pymupdf4llm>=0.0.17",
 
20
  "uv>=0.4.20",
21
  "pytest>=8.3.3",
22
  "PyPDF2>=3.0.1",
23
+ "PyStemmer>=2.2.0.3",
24
+ "safetensors>=0.4.5",
25
  "isort>=5.13.2",
26
  "black>=24.10.0",
27
  "ruff>=0.6.9",
 
37
  "pdfplumber>=0.11.4",
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
+ "sentence-transformers>=3.2.0",
41
  ]
42
 
43
  [project.optional-dependencies]
44
  core = [
45
+ "adapters>=1.0.0",
46
+ "bm25s[full]>=0.2.2",
47
+ "datasets>=3.0.1",
48
+ "einops>=0.8.0",
49
  "firerequests>=0.0.7",
50
+ "jax[cpu]>=0.4.34",
51
  "marker-pdf>=0.2.17",
52
  "pdf2image>=1.17.0",
53
  "pdfplumber>=0.11.4",
54
  "PyPDF2>=3.0.1",
55
+ "PyStemmer>=2.2.0.3",
56
  "python-dotenv>=1.0.1",
57
  "pymupdf4llm>=0.0.17",
58
+ "safetensors>=0.4.5",
59
  "semchunk>=2.2.0",
60
  "tiktoken>=0.8.0",
61
  "torch>=2.4.1",
62
  "weave>=0.51.14",
63
+ "sentence-transformers>=3.2.0",
64
  ]
65
 
66
+ dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
67
+
 
 
 
 
68
 
69
  docs = [
70
  "mkdocs>=1.6.1",