geekyrakshit commited on
Commit
7934a8e
·
1 Parent(s): e4a917d

update: fix bug in LLMClient + add FigureAnnotator

Browse files
.gitignore CHANGED
@@ -17,6 +17,7 @@ wandb/
17
  .byaldi/
18
  cursor_prompt.txt
19
  test.py
 
20
  uv.lock
21
  grays-anatomy-bm25s/
22
  prompt**.txt
 
17
  .byaldi/
18
  cursor_prompt.txt
19
  test.py
20
+ test.ipynb
21
  uv.lock
22
  grays-anatomy-bm25s/
23
  prompt**.txt
medrag_multi_modal/assistant/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
- from .llm_client import LLMClient
 
2
  from .medqa_assistant import MedQAAssistant
3
 
4
- __all__ = ["LLMClient", "MedQAAssistant"]
 
1
+ from .figure_annotation import FigureAnnotator
2
+ from .llm_client import ClientType, LLMClient
3
  from .medqa_assistant import MedQAAssistant
4
 
5
+ __all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotator"]
medrag_multi_modal/assistant/figure_annotation.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union
3
+
4
+ import cv2
5
+ import weave
6
+ from PIL import Image
7
+ from rich.progress import track
8
+
9
+ from ..utils import get_wandb_artifact, read_jsonl_file
10
+ from .llm_client import LLMClient
11
+
12
+
13
+ class FigureAnnotator(weave.Model):
14
+ llm_client: LLMClient
15
+
16
+ @weave.op()
17
+ def annotate_figures(
18
+ self, page_image: Image.Image
19
+ ) -> dict[str, Union[Image.Image, str]]:
20
+ annotation = self.llm_client.predict(
21
+ system_prompt="""
22
+ You are an expert in the domain of scientific textbooks, especially medical texts.
23
+ You are presented with a page from a scientific textbook.
24
+ You are to first identify the number of figures in the image.
25
+ Then you are to identify the figure IDs associated with each figure in the image.
26
+ Then, you are to extract the exact figure descriptions from the image.
27
+
28
+ Here are some clues you need to follow:
29
+ 1. Figure IDs are unique identifiers for each figure in the image.
30
+ 2. Sometimes figure IDs can also be found as captions to the immediate left, right, top, or bottom of the figure.
31
+ 3. Figure IDs are in the form "Fig X.Y" where X and Y are integers. For example, 1.1, 1.2, 1.3, etc.
32
+ 4. Figure descriptions are contained as captions under the figures in the image, just after the figure ID.
33
+ 5. The text in the image is written in English and is present in a two-column format.
34
+ 6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
35
+ 7. There might be multiple figures present in the image.
36
+ """,
37
+ user_prompt=[page_image],
38
+ )
39
+ return {"page_image": page_image, "annotations": annotation}
40
+
41
+ @weave.op()
42
+ def predict(self, image_artifact_address: str):
43
+ artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
44
+ metadata = read_jsonl_file(os.path.join(artifact_dir, "metadata.jsonl"))
45
+ annotations = []
46
+ for item in track(metadata, description="Annotating images:"):
47
+ page_image = cv2.imread(
48
+ os.path.join(artifact_dir, f"page{item['page_idx']}.png")
49
+ )
50
+ page_image = cv2.cvtColor(page_image, cv2.COLOR_BGR2RGB)
51
+ page_image = Image.fromarray(page_image)
52
+ annotations.append(self.annotate_figures(page_image=page_image))
53
+ return annotations
medrag_multi_modal/assistant/llm_client.py CHANGED
@@ -9,7 +9,7 @@ from PIL import Image
9
  from ..utils import base64_encode_image
10
 
11
 
12
- class ClientType(Enum, str):
13
  GEMINI = "gemini"
14
  MISTRAL = "mistral"
15
 
@@ -80,7 +80,7 @@ class LLMClient(weave.Model):
80
  ]
81
 
82
  client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
83
- client = instructor.from_mistral(client)
84
 
85
  response = (
86
  client.chat.complete(model=self.model_name, messages=messages)
 
9
  from ..utils import base64_encode_image
10
 
11
 
12
+ class ClientType(str, Enum):
13
  GEMINI = "gemini"
14
  MISTRAL = "mistral"
15
 
 
80
  ]
81
 
82
  client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
83
+ client = instructor.from_mistral(client) if schema is not None else client
84
 
85
  response = (
86
  client.chat.complete(model=self.model_name, messages=messages)
medrag_multi_modal/utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import base64
2
  import io
3
 
 
4
  import torch
5
  from PIL import Image
6
 
@@ -36,8 +37,17 @@ def get_torch_backend():
36
 
37
 
38
  def base64_encode_image(image: Image.Image, mimetype: str) -> str:
 
 
 
39
  byte_arr = io.BytesIO()
40
  image.save(byte_arr, format="PNG")
41
  encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
42
  encoded_string = f"data:{mimetype};base64,{encoded_string}"
43
  return str(encoded_string)
 
 
 
 
 
 
 
1
  import base64
2
  import io
3
 
4
+ import jsonlines
5
  import torch
6
  from PIL import Image
7
 
 
37
 
38
 
39
  def base64_encode_image(image: Image.Image, mimetype: str) -> str:
40
+ image.load()
41
+ if image.mode not in ("RGB", "RGBA"):
42
+ image = image.convert("RGB")
43
  byte_arr = io.BytesIO()
44
  image.save(byte_arr, format="PNG")
45
  encoded_string = base64.b64encode(byte_arr.getvalue()).decode("utf-8")
46
  encoded_string = f"data:{mimetype};base64,{encoded_string}"
47
  return str(encoded_string)
48
+
49
+
50
+ def read_jsonl_file(file_path: str) -> list[dict[str, any]]:
51
+ with jsonlines.open(file_path) as reader:
52
+ for obj in reader:
53
+ return obj
pyproject.toml CHANGED
@@ -42,6 +42,7 @@ dependencies = [
42
  "mistralai>=1.1.0",
43
  "instructor>=1.6.3",
44
  "jsonlines>=4.0.0",
 
45
  ]
46
 
47
  [project.optional-dependencies]
@@ -69,6 +70,7 @@ core = [
69
  "mistralai>=1.1.0",
70
  "instructor>=1.6.3",
71
  "jsonlines>=4.0.0",
 
72
  ]
73
 
74
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
42
  "mistralai>=1.1.0",
43
  "instructor>=1.6.3",
44
  "jsonlines>=4.0.0",
45
+ "opencv-python>=4.10.0.84",
46
  ]
47
 
48
  [project.optional-dependencies]
 
70
  "mistralai>=1.1.0",
71
  "instructor>=1.6.3",
72
  "jsonlines>=4.0.0",
73
+ "opencv-python>=4.10.0.84",
74
  ]
75
 
76
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]