Atanu Sarkar commited on
Commit
a19b9c3
·
unverified ·
2 Parent(s): 8bd2693 6c6905f

Merge pull request #21 from soumik12345/feat/llm-client

Browse files
.gitignore CHANGED
@@ -17,6 +17,7 @@ wandb/
17
  .byaldi/
18
  cursor_prompt.txt
19
  test.py
 
20
  uv.lock
21
  grays-anatomy-bm25s/
22
  prompt**.txt
 
17
  .byaldi/
18
  cursor_prompt.txt
19
  test.py
20
+ test.ipynb
21
  uv.lock
22
  grays-anatomy-bm25s/
23
  prompt**.txt
docs/assistant/figure_annotation.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Figure Annotation
2
+
3
+ ::: medrag_multi_modal.assistant.figure_annotation
docs/assistant/llm_client.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # LLM Client
2
+
3
+ ::: medrag_multi_modal.assistant.llm_client
docs/assistant/medqa_assistant.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # MedQA Assistant
2
+
3
+ ::: medrag_multi_modal.assistant.medqa_assistant
medrag_multi_modal/assistant/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .figure_annotation import FigureAnnotatorFromPageImage
2
+ from .llm_client import ClientType, LLMClient
3
+ from .medqa_assistant import MedQAAssistant
4
+
5
+ __all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
medrag_multi_modal/assistant/figure_annotation.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from typing import Optional, Union
4
+
5
+ import cv2
6
+ import weave
7
+ from PIL import Image
8
+ from pydantic import BaseModel
9
+
10
+ from ..utils import get_wandb_artifact, read_jsonl_file
11
+ from .llm_client import LLMClient
12
+
13
+
14
+ class FigureAnnotation(BaseModel):
15
+ figure_id: str
16
+ figure_description: str
17
+
18
+
19
+ class FigureAnnotations(BaseModel):
20
+ annotations: list[FigureAnnotation]
21
+
22
+
23
+ class FigureAnnotatorFromPageImage(weave.Model):
24
+ """
25
+ `FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
26
+ figures from a page image of a scientific textbook.
27
+
28
+ !!! example "Example Usage"
29
+ ```python
30
+ import weave
31
+ from dotenv import load_dotenv
32
+
33
+ from medrag_multi_modal.assistant import (
34
+ FigureAnnotatorFromPageImage, LLMClient
35
+ )
36
+
37
+ load_dotenv()
38
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
39
+ figure_annotator = FigureAnnotatorFromPageImage(
40
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
41
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
42
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
43
+ )
44
+ annotations = figure_annotator.predict(page_idx=34)
45
+ ```
46
+
47
+ Args:
48
+ figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
49
+ from the page image.
50
+ structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
51
+ annotations into a structured format.
52
+ image_artifact_address (Optional[str]): The address of the image artifact containing the
53
+ page images.
54
+ """
55
+
56
+ figure_extraction_llm_client: LLMClient
57
+ structured_output_llm_client: LLMClient
58
+ _artifact_dir: str
59
+
60
+ def __init__(
61
+ self,
62
+ figure_extraction_llm_client: LLMClient,
63
+ structured_output_llm_client: LLMClient,
64
+ image_artifact_address: Optional[str] = None,
65
+ ):
66
+ super().__init__(
67
+ figure_extraction_llm_client=figure_extraction_llm_client,
68
+ structured_output_llm_client=structured_output_llm_client,
69
+ )
70
+ self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
71
+
72
+ @weave.op()
73
+ def annotate_figures(
74
+ self, page_image: Image.Image
75
+ ) -> dict[str, Union[Image.Image, str]]:
76
+ annotation = self.figure_extraction_llm_client.predict(
77
+ system_prompt="""
78
+ You are an expert in the domain of scientific textbooks, especially medical texts.
79
+ You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
80
+ You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc.
81
+ Then you are to identify the figure IDs associated with each figure in the page image.
82
+ Then, you are to extract only the exact figure descriptions from the page image.
83
+ You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object.
84
+
85
+ Here are some clues you need to follow:
86
+ 1. Figure IDs are unique identifiers for each figure in the page image.
87
+ 2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
88
+ 3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
89
+ 4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
90
+ 5. The text in the page image is written in English and is present in a two-column format.
91
+ 6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space.
92
+ You are to carefully identify all the figures in the page image.
93
+ 7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side
94
+ or one above the other.
95
+ 8. The figures may or may not have a distinct border against a white background.
96
+ 10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is.
97
+ """,
98
+ user_prompt=[page_image],
99
+ )
100
+ return {"page_image": page_image, "annotations": annotation}
101
+
102
+ @weave.op
103
+ def extract_structured_output(self, annotations: str) -> FigureAnnotations:
104
+ return self.structured_output_llm_client.predict(
105
+ system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
106
+ user_prompt=[annotations],
107
+ schema=FigureAnnotations,
108
+ )
109
+
110
+ @weave.op()
111
+ def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
112
+ """
113
+ Predicts figure annotations for a specific page in a document.
114
+
115
+ This function retrieves the artifact directory from the given image artifact address,
116
+ reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
117
+ to find the specified page index. If the page index matches, it reads the page image
118
+ and associated figure images, and then uses the `annotate_figures` method to extract
119
+ figure annotations from the page image. The extracted annotations are then structured
120
+ using the `extract_structured_output` method and returned as a dictionary.
121
+
122
+ Args:
123
+ page_idx (int): The index of the page to annotate.
124
+ image_artifact_address (str): The address of the image artifact containing the
125
+ page images.
126
+
127
+ Returns:
128
+ dict: A dictionary containing the page index as the key and the extracted figure
129
+ annotations as the value.
130
+ """
131
+
132
+ metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
133
+ annotations = {}
134
+ for item in metadata:
135
+ if item["page_idx"] == page_idx:
136
+ page_image_file = os.path.join(
137
+ self._artifact_dir, f"page{item['page_idx']}.png"
138
+ )
139
+ figure_image_files = glob(
140
+ os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
141
+ )
142
+ if len(figure_image_files) > 0:
143
+ page_image = cv2.imread(page_image_file)
144
+ page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
145
+ page_image = Image.fromarray(page_image)
146
+ figure_extracted_annotations = self.annotate_figures(
147
+ page_image=page_image
148
+ )
149
+ figure_extracted_annotations = self.extract_structured_output(
150
+ figure_extracted_annotations["annotations"]
151
+ ).model_dump()
152
+ annotations[item["page_idx"]] = figure_extracted_annotations[
153
+ "annotations"
154
+ ]
155
+ break
156
+ return annotations
medrag_multi_modal/assistant/llm_client.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from typing import Any, Optional, Union
4
+
5
+ import instructor
6
+ import weave
7
+ from PIL import Image
8
+
9
+ from ..utils import base64_encode_image
10
+
11
+
12
+ class ClientType(str, Enum):
13
+ GEMINI = "gemini"
14
+ MISTRAL = "mistral"
15
+ OPENAI = "openai"
16
+
17
+
18
+ GOOGLE_MODELS = [
19
+ "gemini-1.0-pro-latest",
20
+ "gemini-1.0-pro",
21
+ "gemini-pro",
22
+ "gemini-1.0-pro-001",
23
+ "gemini-1.0-pro-vision-latest",
24
+ "gemini-pro-vision",
25
+ "gemini-1.5-pro-latest",
26
+ "gemini-1.5-pro-001",
27
+ "gemini-1.5-pro-002",
28
+ "gemini-1.5-pro",
29
+ "gemini-1.5-pro-exp-0801",
30
+ "gemini-1.5-pro-exp-0827",
31
+ "gemini-1.5-flash-latest",
32
+ "gemini-1.5-flash-001",
33
+ "gemini-1.5-flash-001-tuning",
34
+ "gemini-1.5-flash",
35
+ "gemini-1.5-flash-exp-0827",
36
+ "gemini-1.5-flash-002",
37
+ "gemini-1.5-flash-8b",
38
+ "gemini-1.5-flash-8b-001",
39
+ "gemini-1.5-flash-8b-latest",
40
+ "gemini-1.5-flash-8b-exp-0827",
41
+ "gemini-1.5-flash-8b-exp-0924",
42
+ ]
43
+
44
+ MISTRAL_MODELS = [
45
+ "ministral-3b-latest",
46
+ "ministral-8b-latest",
47
+ "mistral-large-latest",
48
+ "mistral-small-latest",
49
+ "codestral-latest",
50
+ "pixtral-12b-2409",
51
+ "open-mistral-nemo",
52
+ "open-codestral-mamba",
53
+ "open-mistral-7b",
54
+ "open-mixtral-8x7b",
55
+ "open-mixtral-8x22b",
56
+ ]
57
+
58
+ OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
59
+
60
+
61
+ class LLMClient(weave.Model):
62
+ """
63
+ LLMClient is a class that interfaces with different large language model (LLM) providers
64
+ such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
65
+ these different APIs and provides a unified interface for making predictions.
66
+
67
+ Args:
68
+ model_name (str): The name of the model to be used for predictions.
69
+ client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
70
+ If not provided, it is inferred from the model_name.
71
+ """
72
+
73
+ model_name: str
74
+ client_type: Optional[ClientType]
75
+
76
+ def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
77
+ if client_type is None:
78
+ if model_name in GOOGLE_MODELS:
79
+ client_type = ClientType.GEMINI
80
+ elif model_name in MISTRAL_MODELS:
81
+ client_type = ClientType.MISTRAL
82
+ elif model_name in OPENAI_MODELS:
83
+ client_type = ClientType.OPENAI
84
+ else:
85
+ raise ValueError(f"Invalid model name: {model_name}")
86
+ super().__init__(model_name=model_name, client_type=client_type)
87
+
88
+ @weave.op()
89
+ def execute_gemini_sdk(
90
+ self,
91
+ user_prompt: Union[str, list[str]],
92
+ system_prompt: Optional[Union[str, list[str]]] = None,
93
+ schema: Optional[Any] = None,
94
+ ) -> Union[str, Any]:
95
+ import google.generativeai as genai
96
+
97
+ system_prompt = (
98
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
99
+ )
100
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
101
+
102
+ genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
103
+ model = genai.GenerativeModel(self.model_name)
104
+ generation_config = (
105
+ None
106
+ if schema is None
107
+ else genai.GenerationConfig(
108
+ response_mime_type="application/json", response_schema=list[schema]
109
+ )
110
+ )
111
+ response = model.generate_content(
112
+ system_prompt + user_prompt, generation_config=generation_config
113
+ )
114
+ return response.text if schema is None else response
115
+
116
+ @weave.op()
117
+ def execute_mistral_sdk(
118
+ self,
119
+ user_prompt: Union[str, list[str]],
120
+ system_prompt: Optional[Union[str, list[str]]] = None,
121
+ schema: Optional[Any] = None,
122
+ ) -> Union[str, Any]:
123
+ from mistralai import Mistral
124
+
125
+ system_prompt = (
126
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
127
+ )
128
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
129
+ system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
130
+ user_messages = []
131
+ for prompt in user_prompt:
132
+ if isinstance(prompt, Image.Image):
133
+ user_messages.append(
134
+ {
135
+ "type": "image_url",
136
+ "image_url": base64_encode_image(prompt, "image/png"),
137
+ }
138
+ )
139
+ else:
140
+ user_messages.append({"type": "text", "text": prompt})
141
+ messages = [
142
+ {"role": "system", "content": system_messages},
143
+ {"role": "user", "content": user_messages},
144
+ ]
145
+
146
+ client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
147
+ client = instructor.from_mistral(client) if schema is not None else client
148
+
149
+ response = (
150
+ client.chat.complete(model=self.model_name, messages=messages)
151
+ if schema is None
152
+ else client.messages.create(
153
+ response_model=schema, messages=messages, temperature=0
154
+ )
155
+ )
156
+ return response.choices[0].message.content
157
+
158
+ @weave.op()
159
+ def execute_openai_sdk(
160
+ self,
161
+ user_prompt: Union[str, list[str]],
162
+ system_prompt: Optional[Union[str, list[str]]] = None,
163
+ schema: Optional[Any] = None,
164
+ ) -> Union[str, Any]:
165
+ from openai import OpenAI
166
+
167
+ system_prompt = (
168
+ [system_prompt] if isinstance(system_prompt, str) else system_prompt
169
+ )
170
+ user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
171
+
172
+ system_messages = [
173
+ {"role": "system", "content": prompt} for prompt in system_prompt
174
+ ]
175
+ user_messages = []
176
+ for prompt in user_prompt:
177
+ if isinstance(prompt, Image.Image):
178
+ user_messages.append(
179
+ {
180
+ "type": "image_url",
181
+ "image_url": {
182
+ "url": base64_encode_image(prompt, "image/png"),
183
+ },
184
+ },
185
+ )
186
+ else:
187
+ user_messages.append({"type": "text", "text": prompt})
188
+ messages = system_messages + [{"role": "user", "content": user_messages}]
189
+
190
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
191
+
192
+ if schema is None:
193
+ completion = client.chat.completions.create(
194
+ model=self.model_name, messages=messages
195
+ )
196
+ return completion.choices[0].message.content
197
+
198
+ completion = weave.op()(client.beta.chat.completions.parse)(
199
+ model=self.model_name, messages=messages, response_format=schema
200
+ )
201
+ return completion.choices[0].message.parsed
202
+
203
+ @weave.op()
204
+ def predict(
205
+ self,
206
+ user_prompt: Union[str, list[str]],
207
+ system_prompt: Optional[Union[str, list[str]]] = None,
208
+ schema: Optional[Any] = None,
209
+ ) -> Union[str, Any]:
210
+ """
211
+ Predicts the response from a language model based on the provided prompts and schema.
212
+
213
+ This function determines the client type and calls the appropriate SDK execution function
214
+ to get the response from the language model. It supports multiple client types including
215
+ GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
216
+ execution function with the provided user and system prompts, and an optional schema.
217
+
218
+ Args:
219
+ user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
220
+ system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
221
+ schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
222
+
223
+ Returns:
224
+ Union[str, Any]: The response from the language model, which could be a string or any other type
225
+ depending on the schema provided.
226
+
227
+ Raises:
228
+ ValueError: If the client type is invalid.
229
+ """
230
+ if self.client_type == ClientType.GEMINI:
231
+ return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
232
+ elif self.client_type == ClientType.MISTRAL:
233
+ return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
234
+ elif self.client_type == ClientType.OPENAI:
235
+ return self.execute_openai_sdk(user_prompt, system_prompt, schema)
236
+ else:
237
+ raise ValueError(f"Invalid client type: {self.client_type}")
medrag_multi_modal/assistant/medqa_assistant.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weave
2
+
3
+ from ..retrieval import SimilarityMetric
4
+ from .figure_annotation import FigureAnnotatorFromPageImage
5
+ from .llm_client import LLMClient
6
+
7
+
8
+ class MedQAAssistant(weave.Model):
9
+ """
10
+ `MedQAAssistant` is a class designed to assist with medical queries by leveraging a
11
+ language model client, a retriever model, and a figure annotator.
12
+
13
+ !!! example "Usage Example"
14
+ ```python
15
+ import weave
16
+ from dotenv import load_dotenv
17
+
18
+ from medrag_multi_modal.assistant import (
19
+ FigureAnnotatorFromPageImage,
20
+ LLMClient,
21
+ MedQAAssistant,
22
+ )
23
+ from medrag_multi_modal.retrieval import MedCPTRetriever
24
+
25
+ load_dotenv()
26
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
27
+
28
+ llm_client = LLMClient(model_name="gemini-1.5-flash")
29
+
30
+ retriever=MedCPTRetriever.from_wandb_artifact(
31
+ chunk_dataset_name="grays-anatomy-chunks:v0",
32
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
33
+ )
34
+
35
+ figure_annotator=FigureAnnotatorFromPageImage(
36
+ figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
37
+ structured_output_llm_client=LLMClient(model_name="gpt-4o"),
38
+ image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
39
+ )
40
+ medqa_assistant = MedQAAssistant(
41
+ llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
42
+ )
43
+ medqa_assistant.predict(query="What is ribosome?")
44
+ ```
45
+
46
+ Args:
47
+ llm_client (LLMClient): The language model client used to generate responses.
48
+ retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
49
+ figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
50
+ top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
51
+ retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
52
+ """
53
+
54
+ llm_client: LLMClient
55
+ retriever: weave.Model
56
+ figure_annotator: FigureAnnotatorFromPageImage
57
+ top_k_chunks: int = 2
58
+ retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
59
+
60
+ @weave.op()
61
+ def predict(self, query: str) -> str:
62
+ """
63
+ Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
64
+ from a medical document and using a language model to generate the final response.
65
+
66
+ This function performs the following steps:
67
+ 1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
68
+ 2. Extracts the text and page indices from the retrieved chunks.
69
+ 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
70
+ 4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
71
+ 5. Uses the language model client to generate a response based on the constructed prompts.
72
+ 6. Appends the source information (page numbers) to the generated response.
73
+
74
+ Args:
75
+ query (str): The medical query to be answered.
76
+
77
+ Returns:
78
+ str: The generated response to the query, including source information.
79
+ """
80
+ retrieved_chunks = self.retriever.predict(
81
+ query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
82
+ )
83
+
84
+ retrieved_chunk_texts = []
85
+ page_indices = set()
86
+ for chunk in retrieved_chunks:
87
+ retrieved_chunk_texts.append(chunk["text"])
88
+ page_indices.add(int(chunk["page_idx"]))
89
+
90
+ figure_descriptions = []
91
+ for page_idx in page_indices:
92
+ figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
93
+ page_idx
94
+ ]
95
+ figure_descriptions += [
96
+ item["figure_description"] for item in figure_annotations
97
+ ]
98
+
99
+ system_prompt = """
100
+ You are an expert in medical science. You are given a query and a list of chunks from a medical document.
101
+ """
102
+ response = self.llm_client.predict(
103
+ system_prompt=system_prompt,
104
+ user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
105
+ )
106
+ page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
107
+ response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
108
+ return response
medrag_multi_modal/document_loader/image_loader/base_img_loader.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from abc import abstractmethod
4
  from typing import Dict, List, Optional
5
 
 
6
  import rich
7
 
8
  import wandb
@@ -41,7 +42,8 @@ class BaseImageLoader(BaseTextLoader):
41
  end_page: Optional[int] = None,
42
  wandb_artifact_name: Optional[str] = None,
43
  image_save_dir: str = "./images",
44
- cleanup: bool = True,
 
45
  **kwargs,
46
  ) -> List[Dict[str, str]]:
47
  """
@@ -61,10 +63,11 @@ class BaseImageLoader(BaseTextLoader):
61
  If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
62
 
63
  Args:
64
- start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
65
- end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
66
  wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
67
  image_save_dir (str): The directory to save the extracted images.
 
68
  cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
69
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
70
 
@@ -99,8 +102,21 @@ class BaseImageLoader(BaseTextLoader):
99
  for task in asyncio.as_completed(tasks):
100
  await task
101
 
 
 
 
 
 
 
 
 
 
102
  if wandb_artifact_name:
103
- artifact = wandb.Artifact(name=wandb_artifact_name, type="dataset")
 
 
 
 
104
  artifact.add_dir(local_path=image_save_dir)
105
  artifact.save()
106
  rich.print("Artifact saved and uploaded to wandb!")
 
3
  from abc import abstractmethod
4
  from typing import Dict, List, Optional
5
 
6
+ import jsonlines
7
  import rich
8
 
9
  import wandb
 
42
  end_page: Optional[int] = None,
43
  wandb_artifact_name: Optional[str] = None,
44
  image_save_dir: str = "./images",
45
+ exclude_file_extensions: list[str] = [],
46
+ cleanup: bool = False,
47
  **kwargs,
48
  ) -> List[Dict[str, str]]:
49
  """
 
63
  If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
64
 
65
  Args:
66
+ start_page (Optional[int]): The starting page index (0-based) to process.
67
+ end_page (Optional[int]): The ending page index (0-based) to process.
68
  wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
69
  image_save_dir (str): The directory to save the extracted images.
70
+ exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
71
  cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
72
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
73
 
 
102
  for task in asyncio.as_completed(tasks):
103
  await task
104
 
105
+ with jsonlines.open(
106
+ os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
107
+ ) as writer:
108
+ writer.write(pages)
109
+
110
+ for file in os.listdir(image_save_dir):
111
+ if file.endswith(tuple(exclude_file_extensions)):
112
+ os.remove(os.path.join(image_save_dir, file))
113
+
114
  if wandb_artifact_name:
115
+ artifact = wandb.Artifact(
116
+ name=wandb_artifact_name,
117
+ type="dataset",
118
+ metadata={"loader_name": self.__class__.__name__},
119
+ )
120
  artifact.add_dir(local_path=image_save_dir)
121
  artifact.save()
122
  rich.print("Artifact saved and uploaded to wandb!")
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py CHANGED
@@ -1,11 +1,14 @@
1
  import os
2
- from typing import Any, Dict
3
 
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
 
6
 
7
  from .base_img_loader import BaseImageLoader
8
 
 
 
9
 
10
  class MarkerImageLoader(BaseImageLoader):
11
  """
@@ -46,10 +49,18 @@ class MarkerImageLoader(BaseImageLoader):
46
  url (str): The URL of the PDF document.
47
  document_name (str): The name of the document.
48
  document_file_path (str): The path to the PDF file.
 
49
  """
50
 
51
- def __init__(self, url: str, document_name: str, document_file_path: str):
 
 
 
 
 
 
52
  super().__init__(url, document_name, document_file_path)
 
53
  self.model_lst = load_all_models()
54
 
55
  async def extract_page_data(
@@ -90,11 +101,42 @@ class MarkerImageLoader(BaseImageLoader):
90
  image.save(image_file_path, "png")
91
  image_file_paths.append(image_file_path)
92
 
 
 
 
 
 
 
 
 
 
93
  return {
94
  "page_idx": page_idx,
95
  "document_name": self.document_name,
96
  "file_path": self.document_file_path,
97
  "file_url": self.url,
98
- "image_file_paths": image_file_paths,
99
  "meta": out_meta,
100
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, Coroutine, Dict, List
3
 
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
6
+ from pdf2image.pdf2image import convert_from_path
7
 
8
  from .base_img_loader import BaseImageLoader
9
 
10
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
11
+
12
 
13
  class MarkerImageLoader(BaseImageLoader):
14
  """
 
49
  url (str): The URL of the PDF document.
50
  document_name (str): The name of the document.
51
  document_file_path (str): The path to the PDF file.
52
+ save_page_image (bool): Whether to additionally save the image of the entire page.
53
  """
54
 
55
+ def __init__(
56
+ self,
57
+ url: str,
58
+ document_name: str,
59
+ document_file_path: str,
60
+ save_page_image: bool = False,
61
+ ):
62
  super().__init__(url, document_name, document_file_path)
63
+ self.save_page_image = save_page_image
64
  self.model_lst = load_all_models()
65
 
66
  async def extract_page_data(
 
101
  image.save(image_file_path, "png")
102
  image_file_paths.append(image_file_path)
103
 
104
+ if self.save_page_image:
105
+ page_image = convert_from_path(
106
+ self.document_file_path,
107
+ first_page=page_idx + 1,
108
+ last_page=page_idx + 1,
109
+ **kwargs,
110
+ )[0]
111
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
112
+
113
  return {
114
  "page_idx": page_idx,
115
  "document_name": self.document_name,
116
  "file_path": self.document_file_path,
117
  "file_url": self.url,
118
+ "image_file_paths": os.path.join(image_save_dir, "*.png"),
119
  "meta": out_meta,
120
  }
121
+
122
+ def load_data(
123
+ self,
124
+ start_page: int | None = None,
125
+ end_page: int | None = None,
126
+ wandb_artifact_name: str | None = None,
127
+ image_save_dir: str = "./images",
128
+ exclude_file_extensions: list[str] = [],
129
+ cleanup: bool = False,
130
+ **kwargs,
131
+ ) -> Coroutine[Any, Any, List[Dict[str, str]]]:
132
+ start_page = start_page - 1 if start_page is not None else None
133
+ end_page = end_page - 1 if end_page is not None else None
134
+ return super().load_data(
135
+ start_page,
136
+ end_page,
137
+ wandb_artifact_name,
138
+ image_save_dir,
139
+ exclude_file_extensions,
140
+ cleanup,
141
+ **kwargs,
142
+ )
medrag_multi_modal/document_loader/text_loader/base_text_loader.py CHANGED
@@ -131,6 +131,7 @@ class BaseTextLoader(ABC):
131
  async def process_page(page_idx):
132
  nonlocal processed_pages_counter
133
  page_data = await self.extract_page_data(page_idx, **kwargs)
 
134
  pages.append(page_data)
135
  rich.print(
136
  f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
 
131
  async def process_page(page_idx):
132
  nonlocal processed_pages_counter
133
  page_data = await self.extract_page_data(page_idx, **kwargs)
134
+ page_data["loader_name"] = self.__class__.__name__
135
  pages.append(page_data)
136
  rich.print(
137
  f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict
2
 
3
  from marker.convert import convert_single_pdf
@@ -5,6 +6,8 @@ from marker.models import load_all_models
5
 
6
  from .base_text_loader import BaseTextLoader
7
 
 
 
8
 
9
  class MarkerTextLoader(BaseTextLoader):
10
  """
 
1
+ import os
2
  from typing import Dict
3
 
4
  from marker.convert import convert_single_pdf
 
6
 
7
  from .base_text_loader import BaseTextLoader
8
 
9
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
+
11
 
12
  class MarkerTextLoader(BaseTextLoader):
13
  """
medrag_multi_modal/retrieval/bm25s_retrieval.py CHANGED
@@ -175,7 +175,7 @@ class BM25sRetriever(weave.Model):
175
  results.documents.flatten().tolist(),
176
  results.scores.flatten().tolist(),
177
  ):
178
- retrieved_chunks.append({"chunk": chunk, "score": score})
179
  return retrieved_chunks
180
 
181
  @weave.op()
 
175
  results.documents.flatten().tolist(),
176
  results.scores.flatten().tolist(),
177
  ):
178
+ retrieved_chunks.append({**chunk, **{"score": score}})
179
  return retrieved_chunks
180
 
181
  @weave.op()
medrag_multi_modal/retrieval/contriever_retrieval.py CHANGED
@@ -192,8 +192,8 @@ class ContrieverRetriever(weave.Model):
192
  for score in scores:
193
  retrieved_chunks.append(
194
  {
195
- "chunk": self._chunk_dataset[score["original_index"]],
196
- "score": score["item"],
197
  }
198
  )
199
  return retrieved_chunks
 
192
  for score in scores:
193
  retrieved_chunks.append(
194
  {
195
+ **self._chunk_dataset[score["original_index"]],
196
+ **{"score": score["item"]},
197
  }
198
  )
199
  return retrieved_chunks
medrag_multi_modal/retrieval/medcpt_retrieval.py CHANGED
@@ -231,8 +231,8 @@ class MedCPTRetriever(weave.Model):
231
  for score in scores:
232
  retrieved_chunks.append(
233
  {
234
- "chunk": self._chunk_dataset[score["original_index"]],
235
- "score": score["item"],
236
  }
237
  )
238
  return retrieved_chunks
 
231
  for score in scores:
232
  retrieved_chunks.append(
233
  {
234
+ **self._chunk_dataset[score["original_index"]],
235
+ **{"score": score["item"]},
236
  }
237
  )
238
  return retrieved_chunks
medrag_multi_modal/retrieval/nv_embed_2.py CHANGED
@@ -217,8 +217,8 @@ class NVEmbed2Retriever(weave.Model):
217
  for score in scores:
218
  retrieved_chunks.append(
219
  {
220
- "chunk": self._chunk_dataset[score["original_index"]],
221
- "score": score["item"],
222
  }
223
  )
224
  return retrieved_chunks
 
217
  for score in scores:
218
  retrieved_chunks.append(
219
  {
220
+ **self._chunk_dataset[score["original_index"]],
221
+ **{"score": score["item"]},
222
  }
223
  )
224
  return retrieved_chunks
medrag_multi_modal/utils.py CHANGED
@@ -1,4 +1,9 @@
 
 
 
 
1
  import torch
 
2
 
3
  import wandb
4
 
@@ -29,3 +34,20 @@ def get_torch_backend():
29
  return "mps"
30
  return "cpu"
31
  return "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+
4
+ import jsonlines
5
  import torch
6
+ from PIL import Image
7
 
8
  import wandb
9
 
 
34
  return "mps"
35
  return "cpu"
36
  return "cpu"
37
+
38
+
39
+ def base64_encode_image(image: Image.Image, mimetype: str) -> str:
40
+ image.load()
41
+ if image.mode not in ("RGB", "RGBA"):
42
+ image = image.convert("RGB")
43
+ byte_arr = io.BytesIO()
44
+ image.save(byte_arr, format="PNG")
45
+ encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
46
+ encoded_string = f"data:{mimetype};base64,{encoded_string}"
47
+ return str(encoded_string)
48
+
49
+
50
+ def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
51
+ with jsonlines.open(file_path) as reader:
52
+ for obj in reader:
53
+ return obj
mkdocs.yml CHANGED
@@ -83,5 +83,9 @@ nav:
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
85
  - NV-Embed-v2: 'retreival/nv_embed_2.md'
 
 
 
 
86
 
87
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
85
  - NV-Embed-v2: 'retreival/nv_embed_2.md'
86
+ - Assistant:
87
+ - MedQA Assistant: 'assistant/medqa_assistant.md'
88
+ - Figure Annotation: 'assistant/figure_annotation.md'
89
+ - LLM Client: 'assistant/llm_client.md'
90
 
91
  repo_url: https://github.com/soumik12345/medrag-multi-modal
pyproject.toml CHANGED
@@ -38,6 +38,12 @@ dependencies = [
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
  "sentence-transformers>=3.2.0",
 
 
 
 
 
 
41
  ]
42
 
43
  [project.optional-dependencies]
@@ -61,6 +67,12 @@ core = [
61
  "torch>=2.4.1",
62
  "weave>=0.51.14",
63
  "sentence-transformers>=3.2.0",
 
 
 
 
 
 
64
  ]
65
 
66
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
  "sentence-transformers>=3.2.0",
41
+ "google-generativeai>=0.8.3",
42
+ "mistralai>=1.1.0",
43
+ "instructor>=1.6.3",
44
+ "jsonlines>=4.0.0",
45
+ "opencv-python>=4.10.0.84",
46
+ "openai>=1.52.2",
47
  ]
48
 
49
  [project.optional-dependencies]
 
67
  "torch>=2.4.1",
68
  "weave>=0.51.14",
69
  "sentence-transformers>=3.2.0",
70
+ "google-generativeai>=0.8.3",
71
+ "mistralai>=1.1.0",
72
+ "instructor>=1.6.3",
73
+ "jsonlines>=4.0.0",
74
+ "opencv-python>=4.10.0.84",
75
+ "openai>=1.52.2",
76
  ]
77
 
78
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]