Spaces:
Running
Running
Merge pull request #21 from soumik12345/feat/llm-client
Browse files- .gitignore +1 -0
- docs/assistant/figure_annotation.md +3 -0
- docs/assistant/llm_client.md +3 -0
- docs/assistant/medqa_assistant.md +3 -0
- medrag_multi_modal/assistant/__init__.py +5 -0
- medrag_multi_modal/assistant/figure_annotation.py +156 -0
- medrag_multi_modal/assistant/llm_client.py +237 -0
- medrag_multi_modal/assistant/medqa_assistant.py +108 -0
- medrag_multi_modal/document_loader/image_loader/base_img_loader.py +20 -4
- medrag_multi_modal/document_loader/image_loader/marker_img_loader.py +45 -3
- medrag_multi_modal/document_loader/text_loader/base_text_loader.py +1 -0
- medrag_multi_modal/document_loader/text_loader/marker_text_loader.py +3 -0
- medrag_multi_modal/retrieval/bm25s_retrieval.py +1 -1
- medrag_multi_modal/retrieval/contriever_retrieval.py +2 -2
- medrag_multi_modal/retrieval/medcpt_retrieval.py +2 -2
- medrag_multi_modal/retrieval/nv_embed_2.py +2 -2
- medrag_multi_modal/utils.py +22 -0
- mkdocs.yml +4 -0
- pyproject.toml +12 -0
.gitignore
CHANGED
@@ -17,6 +17,7 @@ wandb/
|
|
17 |
.byaldi/
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
|
|
20 |
uv.lock
|
21 |
grays-anatomy-bm25s/
|
22 |
prompt**.txt
|
|
|
17 |
.byaldi/
|
18 |
cursor_prompt.txt
|
19 |
test.py
|
20 |
+
test.ipynb
|
21 |
uv.lock
|
22 |
grays-anatomy-bm25s/
|
23 |
prompt**.txt
|
docs/assistant/figure_annotation.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Figure Annotation
|
2 |
+
|
3 |
+
::: medrag_multi_modal.assistant.figure_annotation
|
docs/assistant/llm_client.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# LLM Client
|
2 |
+
|
3 |
+
::: medrag_multi_modal.assistant.llm_client
|
docs/assistant/medqa_assistant.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# MedQA Assistant
|
2 |
+
|
3 |
+
::: medrag_multi_modal.assistant.medqa_assistant
|
medrag_multi_modal/assistant/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
2 |
+
from .llm_client import ClientType, LLMClient
|
3 |
+
from .medqa_assistant import MedQAAssistant
|
4 |
+
|
5 |
+
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
|
medrag_multi_modal/assistant/figure_annotation.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import weave
|
7 |
+
from PIL import Image
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from ..utils import get_wandb_artifact, read_jsonl_file
|
11 |
+
from .llm_client import LLMClient
|
12 |
+
|
13 |
+
|
14 |
+
class FigureAnnotation(BaseModel):
|
15 |
+
figure_id: str
|
16 |
+
figure_description: str
|
17 |
+
|
18 |
+
|
19 |
+
class FigureAnnotations(BaseModel):
|
20 |
+
annotations: list[FigureAnnotation]
|
21 |
+
|
22 |
+
|
23 |
+
class FigureAnnotatorFromPageImage(weave.Model):
|
24 |
+
"""
|
25 |
+
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
26 |
+
figures from a page image of a scientific textbook.
|
27 |
+
|
28 |
+
!!! example "Example Usage"
|
29 |
+
```python
|
30 |
+
import weave
|
31 |
+
from dotenv import load_dotenv
|
32 |
+
|
33 |
+
from medrag_multi_modal.assistant import (
|
34 |
+
FigureAnnotatorFromPageImage, LLMClient
|
35 |
+
)
|
36 |
+
|
37 |
+
load_dotenv()
|
38 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
39 |
+
figure_annotator = FigureAnnotatorFromPageImage(
|
40 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
41 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
42 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
43 |
+
)
|
44 |
+
annotations = figure_annotator.predict(page_idx=34)
|
45 |
+
```
|
46 |
+
|
47 |
+
Args:
|
48 |
+
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
49 |
+
from the page image.
|
50 |
+
structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
|
51 |
+
annotations into a structured format.
|
52 |
+
image_artifact_address (Optional[str]): The address of the image artifact containing the
|
53 |
+
page images.
|
54 |
+
"""
|
55 |
+
|
56 |
+
figure_extraction_llm_client: LLMClient
|
57 |
+
structured_output_llm_client: LLMClient
|
58 |
+
_artifact_dir: str
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
figure_extraction_llm_client: LLMClient,
|
63 |
+
structured_output_llm_client: LLMClient,
|
64 |
+
image_artifact_address: Optional[str] = None,
|
65 |
+
):
|
66 |
+
super().__init__(
|
67 |
+
figure_extraction_llm_client=figure_extraction_llm_client,
|
68 |
+
structured_output_llm_client=structured_output_llm_client,
|
69 |
+
)
|
70 |
+
self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
71 |
+
|
72 |
+
@weave.op()
|
73 |
+
def annotate_figures(
|
74 |
+
self, page_image: Image.Image
|
75 |
+
) -> dict[str, Union[Image.Image, str]]:
|
76 |
+
annotation = self.figure_extraction_llm_client.predict(
|
77 |
+
system_prompt="""
|
78 |
+
You are an expert in the domain of scientific textbooks, especially medical texts.
|
79 |
+
You are presented with a page from a scientific textbook from the domain of biology, specifically anatomy.
|
80 |
+
You are to first identify all the figures in the page image, which could be images or biological diagrams, charts, graphs, etc.
|
81 |
+
Then you are to identify the figure IDs associated with each figure in the page image.
|
82 |
+
Then, you are to extract only the exact figure descriptions from the page image.
|
83 |
+
You need to output the figure IDs and figure descriptions only, in a structured manner as a JSON object.
|
84 |
+
|
85 |
+
Here are some clues you need to follow:
|
86 |
+
1. Figure IDs are unique identifiers for each figure in the page image.
|
87 |
+
2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
|
88 |
+
3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
|
89 |
+
4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
|
90 |
+
5. The text in the page image is written in English and is present in a two-column format.
|
91 |
+
6. There is a clear distinction between the figure caption and the regular text in the page image in the form of extra white space.
|
92 |
+
You are to carefully identify all the figures in the page image.
|
93 |
+
7. There might be multiple figures or even no figures present in the page image. Sometimes the figures can be present side-by-side
|
94 |
+
or one above the other.
|
95 |
+
8. The figures may or may not have a distinct border against a white background.
|
96 |
+
10. You are not supposed to alter the figure description in any way present in the page image and you are to extract it as is.
|
97 |
+
""",
|
98 |
+
user_prompt=[page_image],
|
99 |
+
)
|
100 |
+
return {"page_image": page_image, "annotations": annotation}
|
101 |
+
|
102 |
+
@weave.op
|
103 |
+
def extract_structured_output(self, annotations: str) -> FigureAnnotations:
|
104 |
+
return self.structured_output_llm_client.predict(
|
105 |
+
system_prompt="You are suppossed to extract a list of figure annotations consisting of figure IDs and corresponding figure descriptions.",
|
106 |
+
user_prompt=[annotations],
|
107 |
+
schema=FigureAnnotations,
|
108 |
+
)
|
109 |
+
|
110 |
+
@weave.op()
|
111 |
+
def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
|
112 |
+
"""
|
113 |
+
Predicts figure annotations for a specific page in a document.
|
114 |
+
|
115 |
+
This function retrieves the artifact directory from the given image artifact address,
|
116 |
+
reads the metadata from the 'metadata.jsonl' file, and iterates through the metadata
|
117 |
+
to find the specified page index. If the page index matches, it reads the page image
|
118 |
+
and associated figure images, and then uses the `annotate_figures` method to extract
|
119 |
+
figure annotations from the page image. The extracted annotations are then structured
|
120 |
+
using the `extract_structured_output` method and returned as a dictionary.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
page_idx (int): The index of the page to annotate.
|
124 |
+
image_artifact_address (str): The address of the image artifact containing the
|
125 |
+
page images.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
dict: A dictionary containing the page index as the key and the extracted figure
|
129 |
+
annotations as the value.
|
130 |
+
"""
|
131 |
+
|
132 |
+
metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
|
133 |
+
annotations = {}
|
134 |
+
for item in metadata:
|
135 |
+
if item["page_idx"] == page_idx:
|
136 |
+
page_image_file = os.path.join(
|
137 |
+
self._artifact_dir, f"page{item['page_idx']}.png"
|
138 |
+
)
|
139 |
+
figure_image_files = glob(
|
140 |
+
os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
|
141 |
+
)
|
142 |
+
if len(figure_image_files) > 0:
|
143 |
+
page_image = cv2.imread(page_image_file)
|
144 |
+
page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
|
145 |
+
page_image = Image.fromarray(page_image)
|
146 |
+
figure_extracted_annotations = self.annotate_figures(
|
147 |
+
page_image=page_image
|
148 |
+
)
|
149 |
+
figure_extracted_annotations = self.extract_structured_output(
|
150 |
+
figure_extracted_annotations["annotations"]
|
151 |
+
).model_dump()
|
152 |
+
annotations[item["page_idx"]] = figure_extracted_annotations[
|
153 |
+
"annotations"
|
154 |
+
]
|
155 |
+
break
|
156 |
+
return annotations
|
medrag_multi_modal/assistant/llm_client.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Any, Optional, Union
|
4 |
+
|
5 |
+
import instructor
|
6 |
+
import weave
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from ..utils import base64_encode_image
|
10 |
+
|
11 |
+
|
12 |
+
class ClientType(str, Enum):
|
13 |
+
GEMINI = "gemini"
|
14 |
+
MISTRAL = "mistral"
|
15 |
+
OPENAI = "openai"
|
16 |
+
|
17 |
+
|
18 |
+
GOOGLE_MODELS = [
|
19 |
+
"gemini-1.0-pro-latest",
|
20 |
+
"gemini-1.0-pro",
|
21 |
+
"gemini-pro",
|
22 |
+
"gemini-1.0-pro-001",
|
23 |
+
"gemini-1.0-pro-vision-latest",
|
24 |
+
"gemini-pro-vision",
|
25 |
+
"gemini-1.5-pro-latest",
|
26 |
+
"gemini-1.5-pro-001",
|
27 |
+
"gemini-1.5-pro-002",
|
28 |
+
"gemini-1.5-pro",
|
29 |
+
"gemini-1.5-pro-exp-0801",
|
30 |
+
"gemini-1.5-pro-exp-0827",
|
31 |
+
"gemini-1.5-flash-latest",
|
32 |
+
"gemini-1.5-flash-001",
|
33 |
+
"gemini-1.5-flash-001-tuning",
|
34 |
+
"gemini-1.5-flash",
|
35 |
+
"gemini-1.5-flash-exp-0827",
|
36 |
+
"gemini-1.5-flash-002",
|
37 |
+
"gemini-1.5-flash-8b",
|
38 |
+
"gemini-1.5-flash-8b-001",
|
39 |
+
"gemini-1.5-flash-8b-latest",
|
40 |
+
"gemini-1.5-flash-8b-exp-0827",
|
41 |
+
"gemini-1.5-flash-8b-exp-0924",
|
42 |
+
]
|
43 |
+
|
44 |
+
MISTRAL_MODELS = [
|
45 |
+
"ministral-3b-latest",
|
46 |
+
"ministral-8b-latest",
|
47 |
+
"mistral-large-latest",
|
48 |
+
"mistral-small-latest",
|
49 |
+
"codestral-latest",
|
50 |
+
"pixtral-12b-2409",
|
51 |
+
"open-mistral-nemo",
|
52 |
+
"open-codestral-mamba",
|
53 |
+
"open-mistral-7b",
|
54 |
+
"open-mixtral-8x7b",
|
55 |
+
"open-mixtral-8x22b",
|
56 |
+
]
|
57 |
+
|
58 |
+
OPENAI_MODELS = ["gpt-4o", "gpt-4o-2024-08-06", "gpt-4o-mini", "gpt-4o-mini-2024-07-18"]
|
59 |
+
|
60 |
+
|
61 |
+
class LLMClient(weave.Model):
|
62 |
+
"""
|
63 |
+
LLMClient is a class that interfaces with different large language model (LLM) providers
|
64 |
+
such as Google Gemini, Mistral, and OpenAI. It abstracts the complexity of interacting with
|
65 |
+
these different APIs and provides a unified interface for making predictions.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
model_name (str): The name of the model to be used for predictions.
|
69 |
+
client_type (Optional[ClientType]): The type of client (e.g., GEMINI, MISTRAL, OPENAI).
|
70 |
+
If not provided, it is inferred from the model_name.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_name: str
|
74 |
+
client_type: Optional[ClientType]
|
75 |
+
|
76 |
+
def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
|
77 |
+
if client_type is None:
|
78 |
+
if model_name in GOOGLE_MODELS:
|
79 |
+
client_type = ClientType.GEMINI
|
80 |
+
elif model_name in MISTRAL_MODELS:
|
81 |
+
client_type = ClientType.MISTRAL
|
82 |
+
elif model_name in OPENAI_MODELS:
|
83 |
+
client_type = ClientType.OPENAI
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Invalid model name: {model_name}")
|
86 |
+
super().__init__(model_name=model_name, client_type=client_type)
|
87 |
+
|
88 |
+
@weave.op()
|
89 |
+
def execute_gemini_sdk(
|
90 |
+
self,
|
91 |
+
user_prompt: Union[str, list[str]],
|
92 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
93 |
+
schema: Optional[Any] = None,
|
94 |
+
) -> Union[str, Any]:
|
95 |
+
import google.generativeai as genai
|
96 |
+
|
97 |
+
system_prompt = (
|
98 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
99 |
+
)
|
100 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
101 |
+
|
102 |
+
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
|
103 |
+
model = genai.GenerativeModel(self.model_name)
|
104 |
+
generation_config = (
|
105 |
+
None
|
106 |
+
if schema is None
|
107 |
+
else genai.GenerationConfig(
|
108 |
+
response_mime_type="application/json", response_schema=list[schema]
|
109 |
+
)
|
110 |
+
)
|
111 |
+
response = model.generate_content(
|
112 |
+
system_prompt + user_prompt, generation_config=generation_config
|
113 |
+
)
|
114 |
+
return response.text if schema is None else response
|
115 |
+
|
116 |
+
@weave.op()
|
117 |
+
def execute_mistral_sdk(
|
118 |
+
self,
|
119 |
+
user_prompt: Union[str, list[str]],
|
120 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
121 |
+
schema: Optional[Any] = None,
|
122 |
+
) -> Union[str, Any]:
|
123 |
+
from mistralai import Mistral
|
124 |
+
|
125 |
+
system_prompt = (
|
126 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
127 |
+
)
|
128 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
129 |
+
system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
|
130 |
+
user_messages = []
|
131 |
+
for prompt in user_prompt:
|
132 |
+
if isinstance(prompt, Image.Image):
|
133 |
+
user_messages.append(
|
134 |
+
{
|
135 |
+
"type": "image_url",
|
136 |
+
"image_url": base64_encode_image(prompt, "image/png"),
|
137 |
+
}
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
user_messages.append({"type": "text", "text": prompt})
|
141 |
+
messages = [
|
142 |
+
{"role": "system", "content": system_messages},
|
143 |
+
{"role": "user", "content": user_messages},
|
144 |
+
]
|
145 |
+
|
146 |
+
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
|
147 |
+
client = instructor.from_mistral(client) if schema is not None else client
|
148 |
+
|
149 |
+
response = (
|
150 |
+
client.chat.complete(model=self.model_name, messages=messages)
|
151 |
+
if schema is None
|
152 |
+
else client.messages.create(
|
153 |
+
response_model=schema, messages=messages, temperature=0
|
154 |
+
)
|
155 |
+
)
|
156 |
+
return response.choices[0].message.content
|
157 |
+
|
158 |
+
@weave.op()
|
159 |
+
def execute_openai_sdk(
|
160 |
+
self,
|
161 |
+
user_prompt: Union[str, list[str]],
|
162 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
163 |
+
schema: Optional[Any] = None,
|
164 |
+
) -> Union[str, Any]:
|
165 |
+
from openai import OpenAI
|
166 |
+
|
167 |
+
system_prompt = (
|
168 |
+
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
169 |
+
)
|
170 |
+
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
|
171 |
+
|
172 |
+
system_messages = [
|
173 |
+
{"role": "system", "content": prompt} for prompt in system_prompt
|
174 |
+
]
|
175 |
+
user_messages = []
|
176 |
+
for prompt in user_prompt:
|
177 |
+
if isinstance(prompt, Image.Image):
|
178 |
+
user_messages.append(
|
179 |
+
{
|
180 |
+
"type": "image_url",
|
181 |
+
"image_url": {
|
182 |
+
"url": base64_encode_image(prompt, "image/png"),
|
183 |
+
},
|
184 |
+
},
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
user_messages.append({"type": "text", "text": prompt})
|
188 |
+
messages = system_messages + [{"role": "user", "content": user_messages}]
|
189 |
+
|
190 |
+
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
191 |
+
|
192 |
+
if schema is None:
|
193 |
+
completion = client.chat.completions.create(
|
194 |
+
model=self.model_name, messages=messages
|
195 |
+
)
|
196 |
+
return completion.choices[0].message.content
|
197 |
+
|
198 |
+
completion = weave.op()(client.beta.chat.completions.parse)(
|
199 |
+
model=self.model_name, messages=messages, response_format=schema
|
200 |
+
)
|
201 |
+
return completion.choices[0].message.parsed
|
202 |
+
|
203 |
+
@weave.op()
|
204 |
+
def predict(
|
205 |
+
self,
|
206 |
+
user_prompt: Union[str, list[str]],
|
207 |
+
system_prompt: Optional[Union[str, list[str]]] = None,
|
208 |
+
schema: Optional[Any] = None,
|
209 |
+
) -> Union[str, Any]:
|
210 |
+
"""
|
211 |
+
Predicts the response from a language model based on the provided prompts and schema.
|
212 |
+
|
213 |
+
This function determines the client type and calls the appropriate SDK execution function
|
214 |
+
to get the response from the language model. It supports multiple client types including
|
215 |
+
GEMINI, MISTRAL, and OPENAI. Depending on the client type, it calls the corresponding
|
216 |
+
execution function with the provided user and system prompts, and an optional schema.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
user_prompt (Union[str, list[str]]): The user prompt(s) to be sent to the language model.
|
220 |
+
system_prompt (Optional[Union[str, list[str]]]): The system prompt(s) to be sent to the language model.
|
221 |
+
schema (Optional[Any]): The schema to be used for parsing the response, if applicable.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
Union[str, Any]: The response from the language model, which could be a string or any other type
|
225 |
+
depending on the schema provided.
|
226 |
+
|
227 |
+
Raises:
|
228 |
+
ValueError: If the client type is invalid.
|
229 |
+
"""
|
230 |
+
if self.client_type == ClientType.GEMINI:
|
231 |
+
return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
|
232 |
+
elif self.client_type == ClientType.MISTRAL:
|
233 |
+
return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
|
234 |
+
elif self.client_type == ClientType.OPENAI:
|
235 |
+
return self.execute_openai_sdk(user_prompt, system_prompt, schema)
|
236 |
+
else:
|
237 |
+
raise ValueError(f"Invalid client type: {self.client_type}")
|
medrag_multi_modal/assistant/medqa_assistant.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import weave
|
2 |
+
|
3 |
+
from ..retrieval import SimilarityMetric
|
4 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
5 |
+
from .llm_client import LLMClient
|
6 |
+
|
7 |
+
|
8 |
+
class MedQAAssistant(weave.Model):
|
9 |
+
"""
|
10 |
+
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
|
11 |
+
language model client, a retriever model, and a figure annotator.
|
12 |
+
|
13 |
+
!!! example "Usage Example"
|
14 |
+
```python
|
15 |
+
import weave
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
|
18 |
+
from medrag_multi_modal.assistant import (
|
19 |
+
FigureAnnotatorFromPageImage,
|
20 |
+
LLMClient,
|
21 |
+
MedQAAssistant,
|
22 |
+
)
|
23 |
+
from medrag_multi_modal.retrieval import MedCPTRetriever
|
24 |
+
|
25 |
+
load_dotenv()
|
26 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
27 |
+
|
28 |
+
llm_client = LLMClient(model_name="gemini-1.5-flash")
|
29 |
+
|
30 |
+
retriever=MedCPTRetriever.from_wandb_artifact(
|
31 |
+
chunk_dataset_name="grays-anatomy-chunks:v0",
|
32 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
|
33 |
+
)
|
34 |
+
|
35 |
+
figure_annotator=FigureAnnotatorFromPageImage(
|
36 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
37 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
38 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
39 |
+
)
|
40 |
+
medqa_assistant = MedQAAssistant(
|
41 |
+
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
|
42 |
+
)
|
43 |
+
medqa_assistant.predict(query="What is ribosome?")
|
44 |
+
```
|
45 |
+
|
46 |
+
Args:
|
47 |
+
llm_client (LLMClient): The language model client used to generate responses.
|
48 |
+
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
49 |
+
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
50 |
+
top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
|
51 |
+
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
52 |
+
"""
|
53 |
+
|
54 |
+
llm_client: LLMClient
|
55 |
+
retriever: weave.Model
|
56 |
+
figure_annotator: FigureAnnotatorFromPageImage
|
57 |
+
top_k_chunks: int = 2
|
58 |
+
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
59 |
+
|
60 |
+
@weave.op()
|
61 |
+
def predict(self, query: str) -> str:
|
62 |
+
"""
|
63 |
+
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
64 |
+
from a medical document and using a language model to generate the final response.
|
65 |
+
|
66 |
+
This function performs the following steps:
|
67 |
+
1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
|
68 |
+
2. Extracts the text and page indices from the retrieved chunks.
|
69 |
+
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
70 |
+
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
|
71 |
+
5. Uses the language model client to generate a response based on the constructed prompts.
|
72 |
+
6. Appends the source information (page numbers) to the generated response.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
query (str): The medical query to be answered.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
str: The generated response to the query, including source information.
|
79 |
+
"""
|
80 |
+
retrieved_chunks = self.retriever.predict(
|
81 |
+
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
82 |
+
)
|
83 |
+
|
84 |
+
retrieved_chunk_texts = []
|
85 |
+
page_indices = set()
|
86 |
+
for chunk in retrieved_chunks:
|
87 |
+
retrieved_chunk_texts.append(chunk["text"])
|
88 |
+
page_indices.add(int(chunk["page_idx"]))
|
89 |
+
|
90 |
+
figure_descriptions = []
|
91 |
+
for page_idx in page_indices:
|
92 |
+
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
|
93 |
+
page_idx
|
94 |
+
]
|
95 |
+
figure_descriptions += [
|
96 |
+
item["figure_description"] for item in figure_annotations
|
97 |
+
]
|
98 |
+
|
99 |
+
system_prompt = """
|
100 |
+
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
101 |
+
"""
|
102 |
+
response = self.llm_client.predict(
|
103 |
+
system_prompt=system_prompt,
|
104 |
+
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
105 |
+
)
|
106 |
+
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
107 |
+
response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
108 |
+
return response
|
medrag_multi_modal/document_loader/image_loader/base_img_loader.py
CHANGED
@@ -3,6 +3,7 @@ import os
|
|
3 |
from abc import abstractmethod
|
4 |
from typing import Dict, List, Optional
|
5 |
|
|
|
6 |
import rich
|
7 |
|
8 |
import wandb
|
@@ -41,7 +42,8 @@ class BaseImageLoader(BaseTextLoader):
|
|
41 |
end_page: Optional[int] = None,
|
42 |
wandb_artifact_name: Optional[str] = None,
|
43 |
image_save_dir: str = "./images",
|
44 |
-
|
|
|
45 |
**kwargs,
|
46 |
) -> List[Dict[str, str]]:
|
47 |
"""
|
@@ -61,10 +63,11 @@ class BaseImageLoader(BaseTextLoader):
|
|
61 |
If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
|
62 |
|
63 |
Args:
|
64 |
-
start_page (Optional[int]): The starting page index (0-based) to process.
|
65 |
-
end_page (Optional[int]): The ending page index (0-based) to process.
|
66 |
wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
|
67 |
image_save_dir (str): The directory to save the extracted images.
|
|
|
68 |
cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
|
69 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
70 |
|
@@ -99,8 +102,21 @@ class BaseImageLoader(BaseTextLoader):
|
|
99 |
for task in asyncio.as_completed(tasks):
|
100 |
await task
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
if wandb_artifact_name:
|
103 |
-
artifact = wandb.Artifact(
|
|
|
|
|
|
|
|
|
104 |
artifact.add_dir(local_path=image_save_dir)
|
105 |
artifact.save()
|
106 |
rich.print("Artifact saved and uploaded to wandb!")
|
|
|
3 |
from abc import abstractmethod
|
4 |
from typing import Dict, List, Optional
|
5 |
|
6 |
+
import jsonlines
|
7 |
import rich
|
8 |
|
9 |
import wandb
|
|
|
42 |
end_page: Optional[int] = None,
|
43 |
wandb_artifact_name: Optional[str] = None,
|
44 |
image_save_dir: str = "./images",
|
45 |
+
exclude_file_extensions: list[str] = [],
|
46 |
+
cleanup: bool = False,
|
47 |
**kwargs,
|
48 |
) -> List[Dict[str, str]]:
|
49 |
"""
|
|
|
63 |
If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
|
64 |
|
65 |
Args:
|
66 |
+
start_page (Optional[int]): The starting page index (0-based) to process.
|
67 |
+
end_page (Optional[int]): The ending page index (0-based) to process.
|
68 |
wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
|
69 |
image_save_dir (str): The directory to save the extracted images.
|
70 |
+
exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
|
71 |
cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
|
72 |
**kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
|
73 |
|
|
|
102 |
for task in asyncio.as_completed(tasks):
|
103 |
await task
|
104 |
|
105 |
+
with jsonlines.open(
|
106 |
+
os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
|
107 |
+
) as writer:
|
108 |
+
writer.write(pages)
|
109 |
+
|
110 |
+
for file in os.listdir(image_save_dir):
|
111 |
+
if file.endswith(tuple(exclude_file_extensions)):
|
112 |
+
os.remove(os.path.join(image_save_dir, file))
|
113 |
+
|
114 |
if wandb_artifact_name:
|
115 |
+
artifact = wandb.Artifact(
|
116 |
+
name=wandb_artifact_name,
|
117 |
+
type="dataset",
|
118 |
+
metadata={"loader_name": self.__class__.__name__},
|
119 |
+
)
|
120 |
artifact.add_dir(local_path=image_save_dir)
|
121 |
artifact.save()
|
122 |
rich.print("Artifact saved and uploaded to wandb!")
|
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
import os
|
2 |
-
from typing import Any, Dict
|
3 |
|
4 |
from marker.convert import convert_single_pdf
|
5 |
from marker.models import load_all_models
|
|
|
6 |
|
7 |
from .base_img_loader import BaseImageLoader
|
8 |
|
|
|
|
|
9 |
|
10 |
class MarkerImageLoader(BaseImageLoader):
|
11 |
"""
|
@@ -46,10 +49,18 @@ class MarkerImageLoader(BaseImageLoader):
|
|
46 |
url (str): The URL of the PDF document.
|
47 |
document_name (str): The name of the document.
|
48 |
document_file_path (str): The path to the PDF file.
|
|
|
49 |
"""
|
50 |
|
51 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
super().__init__(url, document_name, document_file_path)
|
|
|
53 |
self.model_lst = load_all_models()
|
54 |
|
55 |
async def extract_page_data(
|
@@ -90,11 +101,42 @@ class MarkerImageLoader(BaseImageLoader):
|
|
90 |
image.save(image_file_path, "png")
|
91 |
image_file_paths.append(image_file_path)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
return {
|
94 |
"page_idx": page_idx,
|
95 |
"document_name": self.document_name,
|
96 |
"file_path": self.document_file_path,
|
97 |
"file_url": self.url,
|
98 |
-
"image_file_paths":
|
99 |
"meta": out_meta,
|
100 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import Any, Coroutine, Dict, List
|
3 |
|
4 |
from marker.convert import convert_single_pdf
|
5 |
from marker.models import load_all_models
|
6 |
+
from pdf2image.pdf2image import convert_from_path
|
7 |
|
8 |
from .base_img_loader import BaseImageLoader
|
9 |
|
10 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
11 |
+
|
12 |
|
13 |
class MarkerImageLoader(BaseImageLoader):
|
14 |
"""
|
|
|
49 |
url (str): The URL of the PDF document.
|
50 |
document_name (str): The name of the document.
|
51 |
document_file_path (str): The path to the PDF file.
|
52 |
+
save_page_image (bool): Whether to additionally save the image of the entire page.
|
53 |
"""
|
54 |
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
url: str,
|
58 |
+
document_name: str,
|
59 |
+
document_file_path: str,
|
60 |
+
save_page_image: bool = False,
|
61 |
+
):
|
62 |
super().__init__(url, document_name, document_file_path)
|
63 |
+
self.save_page_image = save_page_image
|
64 |
self.model_lst = load_all_models()
|
65 |
|
66 |
async def extract_page_data(
|
|
|
101 |
image.save(image_file_path, "png")
|
102 |
image_file_paths.append(image_file_path)
|
103 |
|
104 |
+
if self.save_page_image:
|
105 |
+
page_image = convert_from_path(
|
106 |
+
self.document_file_path,
|
107 |
+
first_page=page_idx + 1,
|
108 |
+
last_page=page_idx + 1,
|
109 |
+
**kwargs,
|
110 |
+
)[0]
|
111 |
+
page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
|
112 |
+
|
113 |
return {
|
114 |
"page_idx": page_idx,
|
115 |
"document_name": self.document_name,
|
116 |
"file_path": self.document_file_path,
|
117 |
"file_url": self.url,
|
118 |
+
"image_file_paths": os.path.join(image_save_dir, "*.png"),
|
119 |
"meta": out_meta,
|
120 |
}
|
121 |
+
|
122 |
+
def load_data(
|
123 |
+
self,
|
124 |
+
start_page: int | None = None,
|
125 |
+
end_page: int | None = None,
|
126 |
+
wandb_artifact_name: str | None = None,
|
127 |
+
image_save_dir: str = "./images",
|
128 |
+
exclude_file_extensions: list[str] = [],
|
129 |
+
cleanup: bool = False,
|
130 |
+
**kwargs,
|
131 |
+
) -> Coroutine[Any, Any, List[Dict[str, str]]]:
|
132 |
+
start_page = start_page - 1 if start_page is not None else None
|
133 |
+
end_page = end_page - 1 if end_page is not None else None
|
134 |
+
return super().load_data(
|
135 |
+
start_page,
|
136 |
+
end_page,
|
137 |
+
wandb_artifact_name,
|
138 |
+
image_save_dir,
|
139 |
+
exclude_file_extensions,
|
140 |
+
cleanup,
|
141 |
+
**kwargs,
|
142 |
+
)
|
medrag_multi_modal/document_loader/text_loader/base_text_loader.py
CHANGED
@@ -131,6 +131,7 @@ class BaseTextLoader(ABC):
|
|
131 |
async def process_page(page_idx):
|
132 |
nonlocal processed_pages_counter
|
133 |
page_data = await self.extract_page_data(page_idx, **kwargs)
|
|
|
134 |
pages.append(page_data)
|
135 |
rich.print(
|
136 |
f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
|
|
|
131 |
async def process_page(page_idx):
|
132 |
nonlocal processed_pages_counter
|
133 |
page_data = await self.extract_page_data(page_idx, **kwargs)
|
134 |
+
page_data["loader_name"] = self.__class__.__name__
|
135 |
pages.append(page_data)
|
136 |
rich.print(
|
137 |
f"Processed page idx: {page_idx}, progress: {processed_pages_counter}/{total_pages}"
|
medrag_multi_modal/document_loader/text_loader/marker_text_loader.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Dict
|
2 |
|
3 |
from marker.convert import convert_single_pdf
|
@@ -5,6 +6,8 @@ from marker.models import load_all_models
|
|
5 |
|
6 |
from .base_text_loader import BaseTextLoader
|
7 |
|
|
|
|
|
8 |
|
9 |
class MarkerTextLoader(BaseTextLoader):
|
10 |
"""
|
|
|
1 |
+
import os
|
2 |
from typing import Dict
|
3 |
|
4 |
from marker.convert import convert_single_pdf
|
|
|
6 |
|
7 |
from .base_text_loader import BaseTextLoader
|
8 |
|
9 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
10 |
+
|
11 |
|
12 |
class MarkerTextLoader(BaseTextLoader):
|
13 |
"""
|
medrag_multi_modal/retrieval/bm25s_retrieval.py
CHANGED
@@ -175,7 +175,7 @@ class BM25sRetriever(weave.Model):
|
|
175 |
results.documents.flatten().tolist(),
|
176 |
results.scores.flatten().tolist(),
|
177 |
):
|
178 |
-
retrieved_chunks.append({
|
179 |
return retrieved_chunks
|
180 |
|
181 |
@weave.op()
|
|
|
175 |
results.documents.flatten().tolist(),
|
176 |
results.scores.flatten().tolist(),
|
177 |
):
|
178 |
+
retrieved_chunks.append({**chunk, **{"score": score}})
|
179 |
return retrieved_chunks
|
180 |
|
181 |
@weave.op()
|
medrag_multi_modal/retrieval/contriever_retrieval.py
CHANGED
@@ -192,8 +192,8 @@ class ContrieverRetriever(weave.Model):
|
|
192 |
for score in scores:
|
193 |
retrieved_chunks.append(
|
194 |
{
|
195 |
-
|
196 |
-
"score": score["item"],
|
197 |
}
|
198 |
)
|
199 |
return retrieved_chunks
|
|
|
192 |
for score in scores:
|
193 |
retrieved_chunks.append(
|
194 |
{
|
195 |
+
**self._chunk_dataset[score["original_index"]],
|
196 |
+
**{"score": score["item"]},
|
197 |
}
|
198 |
)
|
199 |
return retrieved_chunks
|
medrag_multi_modal/retrieval/medcpt_retrieval.py
CHANGED
@@ -231,8 +231,8 @@ class MedCPTRetriever(weave.Model):
|
|
231 |
for score in scores:
|
232 |
retrieved_chunks.append(
|
233 |
{
|
234 |
-
|
235 |
-
"score": score["item"],
|
236 |
}
|
237 |
)
|
238 |
return retrieved_chunks
|
|
|
231 |
for score in scores:
|
232 |
retrieved_chunks.append(
|
233 |
{
|
234 |
+
**self._chunk_dataset[score["original_index"]],
|
235 |
+
**{"score": score["item"]},
|
236 |
}
|
237 |
)
|
238 |
return retrieved_chunks
|
medrag_multi_modal/retrieval/nv_embed_2.py
CHANGED
@@ -217,8 +217,8 @@ class NVEmbed2Retriever(weave.Model):
|
|
217 |
for score in scores:
|
218 |
retrieved_chunks.append(
|
219 |
{
|
220 |
-
|
221 |
-
"score": score["item"],
|
222 |
}
|
223 |
)
|
224 |
return retrieved_chunks
|
|
|
217 |
for score in scores:
|
218 |
retrieved_chunks.append(
|
219 |
{
|
220 |
+
**self._chunk_dataset[score["original_index"]],
|
221 |
+
**{"score": score["item"]},
|
222 |
}
|
223 |
)
|
224 |
return retrieved_chunks
|
medrag_multi_modal/utils.py
CHANGED
@@ -1,4 +1,9 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
|
|
2 |
|
3 |
import wandb
|
4 |
|
@@ -29,3 +34,20 @@ def get_torch_backend():
|
|
29 |
return "mps"
|
30 |
return "cpu"
|
31 |
return "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
|
4 |
+
import jsonlines
|
5 |
import torch
|
6 |
+
from PIL import Image
|
7 |
|
8 |
import wandb
|
9 |
|
|
|
34 |
return "mps"
|
35 |
return "cpu"
|
36 |
return "cpu"
|
37 |
+
|
38 |
+
|
39 |
+
def base64_encode_image(image: Image.Image, mimetype: str) -> str:
|
40 |
+
image.load()
|
41 |
+
if image.mode not in ("RGB", "RGBA"):
|
42 |
+
image = image.convert("RGB")
|
43 |
+
byte_arr = io.BytesIO()
|
44 |
+
image.save(byte_arr, format="PNG")
|
45 |
+
encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
|
46 |
+
encoded_string = f"data:{mimetype};base64,{encoded_string}"
|
47 |
+
return str(encoded_string)
|
48 |
+
|
49 |
+
|
50 |
+
def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
|
51 |
+
with jsonlines.open(file_path) as reader:
|
52 |
+
for obj in reader:
|
53 |
+
return obj
|
mkdocs.yml
CHANGED
@@ -83,5 +83,9 @@ nav:
|
|
83 |
- Contriever: 'retreival/contriever.md'
|
84 |
- MedCPT: 'retreival/medcpt.md'
|
85 |
- NV-Embed-v2: 'retreival/nv_embed_2.md'
|
|
|
|
|
|
|
|
|
86 |
|
87 |
repo_url: https://github.com/soumik12345/medrag-multi-modal
|
|
|
83 |
- Contriever: 'retreival/contriever.md'
|
84 |
- MedCPT: 'retreival/medcpt.md'
|
85 |
- NV-Embed-v2: 'retreival/nv_embed_2.md'
|
86 |
+
- Assistant:
|
87 |
+
- MedQA Assistant: 'assistant/medqa_assistant.md'
|
88 |
+
- Figure Annotation: 'assistant/figure_annotation.md'
|
89 |
+
- LLM Client: 'assistant/llm_client.md'
|
90 |
|
91 |
repo_url: https://github.com/soumik12345/medrag-multi-modal
|
pyproject.toml
CHANGED
@@ -38,6 +38,12 @@ dependencies = [
|
|
38 |
"semchunk>=2.2.0",
|
39 |
"tiktoken>=0.8.0",
|
40 |
"sentence-transformers>=3.2.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
]
|
42 |
|
43 |
[project.optional-dependencies]
|
@@ -61,6 +67,12 @@ core = [
|
|
61 |
"torch>=2.4.1",
|
62 |
"weave>=0.51.14",
|
63 |
"sentence-transformers>=3.2.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
]
|
65 |
|
66 |
dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
|
|
|
38 |
"semchunk>=2.2.0",
|
39 |
"tiktoken>=0.8.0",
|
40 |
"sentence-transformers>=3.2.0",
|
41 |
+
"google-generativeai>=0.8.3",
|
42 |
+
"mistralai>=1.1.0",
|
43 |
+
"instructor>=1.6.3",
|
44 |
+
"jsonlines>=4.0.0",
|
45 |
+
"opencv-python>=4.10.0.84",
|
46 |
+
"openai>=1.52.2",
|
47 |
]
|
48 |
|
49 |
[project.optional-dependencies]
|
|
|
67 |
"torch>=2.4.1",
|
68 |
"weave>=0.51.14",
|
69 |
"sentence-transformers>=3.2.0",
|
70 |
+
"google-generativeai>=0.8.3",
|
71 |
+
"mistralai>=1.1.0",
|
72 |
+
"instructor>=1.6.3",
|
73 |
+
"jsonlines>=4.0.0",
|
74 |
+
"opencv-python>=4.10.0.84",
|
75 |
+
"openai>=1.52.2",
|
76 |
]
|
77 |
|
78 |
dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
|