geekyrakshit commited on
Commit
e4a917d
·
1 Parent(s): 9c51e22

update: override MarkerImageLoader.load_data to align page indices to reflect pdf page numbers

Browse files
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py CHANGED
@@ -1,8 +1,9 @@
1
  import os
2
- from typing import Any, Dict
3
 
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
 
6
 
7
  from .base_img_loader import BaseImageLoader
8
 
@@ -48,10 +49,18 @@ class MarkerImageLoader(BaseImageLoader):
48
  url (str): The URL of the PDF document.
49
  document_name (str): The name of the document.
50
  document_file_path (str): The path to the PDF file.
 
51
  """
52
 
53
- def __init__(self, url: str, document_name: str, document_file_path: str):
 
 
 
 
 
 
54
  super().__init__(url, document_name, document_file_path)
 
55
  self.model_lst = load_all_models()
56
 
57
  async def extract_page_data(
@@ -92,6 +101,15 @@ class MarkerImageLoader(BaseImageLoader):
92
  image.save(image_file_path, "png")
93
  image_file_paths.append(image_file_path)
94
 
 
 
 
 
 
 
 
 
 
95
  return {
96
  "page_idx": page_idx,
97
  "document_name": self.document_name,
@@ -100,3 +118,25 @@ class MarkerImageLoader(BaseImageLoader):
100
  "image_file_paths": os.path.join(image_save_dir, "*.png"),
101
  "meta": out_meta,
102
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Any, Coroutine, Dict, List
3
 
4
  from marker.convert import convert_single_pdf
5
  from marker.models import load_all_models
6
+ from pdf2image.pdf2image import convert_from_path
7
 
8
  from .base_img_loader import BaseImageLoader
9
 
 
49
  url (str): The URL of the PDF document.
50
  document_name (str): The name of the document.
51
  document_file_path (str): The path to the PDF file.
52
+ save_page_image (bool): Whether to additionally save the image of the entire page.
53
  """
54
 
55
+ def __init__(
56
+ self,
57
+ url: str,
58
+ document_name: str,
59
+ document_file_path: str,
60
+ save_page_image: bool = False,
61
+ ):
62
  super().__init__(url, document_name, document_file_path)
63
+ self.save_page_image = save_page_image
64
  self.model_lst = load_all_models()
65
 
66
  async def extract_page_data(
 
101
  image.save(image_file_path, "png")
102
  image_file_paths.append(image_file_path)
103
 
104
+ if self.save_page_image:
105
+ page_image = convert_from_path(
106
+ self.document_file_path,
107
+ first_page=page_idx + 1,
108
+ last_page=page_idx + 1,
109
+ **kwargs,
110
+ )[0]
111
+ page_image.save(os.path.join(image_save_dir, f"page{page_idx}.png"))
112
+
113
  return {
114
  "page_idx": page_idx,
115
  "document_name": self.document_name,
 
118
  "image_file_paths": os.path.join(image_save_dir, "*.png"),
119
  "meta": out_meta,
120
  }
121
+
122
+ def load_data(
123
+ self,
124
+ start_page: int | None = None,
125
+ end_page: int | None = None,
126
+ wandb_artifact_name: str | None = None,
127
+ image_save_dir: str = "./images",
128
+ exclude_file_extensions: list[str] = [],
129
+ cleanup: bool = False,
130
+ **kwargs,
131
+ ) -> Coroutine[Any, Any, List[Dict[str, str]]]:
132
+ start_page = start_page - 1 if start_page is not None else None
133
+ end_page = end_page - 1 if end_page is not None else None
134
+ return super().load_data(
135
+ start_page,
136
+ end_page,
137
+ wandb_artifact_name,
138
+ image_save_dir,
139
+ exclude_file_extensions,
140
+ cleanup,
141
+ **kwargs,
142
+ )