NEXAS commited on
Commit
a1197b5
·
verified ·
1 Parent(s): fbe4e8b

Update utils/ingest_image.py

Browse files
Files changed (1) hide show
  1. utils/ingest_image.py +51 -0
utils/ingest_image.py CHANGED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fitz
3
+ import chromadb
4
+ from chromadb.utils.data_loaders import ImageLoader
5
+ from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction # type: ignore
6
+
7
+ path = "mm_vdb2"
8
+ client = chromadb.PersistentClient(path=path)
9
+
10
+ def extract_and_store_images(pdf_path,images_dir=r'extracted_images'):
11
+ # Step 1: Extract images from PDF
12
+ pdf_document = fitz.open(pdf_path)
13
+ os.makedirs(images_dir, exist_ok=True)
14
+
15
+ for page_num in range(len(pdf_document)):
16
+ page = pdf_document.load_page(page_num)
17
+ image_list = page.get_images(full=True)
18
+
19
+ for image_index, img in enumerate(image_list):
20
+ xref = img[0]
21
+ base_image = pdf_document.extract_image(xref)
22
+ image_bytes = base_image["image"]
23
+ image_ext = base_image["ext"]
24
+ image_filename = f"{images_dir}/page_{page_num+1}_img_{image_index+1}.{image_ext}"
25
+
26
+ with open(image_filename, "wb") as image_file:
27
+ image_file.write(image_bytes)
28
+ print(f"Saved: {image_filename}")
29
+
30
+ print("Image extraction complete.")
31
+
32
+ # Step 2: Add extracted images to ChromaDB
33
+ image_loader = ImageLoader()
34
+ CLIP = OpenCLIPEmbeddingFunction()
35
+ image_collection = client.get_or_create_collection(name="image", embedding_function=CLIP, data_loader=image_loader)
36
+ print("clip embedding.")
37
+
38
+ ids = []
39
+ uris = []
40
+
41
+ for i, filename in enumerate(sorted(os.listdir(images_dir))):
42
+ if filename.endswith('.jpeg') or filename.endswith('.png'):
43
+ file_path = os.path.join(images_dir, filename)
44
+ ids.append(str(i))
45
+ uris.append(file_path)
46
+ print(file_path)
47
+
48
+ print("Image adding to db")
49
+ image_collection.add(ids=ids, uris=uris)
50
+ print("Images added to the database.")
51
+ return image_collection