geekyrakshit commited on
Commit
9c51e22
·
1 Parent(s): 789b57f

update: BaseImageLoader + MarkerImageLoader

Browse files
medrag_multi_modal/document_loader/image_loader/base_img_loader.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  from abc import abstractmethod
4
  from typing import Dict, List, Optional
5
 
 
6
  import rich
7
 
8
  import wandb
@@ -41,7 +42,8 @@ class BaseImageLoader(BaseTextLoader):
41
  end_page: Optional[int] = None,
42
  wandb_artifact_name: Optional[str] = None,
43
  image_save_dir: str = "./images",
44
- cleanup: bool = True,
 
45
  **kwargs,
46
  ) -> List[Dict[str, str]]:
47
  """
@@ -61,10 +63,11 @@ class BaseImageLoader(BaseTextLoader):
61
  If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
62
 
63
  Args:
64
- start_page (Optional[int]): The starting page index (0-based) to process. Defaults to the first page.
65
- end_page (Optional[int]): The ending page index (0-based) to process. Defaults to the last page.
66
  wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
67
  image_save_dir (str): The directory to save the extracted images.
 
68
  cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
69
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
70
 
@@ -99,6 +102,15 @@ class BaseImageLoader(BaseTextLoader):
99
  for task in asyncio.as_completed(tasks):
100
  await task
101
 
 
 
 
 
 
 
 
 
 
102
  if wandb_artifact_name:
103
  artifact = wandb.Artifact(
104
  name=wandb_artifact_name,
 
3
  from abc import abstractmethod
4
  from typing import Dict, List, Optional
5
 
6
+ import jsonlines
7
  import rich
8
 
9
  import wandb
 
42
  end_page: Optional[int] = None,
43
  wandb_artifact_name: Optional[str] = None,
44
  image_save_dir: str = "./images",
45
+ exclude_file_extensions: list[str] = [],
46
+ cleanup: bool = False,
47
  **kwargs,
48
  ) -> List[Dict[str, str]]:
49
  """
 
63
  If a wandb_artifact_name is provided, the processed pages are published to a WandB artifact.
64
 
65
  Args:
66
+ start_page (Optional[int]): The starting page index (0-based) to process.
67
+ end_page (Optional[int]): The ending page index (0-based) to process.
68
  wandb_artifact_name (Optional[str]): The name of the WandB artifact to publish the pages to, if provided.
69
  image_save_dir (str): The directory to save the extracted images.
70
+ exclude_file_extensions (list[str]): A list of file extensions to exclude from the image_save_dir.
71
  cleanup (bool): Whether to remove extracted images from `image_save_dir`, if uploading to wandb artifact.
72
  **kwargs: Additional keyword arguments that will be passed to extract_page_data method and the underlying library.
73
 
 
102
  for task in asyncio.as_completed(tasks):
103
  await task
104
 
105
+ with jsonlines.open(
106
+ os.path.join(image_save_dir, "metadata.jsonl"), mode="w"
107
+ ) as writer:
108
+ writer.write(pages)
109
+
110
+ for file in os.listdir(image_save_dir):
111
+ if file.endswith(tuple(exclude_file_extensions)):
112
+ os.remove(os.path.join(image_save_dir, file))
113
+
114
  if wandb_artifact_name:
115
  artifact = wandb.Artifact(
116
  name=wandb_artifact_name,
medrag_multi_modal/document_loader/image_loader/marker_img_loader.py CHANGED
@@ -97,6 +97,6 @@ class MarkerImageLoader(BaseImageLoader):
97
  "document_name": self.document_name,
98
  "file_path": self.document_file_path,
99
  "file_url": self.url,
100
- "image_file_paths": image_file_paths,
101
  "meta": out_meta,
102
  }
 
97
  "document_name": self.document_name,
98
  "file_path": self.document_file_path,
99
  "file_url": self.url,
100
+ "image_file_paths": os.path.join(image_save_dir, "*.png"),
101
  "meta": out_meta,
102
  }
pyproject.toml CHANGED
@@ -41,6 +41,7 @@ dependencies = [
41
  "google-generativeai>=0.8.3",
42
  "mistralai>=1.1.0",
43
  "instructor>=1.6.3",
 
44
  ]
45
 
46
  [project.optional-dependencies]
@@ -67,6 +68,7 @@ core = [
67
  "google-generativeai>=0.8.3",
68
  "mistralai>=1.1.0",
69
  "instructor>=1.6.3",
 
70
  ]
71
 
72
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
41
  "google-generativeai>=0.8.3",
42
  "mistralai>=1.1.0",
43
  "instructor>=1.6.3",
44
+ "jsonlines>=4.0.0",
45
  ]
46
 
47
  [project.optional-dependencies]
 
68
  "google-generativeai>=0.8.3",
69
  "mistralai>=1.1.0",
70
  "instructor>=1.6.3",
71
+ "jsonlines>=4.0.0",
72
  ]
73
 
74
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]