import csv import sys import pickle from collections import Counter import numpy as np import gradio as gr import gdown import torchvision from torchvision.datasets import ImageFolder from SimSearch import FaissCosineNeighbors, SearchableTrainingSet from ExtractEmbedding import QueryToEmbedding from CHMCorr import chm_classify_and_visualize from visualization import plot_from_reranker_output csv.field_size_limit(sys.maxsize) concat = lambda x: np.concatenate(x, axis=0) # Embeddings gdown.cached_download( url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89", path="./embeddings.pkl", quiet=False, md5="002b2a7f5c80d910b9cc740c2265f058", ) gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") # CUB training set gdown.cached_download( url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo", path="./CUB_train.zip", quiet=False, md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", ) # EXTRACT training set torchvision.datasets.utils.extract_archive( from_path="CUB_train.zip", to_path="data/", remove_finished=False, ) # CHM Weights gdown.cached_download( url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", path="pas_psi.pt", quiet=False, md5="6b7b4d7bad7f89600fac340d6aa7708b", ) # Caluclate Accuracy with open(f"./embeddings.pickle", "rb") as f: Xtrain = pickle.load(f) # FIXME: re-run the code to get the embeddings in the right format with open(f"./labels.pickle", "rb") as f: ytrain = pickle.load(f) searcher = SearchableTrainingSet(Xtrain, ytrain) searcher.build_index() # Extract label names training_folder = ImageFolder(root="./data/train/") id_to_bird_name = { x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs } def search(query_image, searcher=searcher): query_embedding = QueryToEmbedding(query_image) scores, indices, labels = searcher.search(query_embedding, k=50) result_ctr = Counter(labels[0][:20]).most_common(5) top1_label = result_ctr[0][0] top_indices = [] for a, b in zip(labels[0][:20], indices[0][:20]): if a == top1_label: top_indices.append(b) gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} print("gallery_images:", gallery_images) # CHM Prediction kNN_results = (top1_label, result_ctr[0][1], gallery_images) support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] print(support_files) support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] print(support_labels) support = [support_files, support_labels] chm_output = chm_classify_and_visualize( query_image, kNN_results, support, training_folder ) viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False) return predicted_labels, gallery_images, viz_plot demo = gr.Interface( search, gr.Image(type="filepath"), ["label", "gallery", "plot"], examples=[["./examples/bird.jpg"]], description="WIP - kNN on CUB dataset", title="Work in Progress - CHM-Corr", ) if __name__ == "__main__": demo.launch()