import gradio as gr import torch import numpy as np import h5py import faiss from PIL import Image import io import pickle import random def get_image(image1, image2, dataset_image_mask, processid_to_index, idx): if (idx < 162834): image_enc_padded = image1[idx].astype(np.uint8) elif(idx >= 162834): image_enc_padded = image2[idx-162834].astype(np.uint8) enc_length = dataset_image_mask[idx] image_enc = image_enc_padded[:enc_length] image = Image.open(io.BytesIO(image_enc)) return image def searchEmbeddings(id, mod1, mod2): # variable and index initialization original_indx = processid_to_index[id] dim = 768 num_neighbors = 10 # get index index = faiss.IndexFlatIP(dim) if (mod2 == "Image"): index = faiss.read_index("image_index.index") elif (mod2 == "DNA"): index = faiss.read_index("dna_index.index") # search index if (mod1 == "Image"): query = id_to_image_emb_dict[id] elif(mod1 == "DNA"): query = id_to_dna_emb_dict[id] query = query.astype(np.float32) D, I = index.search(query, num_neighbors) id_list = [] for indx in I[0]: id = indx_to_id_dict[indx] id_list.append(id) # get images image0 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, original_indx) image1 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][0]) image2 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][1]) image3 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][2]) image4 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][3]) image5 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][4]) image6 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][5]) image7 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][6]) image8 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][7]) image9 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][8]) image10 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][9]) # get taxonomic information s0 = getTax(original_indx) s1 = getTax(I[0][0]) s2 = getTax(I[0][1]) s3 = getTax(I[0][2]) s4 = getTax(I[0][3]) s5 = getTax(I[0][4]) s6 = getTax(I[0][5]) s7 = getTax(I[0][6]) s8 = getTax(I[0][7]) s9 = getTax(I[0][8]) s10 = getTax(I[0][9]) return id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10 def getRandID(): indx = random.randrange(0, 325667) return indx_to_id_dict[indx], indx def getTax(indx): s = species[indx] g = genus[indx] f = family[indx] str = "Species: " + s + "\nGenus: " + g + "\nFamily: " + f return str with gr.Blocks(title="Bioscan-Clip") as demo: # open general files with open("dataset_image1.pickle", "rb") as f: dataset_image1 = pickle.load(f) with open("dataset_image2.pickle", "rb") as f: dataset_image2 = pickle.load(f) with open("dataset_processid_list.pickle", "rb") as f: dataset_processid_list = pickle.load(f) with open("dataset_image_mask.pickle", "rb") as f: dataset_image_mask = pickle.load(f) with open("processid_to_index.pickle", "rb") as f: processid_to_index = pickle.load(f) with open("indx_to_id_dict.pickle", "rb") as f: indx_to_id_dict = pickle.load(f) # open image files with open("id_to_image_emb_dict.pickle", "rb") as f: id_to_image_emb_dict = pickle.load(f) # open dna files with open("id_to_dna_emb_dict.pickle", "rb") as f: id_to_dna_emb_dict = pickle.load(f) # open taxonomy files with open("family.pickle", "rb") as f: family = [item.decode("utf-8") for item in pickle.load(f)] with open("genus.pickle", "rb") as f: genus= [item.decode("utf-8") for item in pickle.load(f)] with open("species.pickle", "rb") as f: species = [item.decode("utf-8") for item in pickle.load(f)] with gr.Column(): process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") process_id_list = gr.Textbox(label="Closest 10 matches:" ) mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:") mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:") search_btn = gr.Button("Search") with gr.Row(): with gr.Column(): image0 = gr.Image(label="Original", height=550) tax0 = gr.Textbox(label="Taxonomy") with gr.Column(): rand_id = gr.Textbox(label="Random ID:") rand_id_indx = gr.Textbox(label="Index:") id_btn = gr.Button("Get Random ID") with gr.Row(): with gr.Column(): image1 = gr.Image(label=1) tax1 = gr.Textbox(label="Taxonomy") with gr.Column(): image2 = gr.Image(label=2) tax2 = gr.Textbox(label="Taxonomy") with gr.Column(): image3 = gr.Image(label=3) tax3 = gr.Textbox(label="Taxonomy") with gr.Row(): with gr.Column(): image4 = gr.Image(label=4) tax4 = gr.Textbox(label="Taxonomy") with gr.Column(): image5 = gr.Image(label=5) tax5 = gr.Textbox(label="Taxonomy") with gr.Column(): image6 = gr.Image(label=6) tax6 = gr.Textbox(label="Taxonomy") with gr.Row(): with gr.Column(): image7 = gr.Image(label=7) tax7 = gr.Textbox(label="Taxonomy") with gr.Column(): image8 = gr.Image(label=8) tax8 = gr.Textbox(label="Taxonomy") with gr.Column(): image9 = gr.Image(label=9) tax9 = gr.Textbox(label="Taxonomy") with gr.Column(): image10 = gr.Image(label=10) tax10 = gr.Textbox(label="Taxonomy") id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2], outputs=[process_id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, tax0, tax1, tax2, tax3, tax4, tax5, tax6, tax7, tax8, tax9, tax10]) examples = gr.Examples( examples=[["ABOTH966-22", "DNA", "DNA"], ["CRTOB8472-22", "DNA", "Image"], ["PLOAD050-20", "Image", "DNA"], ["HELAC26711-21", "Image", "Image"]], inputs=[process_id, mod1, mod2],) demo.launch()