browser-backend / app.py
atwang's picture
update code to download dataset files from separate repo
07356cd
import pickle
import random
import gradio as gr
import numpy as np
from data import load_indexes_local, load_indexes_hf, load_index_pickle
def getRandID():
indx = random.randrange(0, len(index_to_id_dict))
return index_to_id_dict[indx], indx
def get_image_index(indexType):
try:
return image_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load an image index that is not supported: {indexType}")
def get_dna_index(indexType):
try:
return dna_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}")
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
image_index = get_image_index(index_type)
dna_index = get_dna_index(index_type)
# get index
if query_type == "Image":
query = image_index.reconstruct(id_to_index_dict[id])
elif query_type == "DNA":
query = dna_index.reconstruct(id_to_index_dict[id])
else:
raise ValueError(f"Invalid query type: {query_type}")
query = query.astype(np.float32)
query = np.expand_dims(query, axis=0)
# search for query
if key_type == "Image":
index = image_index
elif key_type == "DNA":
index = dna_index
else:
raise ValueError(f"Invalid key type: {key_type}")
_, I = index.search(query, num_results)
closest_ids = []
for indx in I[0]:
id = index_to_id_dict[indx]
closest_ids.append(id)
return closest_ids
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
dna_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd")
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image")
key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(
fn=searchEmbeddings,
inputs=[process_id, key_type, query_type, index_type, num_results],
outputs=[process_id_list],
)
demo.launch()