msabia commited on
Commit
a51fba7
·
verified ·
1 Parent(s): 7b85781

Update imageSearching.py

Browse files
Files changed (1) hide show
  1. imageSearching.py +97 -0
imageSearching.py CHANGED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ import faiss
6
+ import json
7
+ import hydra
8
+ import time
9
+ import random
10
+ from PIL import Image
11
+ import io
12
+ import pickle
13
+
14
+ def get_image(file, dataset_image_mask, processid_to_index, idx):
15
+ # idx = processid_to_index[query_id]
16
+ image_enc_padded = file["image"][idx].astype(np.uint8)
17
+ enc_length = dataset_image_mask[idx]
18
+ image_enc = image_enc_padded[:enc_length]
19
+ image = Image.open(io.BytesIO(image_enc))
20
+ return image
21
+
22
+ def searchEmbeddings(id):
23
+ # get embeddings from file
24
+ embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')
25
+
26
+ # variable and index initialization
27
+ dim = 768
28
+ count = 0
29
+ num_neighbors = 10
30
+
31
+ image_index = faiss.IndexFlatIP(dim)
32
+
33
+ # load dictionaries
34
+ with open("id_emb_dict.pickle", "rb") as f:
35
+ id_to_emb_dict = pickle.load(f)
36
+ with open("indx_to_id.pickle", "rb") as f:
37
+ indx_to_id_dict = pickle.load(f)
38
+
39
+ # get index
40
+ image_index = faiss.read_index("image_index.index")
41
+
42
+ # search for query
43
+ query = id_to_emb_dict[id]
44
+ query = query.astype(np.float32)
45
+ D, I = image_index.search(query, num_neighbors)
46
+
47
+ id_list = []
48
+ i = 1
49
+ for indx in I[0]:
50
+ id = indx_to_id_dict[indx]
51
+ id_list.append(id)
52
+
53
+ # get image data
54
+ dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
55
+ dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
56
+ dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
57
+ processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}
58
+
59
+ image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
60
+ image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
61
+ image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
62
+ image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
63
+ image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
64
+ image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
65
+ image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
66
+ image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
67
+ image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
68
+ image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])
69
+
70
+ # return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
71
+ # return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
72
+ return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
73
+
74
+ with gr.Blocks() as demo:
75
+ with gr.Column():
76
+ process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
77
+ process_id_list = gr.Textbox(label="Closest 10 matches:" )
78
+ search_btn = gr.Button("Search")
79
+
80
+ with gr.Row():
81
+ image1 = gr.Image(label=1)
82
+ image2 = gr.Image(label=2)
83
+ image3 = gr.Image(label=3)
84
+ image4 = gr.Image(label=4)
85
+ image5 = gr.Image(label=5)
86
+ with gr.Row():
87
+ image6 = gr.Image(label=6)
88
+ image7 = gr.Image(label=7)
89
+ image8 = gr.Image(label=8)
90
+ image9 = gr.Image(label=9)
91
+ image10 = gr.Image(label=10)
92
+
93
+ search_btn.click(fn=searchEmbeddings, inputs=process_id,
94
+ outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
95
+
96
+ # ARONZ671-20
97
+ demo.launch(share=True)