initial commit
Browse files- ExtractEmbedding.py +59 -0
- README.md +2 -2
- SaveEmbedding.py +100 -0
- SimSearch.py +66 -0
- app.py +79 -0
ExtractEmbedding.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pickle
|
11 |
+
import argparse
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
concat = lambda x: np.concatenate(x, axis=0)
|
15 |
+
to_np = lambda x: x.data.to("cpu").numpy()
|
16 |
+
|
17 |
+
|
18 |
+
class Wrapper(torch.nn.Module):
|
19 |
+
def __init__(self, model):
|
20 |
+
super(Wrapper, self).__init__()
|
21 |
+
self.model = model
|
22 |
+
self.avgpool_output = None
|
23 |
+
self.query = None
|
24 |
+
self.cossim_value = {}
|
25 |
+
|
26 |
+
def fw_hook(module, input, output):
|
27 |
+
self.avgpool_output = output.squeeze()
|
28 |
+
|
29 |
+
self.model.avgpool.register_forward_hook(fw_hook)
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
_ = self.model(input)
|
33 |
+
return self.avgpool_output
|
34 |
+
|
35 |
+
def __repr__(self):
|
36 |
+
return "Wrappper"
|
37 |
+
|
38 |
+
|
39 |
+
def QueryToEmbedding(query_pil):
|
40 |
+
dataset_transform = transforms.Compose(
|
41 |
+
[
|
42 |
+
transforms.Resize(256),
|
43 |
+
transforms.CenterCrop(224),
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
46 |
+
]
|
47 |
+
)
|
48 |
+
|
49 |
+
model = torchvision.models.resnet50(pretrained=True)
|
50 |
+
model.eval()
|
51 |
+
myw = Wrapper(model)
|
52 |
+
|
53 |
+
# query_pil = Image.open(query_path)
|
54 |
+
query_pt = dataset_transform(query_pil)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
embedding = to_np(myw(query_pt.unsqueeze(0)))
|
58 |
+
|
59 |
+
return np.asarray([embedding])
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: CHM
|
3 |
emoji: 🐨
|
4 |
colorFrom: yellow
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: CHM-Corr
|
3 |
emoji: 🐨
|
4 |
colorFrom: yellow
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.1.1
|
8 |
app_file: app.py
|
SaveEmbedding.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pickle
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
|
14 |
+
concat = lambda x: np.concatenate(x, axis=0)
|
15 |
+
to_np = lambda x: x.data.to("cpu").numpy()
|
16 |
+
|
17 |
+
|
18 |
+
class Wrapper(torch.nn.Module):
|
19 |
+
def __init__(self, model):
|
20 |
+
super(Wrapper, self).__init__()
|
21 |
+
self.model = model
|
22 |
+
self.avgpool_output = None
|
23 |
+
self.query = None
|
24 |
+
self.cossim_value = {}
|
25 |
+
|
26 |
+
def fw_hook(module, input, output):
|
27 |
+
self.avgpool_output = output.squeeze()
|
28 |
+
|
29 |
+
self.model.avgpool.register_forward_hook(fw_hook)
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
_ = self.model(input)
|
33 |
+
return self.avgpool_output
|
34 |
+
|
35 |
+
def __repr__(self):
|
36 |
+
return "Wrappper"
|
37 |
+
|
38 |
+
|
39 |
+
def run(training_set_path):
|
40 |
+
# Standard ImageNet Transformer to apply imagenet's statistics to input batch
|
41 |
+
dataset_transform = transforms.Compose(
|
42 |
+
[
|
43 |
+
transforms.Resize(256),
|
44 |
+
transforms.CenterCrop(224),
|
45 |
+
transforms.ToTensor(),
|
46 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
training_imagefolder = ImageFolder(
|
51 |
+
root=training_set_path, transform=dataset_transform
|
52 |
+
)
|
53 |
+
train_loader = torch.utils.data.DataLoader(
|
54 |
+
training_imagefolder,
|
55 |
+
batch_size=512,
|
56 |
+
shuffle=False,
|
57 |
+
num_workers=2,
|
58 |
+
pin_memory=True,
|
59 |
+
)
|
60 |
+
print(f"# of Training folder samples: {len(training_imagefolder)}")
|
61 |
+
########################################################################################################################
|
62 |
+
model = torchvision.models.resnet50(pretrained=True)
|
63 |
+
model.eval()
|
64 |
+
myw = Wrapper(model)
|
65 |
+
|
66 |
+
training_embeddings = []
|
67 |
+
training_labels = []
|
68 |
+
|
69 |
+
with torch.no_grad():
|
70 |
+
for _, (data, target) in enumerate(tqdm(train_loader)):
|
71 |
+
embeddings = to_np(myw(data))
|
72 |
+
labels = to_np(target)
|
73 |
+
|
74 |
+
training_embeddings.append(embeddings)
|
75 |
+
training_labels.append(labels)
|
76 |
+
|
77 |
+
training_embeddings_concatted = concat(training_embeddings)
|
78 |
+
training_labels_concatted = concat(training_labels)
|
79 |
+
|
80 |
+
print(training_embeddings_concatted.shape)
|
81 |
+
return training_embeddings_concatted, training_labels_concatted
|
82 |
+
|
83 |
+
|
84 |
+
def main():
|
85 |
+
parser = argparse.ArgumentParser(description="Saving Embeddings")
|
86 |
+
parser.add_argument("--train", help="Path to the Dataaset", type=str, required=True)
|
87 |
+
args = parser.parse_args()
|
88 |
+
|
89 |
+
embeddings, labels = run(args.train)
|
90 |
+
|
91 |
+
# Caluclate Accuracy
|
92 |
+
with open(f"embeddings.pickle", "wb") as f:
|
93 |
+
pickle.dump(embeddings, f)
|
94 |
+
|
95 |
+
with open(f"labels.pickle", "wb") as f:
|
96 |
+
pickle.dump(labels, f)
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
SimSearch.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class FaissNeighbors:
|
6 |
+
def __init__(self):
|
7 |
+
self.index = None
|
8 |
+
self.y = None
|
9 |
+
|
10 |
+
def fit(self, X, y):
|
11 |
+
self.index = faiss.IndexFlatL2(X.shape[1])
|
12 |
+
self.index.add(X.astype(np.float32))
|
13 |
+
self.y = y
|
14 |
+
|
15 |
+
def get_distances_and_indices(self, X, top_K=1000):
|
16 |
+
distances, indices = self.index.search(X.astype(np.float32), k=top_K)
|
17 |
+
return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
|
18 |
+
|
19 |
+
def get_nearest_labels(self, X, top_K=1000):
|
20 |
+
distances, indices = self.index.search(X.astype(np.float32), k=top_K)
|
21 |
+
return np.copy(self.y[indices])
|
22 |
+
|
23 |
+
|
24 |
+
class FaissCosineNeighbors:
|
25 |
+
def __init__(self):
|
26 |
+
self.cindex = None
|
27 |
+
self.y = None
|
28 |
+
|
29 |
+
def fit(self, X, y):
|
30 |
+
self.cindex = faiss.index_factory(
|
31 |
+
X.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT
|
32 |
+
)
|
33 |
+
X = np.copy(X)
|
34 |
+
X = X.astype(np.float32)
|
35 |
+
faiss.normalize_L2(X)
|
36 |
+
self.cindex.add(X)
|
37 |
+
self.y = y
|
38 |
+
|
39 |
+
def get_distances_and_indices(self, Q, topK):
|
40 |
+
Q = np.copy(Q)
|
41 |
+
faiss.normalize_L2(Q)
|
42 |
+
distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
|
43 |
+
return np.copy(distances), np.copy(indices), np.copy(self.y[indices])
|
44 |
+
|
45 |
+
def get_nearest_labels(self, Q, topK=1000):
|
46 |
+
Q = np.copy(Q)
|
47 |
+
faiss.normalize_L2(Q)
|
48 |
+
distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
|
49 |
+
return np.copy(self.y[indices])
|
50 |
+
|
51 |
+
|
52 |
+
class SearchableTrainingSet:
|
53 |
+
def __init__(self, embeddings, labels):
|
54 |
+
self.simsearcher = FaissCosineNeighbors()
|
55 |
+
self.X_train = embeddings
|
56 |
+
self.y_train = labels
|
57 |
+
|
58 |
+
def build_index(self):
|
59 |
+
self.simsearcher.fit(self.X_train, self.y_train)
|
60 |
+
|
61 |
+
def search(self, query, k=20):
|
62 |
+
nearest_data_points = self.simsearcher.get_distances_and_indices(
|
63 |
+
Q=query, topK=100
|
64 |
+
)
|
65 |
+
# topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
|
66 |
+
return nearest_data_points
|
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from collections import Counter
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
import gdown
|
6 |
+
import torchvision
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
|
9 |
+
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
10 |
+
from ExtractEmbedding import QueryToEmbedding
|
11 |
+
|
12 |
+
concat = lambda x: np.concatenate(x, axis=0)
|
13 |
+
|
14 |
+
gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
|
15 |
+
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
|
16 |
+
|
17 |
+
# CUB training set
|
18 |
+
gdown.cached_download(
|
19 |
+
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
|
20 |
+
path="./CUB_train.zip",
|
21 |
+
quiet=False,
|
22 |
+
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
|
23 |
+
)
|
24 |
+
|
25 |
+
# EXTRACT
|
26 |
+
torchvision.datasets.utils.extract_archive(
|
27 |
+
from_path="CUB_train.zip",
|
28 |
+
to_path="Training/",
|
29 |
+
remove_finished=False,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
# Caluclate Accuracy
|
34 |
+
with open(f"./embeddings.pickle", "rb") as f:
|
35 |
+
Xtrain = pickle.load(f)
|
36 |
+
# FIXME: re-run the code to get the embeddings in the right format
|
37 |
+
with open(f"./labels.pickle", "rb") as f:
|
38 |
+
ytrain = pickle.load(f)
|
39 |
+
|
40 |
+
searcher = SearchableTrainingSet(Xtrain, ytrain)
|
41 |
+
searcher.build_index()
|
42 |
+
|
43 |
+
# Extract label names
|
44 |
+
training_folder = ImageFolder(root="./Training/train/")
|
45 |
+
id_to_bird_name = {
|
46 |
+
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def search(query_imag, searcher=searcher):
|
51 |
+
query_embedding = QueryToEmbedding(query_imag)
|
52 |
+
indices, scores, labels = searcher.search(query_embedding, k=50)
|
53 |
+
|
54 |
+
result_ctr = Counter(labels[0][:20]).most_common(5)
|
55 |
+
|
56 |
+
top1_label = result_ctr[0][0]
|
57 |
+
top_indices = []
|
58 |
+
|
59 |
+
for a, b in zip(labels[0][:20], scores[0][:20]):
|
60 |
+
if a == top1_label:
|
61 |
+
top_indices.append(b)
|
62 |
+
|
63 |
+
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
|
64 |
+
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
|
65 |
+
|
66 |
+
return predicted_labels, gallery_images
|
67 |
+
|
68 |
+
|
69 |
+
demo = gr.Interface(
|
70 |
+
search,
|
71 |
+
gr.Image(type="pil"),
|
72 |
+
["label", "gallery"],
|
73 |
+
examples=[["./examples/bird.jpg"]],
|
74 |
+
description="WIP - kNN on CUB dataset",
|
75 |
+
title="Work in Progress - CHM-Corr",
|
76 |
+
)
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
demo.launch()
|