Spaces:
Running
Running
update code to download dataset files from separate repo
Browse files- .gitignore +7 -0
- app.py +23 -57
- big_indx_to_id_dict.pickle +0 -3
- bioscan_5m_dna_IndexFlatIP.index +0 -3
- bioscan_5m_image_IndexFlatIP.index +0 -3
- data.py +34 -0
- prepare_index.py +18 -33
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.build
|
2 |
+
.data
|
3 |
+
.singularity
|
4 |
+
slurm/
|
5 |
+
.env
|
6 |
+
*.sif
|
7 |
+
__pycache__/
|
app.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
import h5py
|
5 |
-
import faiss
|
6 |
-
from PIL import Image
|
7 |
-
import io
|
8 |
import pickle
|
9 |
import random
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def getRandID():
|
@@ -16,37 +13,17 @@ def getRandID():
|
|
16 |
|
17 |
|
18 |
def get_image_index(indexType):
|
19 |
-
|
20 |
-
return
|
21 |
-
|
22 |
-
raise
|
23 |
-
return image_index_L2
|
24 |
-
elif indexType == "HNSWFlat":
|
25 |
-
raise NotImplementedError
|
26 |
-
return image_index_HNSW
|
27 |
-
elif indexType == "IVFFlat":
|
28 |
-
raise NotImplementedError
|
29 |
-
return image_index_IVF
|
30 |
-
elif indexType == "LSH":
|
31 |
-
raise NotImplementedError
|
32 |
-
return image_index_LSH
|
33 |
|
34 |
|
35 |
def get_dna_index(indexType):
|
36 |
-
|
37 |
-
return
|
38 |
-
|
39 |
-
raise
|
40 |
-
return dna_index_L2
|
41 |
-
elif indexType == "HNSWFlat":
|
42 |
-
raise NotImplementedError
|
43 |
-
return dna_index_HNSW
|
44 |
-
elif indexType == "IVFFlat":
|
45 |
-
raise NotImplementedError
|
46 |
-
return dna_index_IVF
|
47 |
-
elif indexType == "LSH":
|
48 |
-
raise NotImplementedError
|
49 |
-
return dna_index_LSH
|
50 |
|
51 |
|
52 |
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
|
@@ -86,24 +63,13 @@ with gr.Blocks() as demo:
|
|
86 |
# for hf: change all file paths, indx_to_id_dict as well
|
87 |
|
88 |
# load indexes
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
# dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
|
97 |
-
# dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
|
98 |
-
# dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
|
99 |
-
# dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
|
100 |
-
|
101 |
-
# with open("dataset_processid_list.pickle", "rb") as f:
|
102 |
-
# dataset_processid_list = pickle.load(f)
|
103 |
-
# with open("processid_to_index.pickle", "rb") as f:
|
104 |
-
# processid_to_index = pickle.load(f)
|
105 |
-
with open("big_indx_to_id_dict.pickle", "rb") as f:
|
106 |
-
index_to_id_dict = pickle.load(f)
|
107 |
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
|
108 |
|
109 |
with gr.Column():
|
@@ -113,8 +79,8 @@ with gr.Blocks() as demo:
|
|
113 |
rand_id_indx = gr.Textbox(label="Index:")
|
114 |
id_btn = gr.Button("Get Random ID")
|
115 |
with gr.Column():
|
116 |
-
|
117 |
-
|
118 |
|
119 |
index_type = gr.Radio(
|
120 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|
@@ -122,7 +88,7 @@ with gr.Blocks() as demo:
|
|
122 |
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
|
123 |
|
124 |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
125 |
-
process_id_list = gr.Textbox(label="Closest
|
126 |
search_btn = gr.Button("Search")
|
127 |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import pickle
|
2 |
import random
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from data import load_indexes_local, load_indexes_hf, load_index_pickle
|
8 |
|
9 |
|
10 |
def getRandID():
|
|
|
13 |
|
14 |
|
15 |
def get_image_index(indexType):
|
16 |
+
try:
|
17 |
+
return image_indexes[indexType]
|
18 |
+
except KeyError:
|
19 |
+
raise KeyError(f"Tried to load an image index that is not supported: {indexType}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def get_dna_index(indexType):
|
23 |
+
try:
|
24 |
+
return dna_indexes[indexType]
|
25 |
+
except KeyError:
|
26 |
+
raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
|
|
|
63 |
# for hf: change all file paths, indx_to_id_dict as well
|
64 |
|
65 |
# load indexes
|
66 |
+
image_indexes = load_indexes_hf(
|
67 |
+
{"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
|
68 |
+
)
|
69 |
+
dna_indexes = load_indexes_hf(
|
70 |
+
{"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
|
71 |
+
)
|
72 |
+
index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
|
74 |
|
75 |
with gr.Column():
|
|
|
79 |
rand_id_indx = gr.Textbox(label="Index:")
|
80 |
id_btn = gr.Button("Get Random ID")
|
81 |
with gr.Column():
|
82 |
+
query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image")
|
83 |
+
key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image")
|
84 |
|
85 |
index_type = gr.Radio(
|
86 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|
|
|
88 |
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
|
89 |
|
90 |
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
91 |
+
process_id_list = gr.Textbox(label="Closest matches:")
|
92 |
search_btn = gr.Button("Search")
|
93 |
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
94 |
|
big_indx_to_id_dict.pickle
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ee0a9044e054f640b704247a2fa2e74219180b78ded6ba07f551bfc222657fc5
|
3 |
-
size 885457
|
|
|
|
|
|
|
|
bioscan_5m_dna_IndexFlatIP.index
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:83fe6599724756652689b76ef942ffaca2f8d5863ff3dd7fe7ac655199e0968d
|
3 |
-
size 136009773
|
|
|
|
|
|
|
|
bioscan_5m_image_IndexFlatIP.index
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:01e5d74fd5194551e2b8e43aba8e41153efeb29589fa82a7839791d2e057c21d
|
3 |
-
size 136009773
|
|
|
|
|
|
|
|
data.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import faiss
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
|
9 |
+
def load_indexes_local(index_files: dict[str, str], *, data_folder: str, **kw) -> dict[str, Any]:
|
10 |
+
indexes = {}
|
11 |
+
for index_type, index_file in index_files.items():
|
12 |
+
indexes[index_type] = faiss.read_index(os.path.join(data_folder, index_file))
|
13 |
+
|
14 |
+
return indexes
|
15 |
+
|
16 |
+
|
17 |
+
def load_indexes_hf(index_files: dict[str, str], *, repo_name: str, **kw) -> dict[str, Any]:
|
18 |
+
indexes = {}
|
19 |
+
for index_type, index_file in index_files.items():
|
20 |
+
indexes[index_type] = faiss.read_index(
|
21 |
+
hf_hub_download(repo_id=repo_name, filename=index_file, repo_type="dataset")
|
22 |
+
)
|
23 |
+
|
24 |
+
return indexes
|
25 |
+
|
26 |
+
|
27 |
+
def load_index_pickle(index_file: str, repo_name: str) -> Any:
|
28 |
+
index_to_id_dict_file = hf_hub_download(
|
29 |
+
repo_id=repo_name,
|
30 |
+
filename=index_file,
|
31 |
+
repo_type="dataset",
|
32 |
+
)
|
33 |
+
with open(index_to_id_dict_file, "rb") as f:
|
34 |
+
return pickle.load(f)
|
prepare_index.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import click
|
@@ -9,55 +10,33 @@ ALL_INDEX_TYPES = ["IndexFlatIP", "IndexFlatL2", "IndexIVFFlat", "IndexHNSWFlat"
|
|
9 |
EMBEDDING_SIZE = 768
|
10 |
|
11 |
|
12 |
-
def process(
|
13 |
# load embeddings
|
14 |
-
|
15 |
-
f"encoded_{key_type}_feature"
|
16 |
-
][:]
|
17 |
-
seen_test = h5py.File(input / "extracted_features_of_seen_test.hdf5", "r", libver="latest")[
|
18 |
-
f"encoded_{key_type}_feature"
|
19 |
-
][:]
|
20 |
-
unseen_test = h5py.File(input / "extracted_features_of_unseen_test.hdf5", "r", libver="latest")[
|
21 |
-
f"encoded_{key_type}_feature"
|
22 |
-
][:]
|
23 |
-
seen_val = h5py.File(input / "extracted_features_of_seen_val.hdf5", "r", libver="latest")[
|
24 |
-
f"encoded_{key_type}_feature"
|
25 |
-
][:]
|
26 |
-
unseen_val = h5py.File(input / "extracted_features_of_unseen_val.hdf5", "r", libver="latest")[
|
27 |
-
f"encoded_{key_type}_feature"
|
28 |
-
][:]
|
29 |
|
30 |
# FlatIP and FlatL2
|
31 |
if index_type == "IndexFlatIP":
|
32 |
-
test_index = faiss.IndexFlatIP(
|
33 |
elif index_type == "IndexFlatL2":
|
34 |
-
test_index = faiss.IndexFlatL2(
|
35 |
elif index_type == "IndexIVFFlat":
|
36 |
# IVFFlat
|
37 |
-
quantizer = faiss.IndexFlatIP(
|
38 |
-
test_index = faiss.IndexIVFFlat(quantizer,
|
39 |
-
test_index.train(
|
40 |
-
test_index.train(seen_test)
|
41 |
-
test_index.train(unseen_test)
|
42 |
-
test_index.train(seen_val)
|
43 |
-
test_index.train(unseen_val)
|
44 |
elif index_type == "IndexHNSWFlat":
|
45 |
# HNSW
|
46 |
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
|
47 |
-
test_index = faiss.IndexHNSWFlat(
|
48 |
test_index.hnsw.efSearch = 32
|
49 |
test_index.hnsw.efConstruction = 64
|
50 |
elif index_type == "IndexLSH":
|
51 |
# LSH
|
52 |
-
test_index = faiss.IndexLSH(
|
53 |
else:
|
54 |
raise ValueError(f"Index type {index_type} is not supported")
|
55 |
|
56 |
-
test_index.add(
|
57 |
-
test_index.add(seen_test)
|
58 |
-
test_index.add(unseen_test)
|
59 |
-
test_index.add(seen_val)
|
60 |
-
test_index.add(unseen_val)
|
61 |
|
62 |
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
|
63 |
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")
|
@@ -96,9 +75,15 @@ def main(input, output, key_type, index_type):
|
|
96 |
else:
|
97 |
index_types = [index_type]
|
98 |
|
|
|
99 |
for key_type in key_types:
|
100 |
for index_type in index_types:
|
101 |
-
process(
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
|
104 |
if __name__ == "__main__":
|
|
|
1 |
+
import pickle
|
2 |
from pathlib import Path
|
3 |
|
4 |
import click
|
|
|
10 |
EMBEDDING_SIZE = 768
|
11 |
|
12 |
|
13 |
+
def process(embedding_data, output: Path, key_type: str, index_type: str):
|
14 |
# load embeddings
|
15 |
+
embeddings = embedding_data[f"encoded_{key_type}_feature"][:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# FlatIP and FlatL2
|
18 |
if index_type == "IndexFlatIP":
|
19 |
+
test_index = faiss.IndexFlatIP(embeddings.shape[-1])
|
20 |
elif index_type == "IndexFlatL2":
|
21 |
+
test_index = faiss.IndexFlatL2(embeddings.shape[-1])
|
22 |
elif index_type == "IndexIVFFlat":
|
23 |
# IVFFlat
|
24 |
+
quantizer = faiss.IndexFlatIP(embeddings.shape[-1])
|
25 |
+
test_index = faiss.IndexIVFFlat(quantizer, embeddings.shape[-1], 128)
|
26 |
+
test_index.train(embeddings)
|
|
|
|
|
|
|
|
|
27 |
elif index_type == "IndexHNSWFlat":
|
28 |
# HNSW
|
29 |
# 16: connections for each vertex. efSearch: depth of search during search. efConstruction: depth of search during build
|
30 |
+
test_index = faiss.IndexHNSWFlat(embeddings.shape[-1])
|
31 |
test_index.hnsw.efSearch = 32
|
32 |
test_index.hnsw.efConstruction = 64
|
33 |
elif index_type == "IndexLSH":
|
34 |
# LSH
|
35 |
+
test_index = faiss.IndexLSH(embeddings.shape[-1], embeddings.shape[-1] * 2)
|
36 |
else:
|
37 |
raise ValueError(f"Index type {index_type} is not supported")
|
38 |
|
39 |
+
test_index.add(embeddings)
|
|
|
|
|
|
|
|
|
40 |
|
41 |
faiss.write_index(test_index, str(output / f"bioscan_5m_{key_type}_{index_type}.index"))
|
42 |
print("Saved index to", output / f"bioscan_5m_{key_type}_{index_type}.index")
|
|
|
75 |
else:
|
76 |
index_types = [index_type]
|
77 |
|
78 |
+
embedding_data = h5py.File(input / "extracted_features_for_all_5m_data.hdf5", "r", libver="latest")
|
79 |
for key_type in key_types:
|
80 |
for index_type in index_types:
|
81 |
+
process(embedding_data, output, key_type, index_type)
|
82 |
+
|
83 |
+
sample_ids = [raw_id.decode("utf-8") for raw_id in embedding_data["file_name_list"][:]]
|
84 |
+
index_to_id = {index: id for index, id in enumerate(sample_ids)}
|
85 |
+
with open(output / "big_indx_to_id_dict.pickle", "wb") as f:
|
86 |
+
pickle.dump(index_to_id, f)
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|