geekyrakshit commited on
Commit
0d77bb1
·
1 Parent(s): 77a97ce

add: MedCPTRetriever

Browse files
docs/retreival/medcpt.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # MedCPT Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.medcpt_retrieval
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -1,10 +1,13 @@
1
  from .bm25s_retrieval import BM25sRetriever
2
  from .colpali_retrieval import CalPaliRetriever
3
- from .contriever_retrieval import ContrieverRetriever, SimilarityMetric
 
 
4
 
5
  __all__ = [
6
  "CalPaliRetriever",
7
  "BM25sRetriever",
8
  "ContrieverRetriever",
9
  "SimilarityMetric",
 
10
  ]
 
1
  from .bm25s_retrieval import BM25sRetriever
2
  from .colpali_retrieval import CalPaliRetriever
3
+ from .common import SimilarityMetric
4
+ from .contriever_retrieval import ContrieverRetriever
5
+ from .medcpt_retrieval import MedCPTRetriever
6
 
7
  __all__ = [
8
  "CalPaliRetriever",
9
  "BM25sRetriever",
10
  "ContrieverRetriever",
11
  "SimilarityMetric",
12
+ "MedCPTRetriever",
13
  ]
medrag_multi_modal/retrieval/common.py CHANGED
@@ -29,6 +29,7 @@ def argsort_scores(scores: list[float], descending: bool = False):
29
 
30
  def save_vector_index(
31
  vector_index: torch.Tensor,
 
32
  index_name: str,
33
  metadata: dict,
34
  filename: str = "vector_index.safetensors",
@@ -37,7 +38,7 @@ def save_vector_index(
37
  if wandb.run:
38
  artifact = wandb.Artifact(
39
  name=index_name,
40
- type="contriever-index",
41
  metadata=metadata,
42
  )
43
  artifact.add_file(filename)
 
29
 
30
  def save_vector_index(
31
  vector_index: torch.Tensor,
32
+ type: str,
33
  index_name: str,
34
  metadata: dict,
35
  filename: str = "vector_index.safetensors",
 
38
  if wandb.run:
39
  artifact = wandb.Artifact(
40
  name=index_name,
41
+ type=type,
42
  metadata=metadata,
43
  )
44
  artifact.add_file(filename)
medrag_multi_modal/retrieval/contriever_retrieval.py CHANGED
@@ -13,10 +13,8 @@ from transformers import (
13
  PreTrainedTokenizerFast,
14
  )
15
 
16
- import wandb
17
-
18
- from ..utils import get_wandb_artifact, get_torch_backend
19
- from .common import SimilarityMetric, argsort_scores, mean_pooling
20
 
21
 
22
  class ContrieverRetriever(weave.Model):
@@ -80,7 +78,10 @@ class ContrieverRetriever(weave.Model):
80
  weave.init(project_name="ml-colabs/medrag-multi-modal")
81
  wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="contriever-index")
82
  retriever = ContrieverRetriever(model_name="facebook/contriever")
83
- retriever.index(chunk_dataset_name="grays-anatomy-chunks:v0", index_name="grays-anatomy-contriever")
 
 
 
84
  ```
85
 
86
  Args:
@@ -95,17 +96,12 @@ class ContrieverRetriever(weave.Model):
95
  vector_index = self.encode(corpus)
96
  self._vector_index = vector_index
97
  if index_name:
98
- safetensors.torch.save_file(
99
- {"vector_index": vector_index.cpu()}, "vector_index.safetensors"
 
 
 
100
  )
101
- if wandb.run:
102
- artifact = wandb.Artifact(
103
- name=index_name,
104
- type="contriever-index",
105
- metadata={"model_name": self.model_name},
106
- )
107
- artifact.add_file("vector_index.safetensors")
108
- artifact.save()
109
 
110
  @classmethod
111
  def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
 
13
  PreTrainedTokenizerFast,
14
  )
15
 
16
+ from ..utils import get_torch_backend, get_wandb_artifact
17
+ from .common import SimilarityMetric, argsort_scores, mean_pooling, save_vector_index
 
 
18
 
19
 
20
  class ContrieverRetriever(weave.Model):
 
78
  weave.init(project_name="ml-colabs/medrag-multi-modal")
79
  wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="contriever-index")
80
  retriever = ContrieverRetriever(model_name="facebook/contriever")
81
+ retriever.index(
82
+ chunk_dataset_name="grays-anatomy-chunks:v0",
83
+ index_name="grays-anatomy-contriever",
84
+ )
85
  ```
86
 
87
  Args:
 
96
  vector_index = self.encode(corpus)
97
  self._vector_index = vector_index
98
  if index_name:
99
+ save_vector_index(
100
+ self._vector_index,
101
+ "contriever-index",
102
+ index_name,
103
+ {"model_name": self.model_name},
104
  )
 
 
 
 
 
 
 
 
105
 
106
  @classmethod
107
  def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
medrag_multi_modal/retrieval/medcpt_retrieval.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import safetensors.torch
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import weave
9
+ from transformers import (
10
+ AutoModel,
11
+ AutoTokenizer,
12
+ BertPreTrainedModel,
13
+ PreTrainedTokenizerFast,
14
+ )
15
+
16
+ from ..utils import get_torch_backend, get_wandb_artifact
17
+ from .common import SimilarityMetric, argsort_scores, save_vector_index
18
+
19
+
20
+ class MedCPTRetriever(weave.Model):
21
+ """
22
+ A class to retrieve relevant text chunks using MedCPT models.
23
+
24
+ This class provides methods to index a dataset of text chunks and retrieve the most relevant
25
+ chunks for a given query using MedCPT models. It uses separate models for encoding queries
26
+ and articles, and supports both cosine similarity and Euclidean distance as similarity metrics.
27
+
28
+ Args:
29
+ query_encoder_model_name (str): The name of the model used for encoding queries.
30
+ article_encoder_model_name (str): The name of the model used for encoding articles.
31
+ chunk_size (Optional[int]): The maximum length of text chunks.
32
+ vector_index (Optional[torch.Tensor]): The vector index of encoded text chunks.
33
+ chunk_dataset (Optional[list[dict]]): The dataset of text chunks.
34
+ """
35
+
36
+ query_encoder_model_name: str
37
+ article_encoder_model_name: str
38
+ chunk_size: Optional[int]
39
+ _chunk_dataset: Optional[list[dict]]
40
+ _query_tokenizer: PreTrainedTokenizerFast
41
+ _article_tokenizer: PreTrainedTokenizerFast
42
+ _query_encoder_model: BertPreTrainedModel
43
+ _article_encoder_model: BertPreTrainedModel
44
+ _vector_index: Optional[torch.Tensor]
45
+
46
+ def __init__(
47
+ self,
48
+ query_encoder_model_name: str,
49
+ article_encoder_model_name: str,
50
+ chunk_size: Optional[int] = None,
51
+ vector_index: Optional[torch.Tensor] = None,
52
+ chunk_dataset: Optional[list[dict]] = None,
53
+ ):
54
+ super().__init__(
55
+ query_encoder_model_name=query_encoder_model_name,
56
+ article_encoder_model_name=article_encoder_model_name,
57
+ chunk_size=chunk_size,
58
+ )
59
+ self._query_tokenizer = AutoTokenizer.from_pretrained(
60
+ self.query_encoder_model_name
61
+ )
62
+ self._article_tokenizer = AutoTokenizer.from_pretrained(
63
+ self.article_encoder_model_name
64
+ )
65
+ self._query_encoder_model = AutoModel.from_pretrained(
66
+ self.query_encoder_model_name
67
+ )
68
+ self._article_encoder_model = AutoModel.from_pretrained(
69
+ self.article_encoder_model_name
70
+ )
71
+ self._chunk_dataset = chunk_dataset
72
+ self._vector_index = vector_index
73
+
74
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
75
+ """
76
+ Indexes a dataset of text chunks and optionally saves the vector index.
77
+
78
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the text
79
+ chunks using the article encoder model, and stores the resulting vector index. If an
80
+ index name is provided, the vector index is saved to a file using the `save_vector_index`
81
+ function.
82
+
83
+ !!! example "Example Usage"
84
+ ```python
85
+ import weave
86
+ from dotenv import load_dotenv
87
+
88
+ import wandb
89
+ from medrag_multi_modal.retrieval import MedCPTRetriever
90
+
91
+ load_dotenv()
92
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
93
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="medcpt-index")
94
+ retriever = MedCPTRetriever(
95
+ query_encoder_model_name="ncbi/MedCPT-Query-Encoder",
96
+ article_encoder_model_name="ncbi/MedCPT-Article-Encoder",
97
+ )
98
+ retriever.index(
99
+ chunk_dataset_name="grays-anatomy-chunks:v0",
100
+ index_name="grays-anatomy-medcpt",
101
+ )
102
+ ```
103
+
104
+ Args:
105
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
106
+ index_name (Optional[str]): The name to use when saving the vector index. If not provided,
107
+ the vector index is not saved.
108
+
109
+ """
110
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
111
+ corpus = [row["text"] for row in self._chunk_dataset]
112
+ with torch.no_grad():
113
+ encoded = self._article_tokenizer(
114
+ corpus,
115
+ truncation=True,
116
+ padding=True,
117
+ return_tensors="pt",
118
+ max_length=self.chunk_size,
119
+ )
120
+ vector_index = (
121
+ self._article_encoder_model(**encoded)
122
+ .last_hidden_state[:, 0, :]
123
+ .contiguous()
124
+ )
125
+ self._vector_index = vector_index
126
+ if index_name:
127
+ save_vector_index(
128
+ self._vector_index,
129
+ "medcpt-index",
130
+ index_name,
131
+ {
132
+ "query_encoder_model_name": self.query_encoder_model_name,
133
+ "article_encoder_model_name": self.article_encoder_model_name,
134
+ "chunk_size": self.chunk_size,
135
+ },
136
+ )
137
+
138
+ @classmethod
139
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
140
+ """
141
+ Initializes an instance of the class from a Weave artifact.
142
+
143
+ This method retrieves a precomputed vector index and its associated metadata from a Weave artifact
144
+ stored in Weights & Biases (wandb). It then loads the vector index into memory and initializes an
145
+ instance of the class with the retrieved model names, vector index, and chunk dataset.
146
+
147
+ !!! example "Example Usage"
148
+ ```python
149
+ import weave
150
+ from dotenv import load_dotenv
151
+
152
+ import wandb
153
+ from medrag_multi_modal.retrieval import MedCPTRetriever
154
+
155
+ load_dotenv()
156
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
157
+ retriever = MedCPTRetriever.from_wandb_artifact(
158
+ chunk_dataset_name="grays-anatomy-chunks:v0",
159
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
160
+ )
161
+ ```
162
+
163
+ Args:
164
+ chunk_dataset_name (str): The name of the dataset containing text chunks to be indexed.
165
+ index_artifact_address (str): The address of the Weave artifact containing the precomputed vector index.
166
+
167
+ Returns:
168
+ An instance of the class initialized with the retrieved model name, vector index, and chunk dataset.
169
+ """
170
+ artifact_dir, metadata = get_wandb_artifact(
171
+ index_artifact_address, "medcpt-index", get_metadata=True
172
+ )
173
+ with safetensors.torch.safe_open(
174
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
175
+ ) as f:
176
+ vector_index = f.get_tensor("vector_index")
177
+ device = torch.device(get_torch_backend())
178
+ vector_index = vector_index.to(device)
179
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
180
+ return cls(
181
+ query_encoder_model_name=metadata["query_encoder_model_name"],
182
+ article_encoder_model_name=metadata["article_encoder_model_name"],
183
+ chunk_size=metadata["chunk_size"],
184
+ vector_index=vector_index,
185
+ chunk_dataset=chunk_dataset,
186
+ )
187
+
188
+ @weave.op()
189
+ def retrieve(
190
+ self,
191
+ query: str,
192
+ top_k: int = 2,
193
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
194
+ ):
195
+ """
196
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
197
+
198
+ This method encodes the input query into an embedding and computes similarity scores between
199
+ the query embedding and the precomputed vector index. The similarity metric can be either
200
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
201
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
202
+
203
+ !!! example "Example Usage"
204
+ ```python
205
+ import weave
206
+ from dotenv import load_dotenv
207
+
208
+ import wandb
209
+ from medrag_multi_modal.retrieval import MedCPTRetriever
210
+
211
+ load_dotenv()
212
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
213
+ retriever = MedCPTRetriever.from_wandb_artifact(
214
+ chunk_dataset_name="grays-anatomy-chunks:v0",
215
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
216
+ )
217
+ retriever.retrieve(query="What are Ribosomes?")
218
+ ```
219
+
220
+ Args:
221
+ query (str): The input query string to search for relevant chunks.
222
+ top_k (int, optional): The number of top relevant chunks to retrieve. Defaults to 2.
223
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring. Defaults to cosine similarity.
224
+
225
+ Returns:
226
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
227
+ """
228
+ query = [query]
229
+ device = torch.device(get_torch_backend())
230
+ with torch.no_grad():
231
+ encoded = self._query_tokenizer(
232
+ query,
233
+ truncation=True,
234
+ padding=True,
235
+ return_tensors="pt",
236
+ )
237
+ query_embedding = self._query_encoder_model(**encoded).last_hidden_state[
238
+ :, 0, :
239
+ ]
240
+ query_embedding = query_embedding.to(device)
241
+ if metric == SimilarityMetric.EUCLIDEAN:
242
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
243
+ else:
244
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
245
+ scores = scores.cpu().numpy().tolist()
246
+ scores = argsort_scores(scores, descending=True)[:top_k]
247
+ retrieved_chunks = []
248
+ for score in scores:
249
+ retrieved_chunks.append(
250
+ {
251
+ "chunk": self._chunk_dataset[score["original_index"]],
252
+ "score": score["item"],
253
+ }
254
+ )
255
+ return retrieved_chunks
medrag_multi_modal/utils.py CHANGED
@@ -1,9 +1,12 @@
1
  import torch
 
2
  import wandb
3
 
4
 
5
  def get_wandb_artifact(
6
- artifact_name: str, artifact_type: str, get_metadata: bool = False
 
 
7
  ) -> str:
8
  if wandb.run:
9
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
 
1
  import torch
2
+
3
  import wandb
4
 
5
 
6
  def get_wandb_artifact(
7
+ artifact_name: str,
8
+ artifact_type: str,
9
+ get_metadata: bool = False,
10
  ) -> str:
11
  if wandb.run:
12
  artifact = wandb.use_artifact(artifact_name, type=artifact_type)
mkdocs.yml CHANGED
@@ -81,5 +81,6 @@ nav:
81
  - BM25-Sparse: 'retreival/bm25s.md'
82
  - ColPali: 'retreival/colpali.md'
83
  - Contriever: 'retreival/contriever.md'
 
84
 
85
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
81
  - BM25-Sparse: 'retreival/bm25s.md'
82
  - ColPali: 'retreival/colpali.md'
83
  - Contriever: 'retreival/contriever.md'
84
+ - MedCPT: 'retreival/medcpt.md'
85
 
86
  repo_url: https://github.com/soumik12345/medrag-multi-modal