import csv import sys import pickle from collections import Counter import numpy as np import gradio as gr import gdown import torchvision from torchvision.datasets import ImageFolder from SimSearch import FaissCosineNeighbors, SearchableTrainingSet from ExtractEmbedding import QueryToEmbedding csv.field_size_limit(sys.maxsize) concat = lambda x: np.concatenate(x, axis=0) gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89") gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") # CUB training set gdown.cached_download( url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo", path="./CUB_train.zip", quiet=False, md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", ) # EXTRACT torchvision.datasets.utils.extract_archive( from_path="CUB_train.zip", to_path="Training/", remove_finished=False, ) # Caluclate Accuracy with open(f"./embeddings.pickle", "rb") as f: Xtrain = pickle.load(f) # FIXME: re-run the code to get the embeddings in the right format with open(f"./labels.pickle", "rb") as f: ytrain = pickle.load(f) searcher = SearchableTrainingSet(Xtrain, ytrain) searcher.build_index() # Extract label names training_folder = ImageFolder(root="./Training/train/") id_to_bird_name = { x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs } def search(query_imag, searcher=searcher): query_embedding = QueryToEmbedding(query_imag) indices, scores, labels = searcher.search(query_embedding, k=50) result_ctr = Counter(labels[0][:20]).most_common(5) top1_label = result_ctr[0][0] top_indices = [] for a, b in zip(labels[0][:20], scores[0][:20]): if a == top1_label: top_indices.append(b) gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} return predicted_labels, gallery_images demo = gr.Interface( search, gr.Image(type="pil"), ["label", "gallery"], examples=[["./examples/bird.jpg"]], description="WIP - kNN on CUB dataset", title="Work in Progress - CHM-Corr", ) if __name__ == "__main__": demo.launch()