File size: 2,185 Bytes
c4a8d1c
 
bbd199b
 
 
 
 
 
 
 
 
 
 
c4a8d1c
 
bbd199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder

from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding

csv.field_size_limit(sys.maxsize)

concat = lambda x: np.concatenate(x, axis=0)

gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")

# CUB training set
gdown.cached_download(
    url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
    path="./CUB_train.zip",
    quiet=False,
    md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)

# EXTRACT
torchvision.datasets.utils.extract_archive(
    from_path="CUB_train.zip",
    to_path="Training/",
    remove_finished=False,
)


# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
    Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
    ytrain = pickle.load(f)

searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()

# Extract label names
training_folder = ImageFolder(root="./Training/train/")
id_to_bird_name = {
    x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}


def search(query_imag, searcher=searcher):
    query_embedding = QueryToEmbedding(query_imag)
    indices, scores, labels = searcher.search(query_embedding, k=50)

    result_ctr = Counter(labels[0][:20]).most_common(5)

    top1_label = result_ctr[0][0]
    top_indices = []

    for a, b in zip(labels[0][:20], scores[0][:20]):
        if a == top1_label:
            top_indices.append(b)

    gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
    predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}

    return predicted_labels, gallery_images


demo = gr.Interface(
    search,
    gr.Image(type="pil"),
    ["label", "gallery"],
    examples=[["./examples/bird.jpg"]],
    description="WIP - kNN on CUB dataset",
    title="Work in Progress - CHM-Corr",
)

if __name__ == "__main__":
    demo.launch()