taesiri commited on
Commit
d526dbf
1 Parent(s): c4a8d1c

added CHM classification

Browse files
CHMCorr.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CHM-Corr Classifier
2
+ import argparse
3
+ import json
4
+ import pickle
5
+ import random
6
+ from itertools import product
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchvision.transforms as transforms
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.datasets import ImageFolder
13
+ from tqdm import tqdm
14
+ from common.evaluation import Evaluator
15
+ from model import chmnet
16
+ from model.base.geometry import Geometry
17
+
18
+ from Utils import (
19
+ CosineCustomDataset,
20
+ PairedLayer4Extractor,
21
+ compute_spatial_similarity,
22
+ generate_mask,
23
+ normalize_array,
24
+ get_transforms,
25
+ arg_topK,
26
+ )
27
+
28
+ # Setting the random seed
29
+ random.seed(42)
30
+
31
+ # Helper Function
32
+ to_np = lambda x: x.data.to("cpu").numpy()
33
+
34
+ # CHMNet Config
35
+ chm_args = dict(
36
+ {
37
+ "alpha": [0.05, 0.1],
38
+ "img_size": 240,
39
+ "ktype": "psi",
40
+ "load": "pas_psi.pt",
41
+ }
42
+ )
43
+
44
+
45
+ class CHMGridTransfer:
46
+ def __init__(
47
+ self,
48
+ query_image,
49
+ support_set,
50
+ support_set_labels,
51
+ train_folder,
52
+ top_N,
53
+ top_K,
54
+ binarization_threshold,
55
+ chm_source_transform,
56
+ chm_target_transform,
57
+ cosine_source_transform,
58
+ cosine_target_transform,
59
+ batch_size=64,
60
+ ):
61
+ self.N = top_N
62
+ self.K = top_K
63
+ self.BS = batch_size
64
+
65
+ self.chm_source_transform = chm_source_transform
66
+ self.chm_target_transform = chm_target_transform
67
+ self.cosine_source_transform = cosine_source_transform
68
+ self.cosine_target_transform = cosine_target_transform
69
+
70
+ self.source_embeddings = None
71
+ self.target_embeddings = None
72
+ self.correspondence_map = None
73
+ self.similarity_maps = None
74
+ self.reverse_similarity_maps = None
75
+ self.transferred_points = None
76
+
77
+ self.binarization_threshold = binarization_threshold
78
+
79
+ # UPDATE THIS
80
+ self.q = query_image
81
+ self.support_set = support_set
82
+ self.labels_ss = support_set_labels
83
+
84
+ def build(self):
85
+ # C.M.H
86
+ test_ds = CosineCustomDataset(
87
+ query_image=self.q,
88
+ supporting_set=self.support_set,
89
+ source_transform=self.chm_source_transform,
90
+ target_transform=self.chm_target_transform,
91
+ )
92
+ test_dl = DataLoader(test_ds, batch_size=self.BS, shuffle=False)
93
+ self.find_correspondences(test_dl)
94
+
95
+ # LAYER 4s
96
+ test_ds = CosineCustomDataset(
97
+ query_image=self.q,
98
+ supporting_set=self.support_set,
99
+ source_transform=self.cosine_source_transform,
100
+ target_transform=self.cosine_target_transform,
101
+ )
102
+ test_dl = DataLoader(test_ds, batch_size=self.BS, shuffle=False)
103
+ self.compute_embeddings(test_dl)
104
+ self.compute_similarity_map()
105
+
106
+ def find_correspondences(self, test_dl):
107
+ model = chmnet.CHMNet(chm_args["ktype"])
108
+ model.load_state_dict(
109
+ torch.load(chm_args["load"], map_location=torch.device("cpu"))
110
+ )
111
+ Evaluator.initialize(chm_args["alpha"])
112
+ Geometry.initialize(img_size=chm_args["img_size"])
113
+
114
+ grid_results = []
115
+ transferred_points = []
116
+
117
+ # FIXED GRID HARD CODED
118
+ fixed_src_grid_points = list(
119
+ product(
120
+ np.linspace(1 + 17, 240 - 17 - 1, 7),
121
+ np.linspace(1 + 17, 240 - 17 - 1, 7),
122
+ )
123
+ )
124
+ fixed_src_grid_points = np.asarray(fixed_src_grid_points, dtype=np.float64).T
125
+
126
+ with torch.no_grad():
127
+ model.eval()
128
+ for idx, batch in enumerate(tqdm(test_dl)):
129
+
130
+ keypoints = (
131
+ torch.tensor(fixed_src_grid_points)
132
+ .unsqueeze(0)
133
+ .repeat(batch["src_img"].shape[0], 1, 1)
134
+ )
135
+ n_pts = torch.tensor(
136
+ np.asarray(batch["src_img"].shape[0] * [49]), dtype=torch.long
137
+ )
138
+
139
+ corr_matrix = model(batch["src_img"], batch["trg_img"])
140
+ prd_kps = Geometry.transfer_kps(
141
+ corr_matrix, keypoints, n_pts, normalized=False
142
+ )
143
+ transferred_points.append(prd_kps.cpu().numpy())
144
+ for tgt_points in prd_kps:
145
+ tgt_grid = []
146
+ for x, y in zip(tgt_points[0], tgt_points[1]):
147
+ tgt_grid.append(
148
+ [int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]
149
+ )
150
+ grid_results.append(tgt_grid)
151
+
152
+ self.correspondence_map = grid_results
153
+ self.transferred_points = np.vstack(transferred_points)
154
+
155
+ def compute_embeddings(self, test_dl):
156
+ paired_extractor = PairedLayer4Extractor()
157
+
158
+ source_embeddings = []
159
+ target_embeddings = []
160
+
161
+ with torch.no_grad():
162
+ for idx, batch in enumerate(test_dl):
163
+ s_e, t_e = paired_extractor((batch["src_img"], batch["trg_img"]))
164
+
165
+ source_embeddings.append(s_e)
166
+ target_embeddings.append(t_e)
167
+
168
+ # EMBEDDINGS
169
+ self.source_embeddings = torch.cat(source_embeddings, axis=0)
170
+ self.target_embeddings = torch.cat(target_embeddings, axis=0)
171
+
172
+ def compute_similarity_map(self):
173
+ CosSim = nn.CosineSimilarity(dim=0, eps=1e-6)
174
+
175
+ similarity_maps = []
176
+ rsimilarity_maps = []
177
+
178
+ grid = []
179
+ for i in range(7):
180
+ for j in range(7):
181
+ grid.append([i, j])
182
+
183
+ # Compute for all image pairs
184
+ for i in range(len(self.correspondence_map)):
185
+ cosine_map = np.zeros((7, 7))
186
+ reverse_cosine_map = np.zeros((7, 7))
187
+
188
+ # calculate cosine based on the chm corr. map
189
+ for S, T in zip(grid, self.correspondence_map[i]):
190
+ v1 = self.source_embeddings[i][:, S[0], S[1]]
191
+ v2 = self.target_embeddings[i][:, T[0], T[1]]
192
+ covalue = CosSim(v1, v2)
193
+ cosine_map[S[0], S[1]] = covalue
194
+ reverse_cosine_map[T[0], T[1]] = covalue
195
+
196
+ similarity_maps.append(cosine_map)
197
+ rsimilarity_maps.append(reverse_cosine_map)
198
+
199
+ self.similarity_maps = similarity_maps
200
+ self.reverse_similarity_maps = rsimilarity_maps
201
+
202
+ def compute_score_using_cc(self):
203
+ # CC MAPS
204
+ SIMS_source, SIMS_target = [], []
205
+ for i in range(len(self.source_embeddings)):
206
+ simA, simB = compute_spatial_similarity(
207
+ to_np(self.source_embeddings[i]), to_np(self.target_embeddings[i])
208
+ )
209
+
210
+ SIMS_source.append(simA)
211
+ SIMS_target.append(simB)
212
+
213
+ SIMS_source = np.stack(SIMS_source, axis=0)
214
+ # SIMS_target = np.stack(SIMS_target, axis=0)
215
+
216
+ top_cos_values = []
217
+
218
+ for i in range(len(self.similarity_maps)):
219
+ cosine_value = np.multiply(
220
+ self.similarity_maps[i],
221
+ generate_mask(
222
+ normalize_array(SIMS_source[i]), t=self.binarization_threshold
223
+ ),
224
+ )
225
+ top_5_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1][:5]
226
+ mean_of_top_5 = np.mean(
227
+ [cosine_value.T.reshape(-1)[x] for x in top_5_indicies]
228
+ )
229
+ top_cos_values.append(np.mean(mean_of_top_5))
230
+
231
+ return top_cos_values
232
+
233
+ def compute_score_using_custom_points(self, selected_keypoint_masks):
234
+ top_cos_values = []
235
+
236
+ for i in range(len(self.similarity_maps)):
237
+ cosine_value = np.multiply(self.similarity_maps[i], selected_keypoint_masks)
238
+ top_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1]
239
+ mean_of_tops = np.mean(
240
+ [cosine_value.T.reshape(-1)[x] for x in top_indicies]
241
+ )
242
+ top_cos_values.append(np.mean(mean_of_tops))
243
+
244
+ return top_cos_values
245
+
246
+ def export(self):
247
+ storage = {
248
+ "N": self.N,
249
+ "K": self.K,
250
+ "source_embeddings": self.source_embeddings,
251
+ "target_embeddings": self.target_embeddings,
252
+ "correspondence_map": self.correspondence_map,
253
+ "similarity_maps": self.similarity_maps,
254
+ "T": self.binarization_threshold,
255
+ "query": self.q,
256
+ "support_set": self.support_set,
257
+ "labels_for_support_set": self.labels_ss,
258
+ "rsimilarity_maps": self.reverse_similarity_maps,
259
+ "transferred_points": self.transferred_points,
260
+ }
261
+
262
+ return ModifiableCHMResults(storage)
263
+
264
+
265
+ class ModifiableCHMResults:
266
+ def __init__(self, storage):
267
+ self.N = storage["N"]
268
+ self.K = storage["K"]
269
+ self.source_embeddings = storage["source_embeddings"]
270
+ self.target_embeddings = storage["target_embeddings"]
271
+ self.correspondence_map = storage["correspondence_map"]
272
+ self.similarity_maps = storage["similarity_maps"]
273
+ self.T = storage["T"]
274
+ self.q = storage["query"]
275
+ self.support_set = storage["support_set"]
276
+ self.labels_ss = storage["labels_for_support_set"]
277
+ self.rsimilarity_maps = storage["rsimilarity_maps"]
278
+ self.transferred_points = storage["transferred_points"]
279
+ self.similarity_maps_masked = None
280
+ self.SIMS_source = None
281
+ self.SIMS_target = None
282
+ self.masked_sim_values = []
283
+ self.top_cos_values = []
284
+
285
+ def compute_score_using_cc(self):
286
+ # CC MAPS
287
+ SIMS_source, SIMS_target = [], []
288
+ for i in range(len(self.source_embeddings)):
289
+ simA, simB = compute_spatial_similarity(
290
+ to_np(self.source_embeddings[i]), to_np(self.target_embeddings[i])
291
+ )
292
+
293
+ SIMS_source.append(simA)
294
+ SIMS_target.append(simB)
295
+
296
+ SIMS_source = np.stack(SIMS_source, axis=0)
297
+ SIMS_target = np.stack(SIMS_target, axis=0)
298
+
299
+ self.SIMS_source = SIMS_source
300
+ self.SIMS_target = SIMS_target
301
+
302
+ top_cos_values = []
303
+
304
+ for i in range(len(self.similarity_maps)):
305
+ masked_sim_values = np.multiply(
306
+ self.similarity_maps[i],
307
+ generate_mask(normalize_array(SIMS_source[i]), t=self.T),
308
+ )
309
+ self.masked_sim_values.append(masked_sim_values)
310
+ top_5_indicies = np.argsort(masked_sim_values.T.reshape(-1))[::-1][:5]
311
+ mean_of_top_5 = np.mean(
312
+ [masked_sim_values.T.reshape(-1)[x] for x in top_5_indicies]
313
+ )
314
+ top_cos_values.append(np.mean(mean_of_top_5))
315
+
316
+ self.top_cos_values = top_cos_values
317
+
318
+ return top_cos_values
319
+
320
+ def compute_score_using_custom_points(self, selected_keypoint_masks):
321
+ top_cos_values = []
322
+ similarity_maps_masked = []
323
+
324
+ for i in range(len(self.similarity_maps)):
325
+ cosine_value = np.multiply(self.similarity_maps[i], selected_keypoint_masks)
326
+ similarity_maps_masked.append(cosine_value)
327
+ top_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1]
328
+ mean_of_tops = np.mean(
329
+ [cosine_value.T.reshape(-1)[x] for x in top_indicies]
330
+ )
331
+ top_cos_values.append(np.mean(mean_of_tops))
332
+
333
+ self.similarity_maps_masked = similarity_maps_masked
334
+ return top_cos_values
335
+
336
+ def predict_using_cc(self):
337
+ top_cos_values = self.compute_score_using_cc()
338
+ # Predict
339
+ prediction = np.argmax(
340
+ np.bincount(
341
+ [self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
342
+ )
343
+ )
344
+ prediction_weight = np.max(
345
+ np.bincount(
346
+ [self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
347
+ )
348
+ )
349
+
350
+ reranked_nns_idx = [x for x in np.argsort(top_cos_values)[::-1]]
351
+ reranked_nns_files = [self.support_set[x] for x in reranked_nns_idx]
352
+
353
+ topK_idx = [
354
+ x
355
+ for x in np.argsort(top_cos_values)[::-1]
356
+ if self.labels_ss[x] == prediction
357
+ ]
358
+ topK_files = [self.support_set[x] for x in topK_idx]
359
+ topK_cmaps = [self.correspondence_map[x] for x in topK_idx]
360
+ topK_similarity_maps = [self.similarity_maps[x] for x in topK_idx]
361
+ topK_rsimilarity_maps = [self.rsimilarity_maps[x] for x in topK_idx]
362
+ topK_transfered_points = [self.transferred_points[x] for x in topK_idx]
363
+ predicted_folder_name = topK_files[0].split("/")[-2]
364
+
365
+ return (
366
+ topK_idx,
367
+ prediction,
368
+ predicted_folder_name,
369
+ prediction_weight,
370
+ topK_files[: self.K],
371
+ reranked_nns_files[: self.K],
372
+ topK_cmaps[: self.K],
373
+ topK_similarity_maps[: self.K],
374
+ topK_rsimilarity_maps[: self.K],
375
+ topK_transfered_points[: self.K],
376
+ )
377
+
378
+ def predict_custom_pairs(self, selected_keypoint_masks):
379
+ top_cos_values = self.compute_score_using_custom_points(selected_keypoint_masks)
380
+
381
+ # Predict
382
+ prediction = np.argmax(
383
+ np.bincount(
384
+ [self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
385
+ )
386
+ )
387
+ prediction_weight = np.max(
388
+ np.bincount(
389
+ [self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
390
+ )
391
+ )
392
+
393
+ reranked_nns_idx = [x for x in np.argsort(top_cos_values)[::-1]]
394
+ reranked_nns_files = [self.support_set[x] for x in reranked_nns_idx]
395
+
396
+ topK_idx = [
397
+ x
398
+ for x in np.argsort(top_cos_values)[::-1]
399
+ if self.labels_ss[x] == prediction
400
+ ]
401
+ topK_files = [self.support_set[x] for x in topK_idx]
402
+ topK_cmaps = [self.correspondence_map[x] for x in topK_idx]
403
+ topK_similarity_maps = [self.similarity_maps[x] for x in topK_idx]
404
+ topK_rsimilarity_maps = [self.rsimilarity_maps[x] for x in topK_idx]
405
+ topK_transferred_points = [self.transferred_points[x] for x in topK_idx]
406
+ # topK_scores = [top_cos_values[x] for x in topK_idx]
407
+ topK_masked_sims = [self.similarity_maps_masked[x] for x in topK_idx]
408
+ predicted_folder_name = topK_files[0].split("/")[-2]
409
+
410
+ non_zero_mask = np.count_nonzero(selected_keypoint_masks)
411
+
412
+ return (
413
+ topK_idx,
414
+ prediction,
415
+ predicted_folder_name,
416
+ prediction_weight,
417
+ topK_files[: self.K],
418
+ reranked_nns_files[: self.K],
419
+ topK_cmaps[: self.K],
420
+ topK_similarity_maps[: self.K],
421
+ topK_rsimilarity_maps[: self.K],
422
+ topK_transferred_points[: self.K],
423
+ topK_masked_sims[: self.K],
424
+ non_zero_mask,
425
+ )
426
+
427
+
428
+ def export_visualizations_results(
429
+ reranker_output,
430
+ knn_predicted_label,
431
+ knn_confidence,
432
+ topK_knns,
433
+ K=20,
434
+ N=50,
435
+ T=0.55,
436
+ ):
437
+ """
438
+ Export all details for visualization and analysis
439
+ """
440
+
441
+ non_zero_mask = 5 # default value
442
+ (
443
+ topK_idx,
444
+ p,
445
+ pfn,
446
+ pr,
447
+ rfiles,
448
+ reranked_nns,
449
+ cmaps,
450
+ sims,
451
+ rsims,
452
+ trns_kpts,
453
+ ) = reranker_output.predict_using_cc()
454
+
455
+ MASKED_COSINE_VALUES = [
456
+ np.multiply(
457
+ sims[X],
458
+ generate_mask(
459
+ normalize_array(reranker_output.SIMS_source[topK_idx[X]]), t=T
460
+ ),
461
+ )
462
+ for X in range(len(sims))
463
+ ]
464
+
465
+ list_of_source_points = []
466
+ list_of_target_points = []
467
+
468
+ for CK in range(len(sims)):
469
+ target_keypoints = []
470
+ topk_index = arg_topK(MASKED_COSINE_VALUES[CK], topK=non_zero_mask)
471
+
472
+ for i in range(non_zero_mask): # Number of Connections
473
+ # Psource = point_list[topk_index[i]]
474
+ x, y = trns_kpts[CK].T[topk_index[i]]
475
+ Ptarget = int(((x + 1) / 2.0) * 240), int(((y + 1) / 2.0) * 240)
476
+ target_keypoints.append(Ptarget)
477
+
478
+ # Uniform Grid of points
479
+ a = np.linspace(1 + 17, 240 - 17 - 1, 7)
480
+ b = np.linspace(1 + 17, 240 - 17 - 1, 7)
481
+ point_list = list(product(a, b))
482
+
483
+ list_of_source_points.append(np.asarray([point_list[x] for x in topk_index]))
484
+ list_of_target_points.append(np.asarray(target_keypoints))
485
+
486
+ # EXPORT OUTPUT
487
+ detailed_output = {
488
+ "q": reranker_output.q,
489
+ "K": K,
490
+ "N": N,
491
+ "knn-prediction": knn_predicted_label,
492
+ "knn-prediction-confidence": knn_confidence,
493
+ "knn-nearest-neighbors": topK_knns,
494
+ "chm-prediction": pfn,
495
+ "chm-prediction-confidence": pr,
496
+ "chm-nearest-neighbors": rfiles,
497
+ "correspondance_map": cmaps,
498
+ "masked_cos_values": MASKED_COSINE_VALUES,
499
+ "src-keypoints": list_of_source_points,
500
+ "tgt-keypoints": list_of_target_points,
501
+ "non_zero_mask": non_zero_mask,
502
+ "transferred_kpoints": trns_kpts,
503
+ }
504
+
505
+ return detailed_output
506
+
507
+
508
+ def chm_classify_and_visualize(
509
+ query_image, kNN_results, support, TRAIN_SET, N=50, K=20, T=0.55, BS=64
510
+ ):
511
+ global chm_args
512
+ chm_src_t, chm_tgt_t, cos_src_t, cos_tgt_t = get_transforms("single", chm_args)
513
+ knn_predicted_label, knn_confidence, topK_knns = kNN_results
514
+
515
+ reranker = CHMGridTransfer(
516
+ query_image=query_image,
517
+ support_set=support[0],
518
+ support_set_labels=support[1],
519
+ train_folder=TRAIN_SET,
520
+ top_N=N,
521
+ top_K=K,
522
+ binarization_threshold=T,
523
+ chm_source_transform=chm_src_t,
524
+ chm_target_transform=chm_tgt_t,
525
+ cosine_source_transform=cos_src_t,
526
+ cosine_target_transform=cos_tgt_t,
527
+ batch_size=BS,
528
+ )
529
+
530
+ # Building the reranker
531
+ reranker.build()
532
+ # Make a ModifiableCHMResults
533
+ exported_reranker = reranker.export()
534
+ # Export A details for visualizations
535
+
536
+ output = export_visualizations_results(
537
+ exported_reranker,
538
+ knn_predicted_label,
539
+ knn_confidence,
540
+ topK_knns,
541
+ K,
542
+ N,
543
+ T,
544
+ )
545
+
546
+ return output
ExtractEmbedding.py CHANGED
@@ -36,7 +36,7 @@ class Wrapper(torch.nn.Module):
36
  return "Wrappper"
37
 
38
 
39
- def QueryToEmbedding(query_pil):
40
  dataset_transform = transforms.Compose(
41
  [
42
  transforms.Resize(256),
@@ -50,7 +50,7 @@ def QueryToEmbedding(query_pil):
50
  model.eval()
51
  myw = Wrapper(model)
52
 
53
- # query_pil = Image.open(query_path)
54
  query_pt = dataset_transform(query_pil)
55
 
56
  with torch.no_grad():
 
36
  return "Wrappper"
37
 
38
 
39
+ def QueryToEmbedding(query_path):
40
  dataset_transform = transforms.Compose(
41
  [
42
  transforms.Resize(256),
 
50
  model.eval()
51
  myw = Wrapper(model)
52
 
53
+ query_pil = Image.open(query_path)
54
  query_pt = dataset_transform(query_pil)
55
 
56
  with torch.no_grad():
FeatureExtractors.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Author: Jonathan Donnellya (jonathan.donnelly@maine.edu)
2
+ # Modified by Mohammad Reza Taesiri (mtaesiri@gmail.com)
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from collections import OrderedDict
8
+
9
+ model_dir = os.path.dirname(os.path.realpath(__file__))
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1):
13
+ """3x3 convolution with padding"""
14
+ return nn.Conv2d(
15
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
16
+ )
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ # class attribute
26
+ expansion = 1
27
+ num_layers = 2
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
30
+ super(BasicBlock, self).__init__()
31
+ # only conv with possibly not 1 stride
32
+ self.conv1 = conv3x3(inplanes, planes, stride)
33
+ self.bn1 = nn.BatchNorm2d(planes)
34
+ self.relu = nn.ReLU(inplace=True)
35
+ self.conv2 = conv3x3(planes, planes)
36
+ self.bn2 = nn.BatchNorm2d(planes)
37
+
38
+ # if stride is not 1 then self.downsample cannot be None
39
+ self.downsample = downsample
40
+ self.stride = stride
41
+
42
+ def forward(self, x):
43
+ identity = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ # the residual connection
56
+ out += identity
57
+ out = self.relu(out)
58
+
59
+ return out
60
+
61
+ def block_conv_info(self):
62
+ block_kernel_sizes = [3, 3]
63
+ block_strides = [self.stride, 1]
64
+ block_paddings = [1, 1]
65
+
66
+ return block_kernel_sizes, block_strides, block_paddings
67
+
68
+
69
+ class Bottleneck(nn.Module):
70
+ # class attribute
71
+ expansion = 4
72
+ num_layers = 3
73
+
74
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
75
+ super(Bottleneck, self).__init__()
76
+ self.conv1 = conv1x1(inplanes, planes)
77
+ self.bn1 = nn.BatchNorm2d(planes)
78
+ # only conv with possibly not 1 stride
79
+ self.conv2 = conv3x3(planes, planes, stride)
80
+ self.bn2 = nn.BatchNorm2d(planes)
81
+ self.conv3 = conv1x1(planes, planes * self.expansion)
82
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
83
+ self.relu = nn.ReLU(inplace=True)
84
+
85
+ # if stride is not 1 then self.downsample cannot be None
86
+ self.downsample = downsample
87
+ self.stride = stride
88
+
89
+ def forward(self, x):
90
+ identity = x
91
+
92
+ out = self.conv1(x)
93
+ out = self.bn1(out)
94
+ out = self.relu(out)
95
+
96
+ out = self.conv2(out)
97
+ out = self.bn2(out)
98
+ out = self.relu(out)
99
+
100
+ out = self.conv3(out)
101
+ out = self.bn3(out)
102
+
103
+ if self.downsample is not None:
104
+ identity = self.downsample(x)
105
+
106
+ out += identity
107
+ out = self.relu(out)
108
+
109
+ return out
110
+
111
+ def block_conv_info(self):
112
+ block_kernel_sizes = [1, 3, 1]
113
+ block_strides = [1, self.stride, 1]
114
+ block_paddings = [0, 1, 0]
115
+
116
+ return block_kernel_sizes, block_strides, block_paddings
117
+
118
+
119
+ class ResNet_features(nn.Module):
120
+ """
121
+ the convolutional layers of ResNet
122
+ the average pooling and final fully convolutional layer is removed
123
+ """
124
+
125
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
126
+ super(ResNet_features, self).__init__()
127
+
128
+ self.inplanes = 64
129
+
130
+ # the first convolutional layer before the structured sequence of blocks
131
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
132
+ self.bn1 = nn.BatchNorm2d(64)
133
+ self.relu = nn.ReLU(inplace=True)
134
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135
+ # comes from the first conv and the following max pool
136
+ self.kernel_sizes = [7, 3]
137
+ self.strides = [2, 2]
138
+ self.paddings = [3, 1]
139
+
140
+ # the following layers, each layer is a sequence of blocks
141
+ self.block = block
142
+ self.layers = layers
143
+ self.layer1 = self._make_layer(
144
+ block=block, planes=64, num_blocks=self.layers[0]
145
+ )
146
+ self.layer2 = self._make_layer(
147
+ block=block, planes=128, num_blocks=self.layers[1], stride=2
148
+ )
149
+ self.layer3 = self._make_layer(
150
+ block=block, planes=256, num_blocks=self.layers[2], stride=2
151
+ )
152
+ self.layer4 = self._make_layer(
153
+ block=block, planes=512, num_blocks=self.layers[3], stride=2
154
+ )
155
+
156
+ # initialize the parameters
157
+ for m in self.modules():
158
+ if isinstance(m, nn.Conv2d):
159
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
160
+ elif isinstance(m, nn.BatchNorm2d):
161
+ nn.init.constant_(m.weight, 1)
162
+ nn.init.constant_(m.bias, 0)
163
+
164
+ # Zero-initialize the last BN in each residual branch,
165
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
166
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
167
+ if zero_init_residual:
168
+ for m in self.modules():
169
+ if isinstance(m, Bottleneck):
170
+ nn.init.constant_(m.bn3.weight, 0)
171
+ elif isinstance(m, BasicBlock):
172
+ nn.init.constant_(m.bn2.weight, 0)
173
+
174
+ def _make_layer(self, block, planes, num_blocks, stride=1):
175
+ downsample = None
176
+ if stride != 1 or self.inplanes != planes * block.expansion:
177
+ downsample = nn.Sequential(
178
+ conv1x1(self.inplanes, planes * block.expansion, stride),
179
+ nn.BatchNorm2d(planes * block.expansion),
180
+ )
181
+
182
+ layers = []
183
+ # only the first block has downsample that is possibly not None
184
+ layers.append(block(self.inplanes, planes, stride, downsample))
185
+
186
+ self.inplanes = planes * block.expansion
187
+ for _ in range(1, num_blocks):
188
+ layers.append(block(self.inplanes, planes))
189
+
190
+ # keep track of every block's conv size, stride size, and padding size
191
+ for each_block in layers:
192
+ (
193
+ block_kernel_sizes,
194
+ block_strides,
195
+ block_paddings,
196
+ ) = each_block.block_conv_info()
197
+ self.kernel_sizes.extend(block_kernel_sizes)
198
+ self.strides.extend(block_strides)
199
+ self.paddings.extend(block_paddings)
200
+
201
+ return nn.Sequential(*layers)
202
+
203
+ def forward(self, x):
204
+ x = self.conv1(x)
205
+ x = self.bn1(x)
206
+ x = self.relu(x)
207
+ x = self.maxpool(x)
208
+
209
+ x = self.layer1(x)
210
+ x = self.layer2(x)
211
+ x = self.layer3(x)
212
+ x = self.layer4(x)
213
+
214
+ return x
215
+
216
+ def conv_info(self):
217
+ return self.kernel_sizes, self.strides, self.paddings
218
+
219
+ def num_layers(self):
220
+ """
221
+ the number of conv layers in the network, not counting the number
222
+ of bypass layers
223
+ """
224
+
225
+ return (
226
+ self.block.num_layers * self.layers[0]
227
+ + self.block.num_layers * self.layers[1]
228
+ + self.block.num_layers * self.layers[2]
229
+ + self.block.num_layers * self.layers[3]
230
+ + 1
231
+ )
232
+
233
+ def __repr__(self):
234
+ template = "resnet{}_features"
235
+ return template.format(self.num_layers() + 1)
236
+
237
+
238
+ def resnet50_features(pretrained=True, inat=True, **kwargs):
239
+ """Constructs a ResNet-50 model.
240
+ Args:
241
+ pretrained (bool): If True, returns a model pre-trained on ImageNet or iNaturalist
242
+ pretrained (bool): If True, returns a model pre-trained on iNaturalst; else, ImageNet
243
+ """
244
+ model = ResNet_features(Bottleneck, [3, 4, 6, 4], **kwargs)
245
+ if pretrained:
246
+ if inat:
247
+ # print('Loading iNat model')
248
+ model_dict = torch.load(
249
+ model_dir
250
+ + "/../../weights/"
251
+ + "BBN.iNaturalist2017.res50.90epoch.best_model.pth.pt"
252
+ )
253
+ else:
254
+ raise
255
+
256
+ if inat:
257
+ model_dict.pop("module.classifier.weight")
258
+ model_dict.pop("module.classifier.bias")
259
+ for key in list(model_dict.keys()):
260
+ model_dict[
261
+ key.replace("module.backbone.", "")
262
+ .replace("cb_block", "layer4.2")
263
+ .replace("rb_block", "layer4.3")
264
+ ] = model_dict.pop(key)
265
+
266
+ else:
267
+ raise
268
+
269
+ model.load_state_dict(model_dict, strict=False)
270
+
271
+ return model
272
+
273
+
274
+ class ResNet_classifier(nn.Module):
275
+ """
276
+ A classifier for Deformable ProtoPNet
277
+ """
278
+
279
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
280
+ super(ResNet_classifier, self).__init__()
281
+
282
+ self.inplanes = 64
283
+
284
+ # the first convolutional layer before the structured sequence of blocks
285
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
286
+ self.bn1 = nn.BatchNorm2d(64)
287
+ self.relu = nn.ReLU(inplace=True)
288
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
289
+ # comes from the first conv and the following max pool
290
+ self.kernel_sizes = [7, 3]
291
+ self.strides = [2, 2]
292
+ self.paddings = [3, 1]
293
+
294
+ # the following layers, each layer is a sequence of blocks
295
+ self.block = block
296
+ self.layers = layers
297
+ self.layer1 = self._make_layer(
298
+ block=block, planes=64, num_blocks=self.layers[0]
299
+ )
300
+ self.layer2 = self._make_layer(
301
+ block=block, planes=128, num_blocks=self.layers[1], stride=2
302
+ )
303
+ self.layer3 = self._make_layer(
304
+ block=block, planes=256, num_blocks=self.layers[2], stride=2
305
+ )
306
+ self.layer4 = self._make_layer(
307
+ block=block, planes=512, num_blocks=self.layers[3], stride=2
308
+ )
309
+
310
+ self.classifier = nn.Linear(2048 * 7 * 7, 200)
311
+
312
+ # initialize the parameters
313
+ for m in self.modules():
314
+ if isinstance(m, nn.Conv2d):
315
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
316
+ elif isinstance(m, nn.BatchNorm2d):
317
+ nn.init.constant_(m.weight, 1)
318
+ nn.init.constant_(m.bias, 0)
319
+
320
+ # Zero-initialize the last BN in each residual branch,
321
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
322
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
323
+ if zero_init_residual:
324
+ for m in self.modules():
325
+ if isinstance(m, Bottleneck):
326
+ nn.init.constant_(m.bn3.weight, 0)
327
+ elif isinstance(m, BasicBlock):
328
+ nn.init.constant_(m.bn2.weight, 0)
329
+
330
+ def _make_layer(self, block, planes, num_blocks, stride=1):
331
+ downsample = None
332
+ if stride != 1 or self.inplanes != planes * block.expansion:
333
+ downsample = nn.Sequential(
334
+ conv1x1(self.inplanes, planes * block.expansion, stride),
335
+ nn.BatchNorm2d(planes * block.expansion),
336
+ )
337
+
338
+ layers = []
339
+ # only the first block has downsample that is possibly not None
340
+ layers.append(block(self.inplanes, planes, stride, downsample))
341
+
342
+ self.inplanes = planes * block.expansion
343
+ for _ in range(1, num_blocks):
344
+ layers.append(block(self.inplanes, planes))
345
+
346
+ # keep track of every block's conv size, stride size, and padding size
347
+ for each_block in layers:
348
+ (
349
+ block_kernel_sizes,
350
+ block_strides,
351
+ block_paddings,
352
+ ) = each_block.block_conv_info()
353
+ self.kernel_sizes.extend(block_kernel_sizes)
354
+ self.strides.extend(block_strides)
355
+ self.paddings.extend(block_paddings)
356
+
357
+ return nn.Sequential(*layers)
358
+
359
+ def forward(self, x):
360
+ x = self.conv1(x)
361
+ x = self.bn1(x)
362
+ x = self.relu(x)
363
+ x = self.maxpool(x)
364
+
365
+ x = self.layer1(x)
366
+ x = self.layer2(x)
367
+ x = self.layer3(x)
368
+ x = self.layer4(x)
369
+ x = self.classifier(torch.flatten(x, start_dim=1))
370
+ return x
371
+
372
+ def conv_info(self):
373
+ return self.kernel_sizes, self.strides, self.paddings
374
+
375
+ def num_layers(self):
376
+ """
377
+ the number of conv layers in the network, not counting the number
378
+ of bypass layers
379
+ """
380
+
381
+ return (
382
+ self.block.num_layers * self.layers[0]
383
+ + self.block.num_layers * self.layers[1]
384
+ + self.block.num_layers * self.layers[2]
385
+ + self.block.num_layers * self.layers[3]
386
+ + 1
387
+ )
388
+
389
+ def __repr__(self):
390
+ template = "resnet{}_features"
391
+ return template.format(self.num_layers() + 1)
Utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.models as models
4
+ from numpy import matlib as mb
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from torchvision.datasets import ImageFolder
8
+ import torchvision.transforms as transforms
9
+ from FeatureExtractors import resnet50_features
10
+
11
+ to_np = lambda x: x.data.to("cpu").numpy()
12
+
13
+
14
+ def compute_spatial_similarity(conv1, conv2):
15
+ """
16
+ Takes in the last convolutional layer from two images, computes the pooled output
17
+ feature, and then generates the spatial similarity map for both images.
18
+ """
19
+ conv1 = conv1.reshape(-1, 7 * 7).T
20
+ conv2 = conv2.reshape(-1, 7 * 7).T
21
+
22
+ pool1 = np.mean(conv1, axis=0)
23
+ pool2 = np.mean(conv2, axis=0)
24
+ out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0])))
25
+ conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0]
26
+ conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0]
27
+ im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0]))
28
+
29
+ for zz in range(conv1_normed.shape[0]):
30
+ repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1)
31
+ im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1)
32
+ similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz)
33
+ similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz)
34
+ return similarity1, similarity2
35
+
36
+
37
+ def normalize_array(x):
38
+ x = np.asarray(x).copy()
39
+ x -= np.min(x)
40
+ x /= np.max(x)
41
+ return x
42
+
43
+
44
+ def apply_threshold(x, t):
45
+ x = np.asarray(x).copy()
46
+ x[x < t] = 0
47
+ return x
48
+
49
+
50
+ def generate_mask(x, t):
51
+ v = np.zeros_like(x)
52
+ v[x >= t] = 1
53
+ return v
54
+
55
+
56
+ def get_transforms(args_transform, chm_args):
57
+ # TRANSFORMS
58
+ cosine_transform_target = transforms.Compose(
59
+ [
60
+ transforms.Resize(256),
61
+ transforms.CenterCrop(224),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
64
+ ]
65
+ )
66
+
67
+ chm_transform_target = transforms.Compose(
68
+ [
69
+ transforms.Resize(chm_args["img_size"]),
70
+ transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
73
+ ]
74
+ )
75
+
76
+ if args_transform == "multi":
77
+ cosine_transform_source = transforms.Compose(
78
+ [
79
+ transforms.Resize((224, 224)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
82
+ ]
83
+ )
84
+
85
+ chm_transform_source = transforms.Compose(
86
+ [
87
+ transforms.Resize((chm_args["img_size"], chm_args["img_size"])),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize(
90
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
91
+ ),
92
+ ]
93
+ )
94
+
95
+ elif args_transform == "single":
96
+ cosine_transform_source = transforms.Compose(
97
+ [
98
+ transforms.Resize(chm_args["img_size"]),
99
+ transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
100
+ transforms.Resize((224, 224)),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
103
+ ]
104
+ )
105
+
106
+ chm_transform_source = transforms.Compose(
107
+ [
108
+ transforms.Resize(chm_args["img_size"]),
109
+ transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
110
+ transforms.ToTensor(),
111
+ transforms.Normalize(
112
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
113
+ ),
114
+ ]
115
+ )
116
+
117
+ return (
118
+ chm_transform_source,
119
+ chm_transform_target,
120
+ cosine_transform_source,
121
+ cosine_transform_target,
122
+ )
123
+
124
+
125
+ def clamp(x, min_value, max_value):
126
+ return max(min_value, min(x, max_value))
127
+
128
+
129
+ def keep_top5(input_array, K=5):
130
+ top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1]
131
+ masked = np.zeros_like(input_array)
132
+ masked[input_array >= top_5] = 1
133
+ return masked
134
+
135
+
136
+ def arg_topK(input_array, topK=5):
137
+ return np.argsort(input_array.T.reshape(-1))[::-1][:topK]
138
+
139
+
140
+ class KNNSupportSet:
141
+ def __init__(self, train_folder, val_folder, knn_scores, custom_val_labels=None):
142
+ self.train_data = ImageFolder(root=train_folder)
143
+ self.val_data = ImageFolder(root=val_folder)
144
+ self.knn_scores = knn_scores
145
+
146
+ if custom_val_labels is None:
147
+ self.val_labels = np.asarray([x[1] for x in self.val_data.imgs])
148
+ else:
149
+ self.val_labels = custom_val_labels
150
+
151
+ self.train_labels = np.asarray([x[1] for x in self.train_data.imgs])
152
+
153
+ def get_knn_predictions(self, k=20):
154
+ knn_predictions = [
155
+ np.argmax(np.bincount(self.train_labels[self.knn_scores[I][::-1][:k]]))
156
+ for I in range(len(self.knn_scores))
157
+ ]
158
+ knn_accuracy = (
159
+ 100
160
+ * np.sum((np.asarray(knn_predictions) == self.val_labels))
161
+ / len(self.val_labels)
162
+ )
163
+ return knn_predictions, knn_accuracy
164
+
165
+ def get_support_set(self, selected_index, top_N=20):
166
+ support_set = self.knn_scores[selected_index][-top_N:][::-1]
167
+ return [self.train_data.imgs[x][0] for x in support_set]
168
+
169
+ def get_support_set_labels(self, selected_index, top_N=20):
170
+ support_set = self.knn_scores[selected_index][-top_N:][::-1]
171
+ return [self.train_data.imgs[x][1] for x in support_set]
172
+
173
+ def get_image_and_label_by_id(self, q_id):
174
+ q = self.val_data.imgs[q_id][0]
175
+ ql = self.val_data.imgs[q_id][1]
176
+ return (q, ql)
177
+
178
+ def get_folder_name(self, q_id):
179
+ q = self.val_data.imgs[q_id][0]
180
+ return q.split("/")[-2]
181
+
182
+ def get_top5_knn(self, query_id, k=20):
183
+ knn_pred, knn_acc = self.get_knn_predictions(k=k)
184
+ top_5s_index = np.where(
185
+ np.equal(
186
+ self.train_labels[self.knn_scores[query_id][::-1]], knn_pred[query_id]
187
+ )
188
+ )[0][:5]
189
+ top_5s = self.knn_scores[query_id][::-1][top_5s_index]
190
+ top_5s_files = [self.train_data.imgs[x][0] for x in top_5s]
191
+ return top_5s_files
192
+
193
+ def get_topK_knn(self, query_id, k=20):
194
+ knn_pred, knn_acc = self.get_knn_predictions(k=k)
195
+ top_ks_index = np.where(
196
+ np.equal(
197
+ self.train_labels[self.knn_scores[query_id][::-1]], knn_pred[query_id]
198
+ )
199
+ )[0][:k]
200
+ top_ks = self.knn_scores[query_id][::-1][top_ks_index]
201
+ top_ks_files = [self.train_data.imgs[x][0] for x in top_ks]
202
+ return top_ks_files
203
+
204
+ def get_foldername_for_label(self, label):
205
+ for i in range(len(self.train_data)):
206
+ if self.train_data.imgs[i][1] == label:
207
+ return self.train_data.imgs[i][0].split("/")[-2]
208
+
209
+ def get_knn_confidence(self, query_id, k=20):
210
+ return np.max(
211
+ np.bincount(self.train_labels[self.knn_scores[query_id][::-1][:k]])
212
+ )
213
+
214
+
215
+ class CosineCustomDataset(Dataset):
216
+ r"""Parent class of PFPascal, PFWillow, and SPair"""
217
+
218
+ def __init__(self, query_image, supporting_set, source_transform, target_transform):
219
+ r"""XAICustomDataset constructor"""
220
+ super(CosineCustomDataset, self).__init__()
221
+
222
+ self.supporting_set = supporting_set
223
+ self.query_image = [query_image] * len(supporting_set)
224
+
225
+ self.source_transform = source_transform
226
+ self.target_transform = target_transform
227
+
228
+ def __len__(self):
229
+ r"""Returns the number of pairs"""
230
+ return len(self.supporting_set)
231
+
232
+ def __getitem__(self, idx):
233
+ r"""Constructs and return a batch"""
234
+
235
+ # Image name
236
+ batch = dict()
237
+ batch["src_imname"] = self.query_image[idx]
238
+ batch["trg_imname"] = self.supporting_set[idx]
239
+
240
+ # Image as numpy (original width, original height)
241
+ src_pil = self.get_image(self.query_image, idx)
242
+ trg_pil = self.get_image(self.supporting_set, idx)
243
+
244
+ batch["src_imsize"] = src_pil.size
245
+ batch["trg_imsize"] = trg_pil.size
246
+
247
+ # Image as tensor
248
+ batch["src_img"] = self.source_transform(src_pil)
249
+ batch["trg_img"] = self.target_transform(trg_pil)
250
+
251
+ # Total number of pairs in training split
252
+ batch["datalen"] = len(self.query_image)
253
+ return batch
254
+
255
+ def get_image(self, image_pathes, idx):
256
+ r"""Reads PIL image from path"""
257
+ path = image_pathes[idx]
258
+ return Image.open(path).convert("RGB")
259
+
260
+
261
+ class PairedLayer4Extractor(torch.nn.Module):
262
+ """
263
+ Extracting layer-4 embedding for source and target images using ResNet-50 features
264
+ """
265
+
266
+ def __init__(self):
267
+ super(PairedLayer4Extractor, self).__init__()
268
+
269
+ self.modelA = models.resnet50(pretrained=True)
270
+ self.modelA.eval()
271
+
272
+ self.modelB = models.resnet50(pretrained=True)
273
+ self.modelB.eval()
274
+
275
+ self.a_embeddings = None
276
+ self.b_embeddings = None
277
+
278
+ def a_hook(module, input, output):
279
+ self.a_embeddings = output
280
+
281
+ def b_hook(module, input, output):
282
+ self.b_embeddings = output
283
+
284
+ self.modelA._modules.get("layer4").register_forward_hook(a_hook)
285
+ self.modelB._modules.get("layer4").register_forward_hook(b_hook)
286
+
287
+ def forward(self, inputs):
288
+ inputA, inputB = inputs
289
+ self.modelA(inputA)
290
+ self.modelB(inputB)
291
+
292
+ return self.a_embeddings, self.b_embeddings
293
+
294
+ def __repr__(self):
295
+ return "PairedLayer4Extractor"
296
+
297
+
298
+ class iNaturalistPairedLayer4Extractor(torch.nn.Module):
299
+ """
300
+ Extracting layer-4 embedding for source and target images using iNaturalist ResNet-50 features
301
+ """
302
+
303
+ def __init__(self):
304
+ super(iNaturalistPairedLayer4Extractor, self).__init__()
305
+
306
+ self.modelA = resnet50_features(inat=True, pretrained=True)
307
+ self.modelA.eval()
308
+
309
+ self.modelB = resnet50_features(inat=True, pretrained=True)
310
+ self.modelB.eval()
311
+
312
+ self.source_embedding = None
313
+ self.target_embedding = None
314
+
315
+ def forward(self, inputs):
316
+ source_image, target_image = inputs
317
+ self.source_embedding = self.modelA(source_image)
318
+ self.target_embedding = self.modelB(target_image)
319
+
320
+ return self.source_embedding, self.target_embedding
321
+
322
+ def __repr__(self):
323
+ return "iNatPairedLayer4Extractor"
app.py CHANGED
@@ -10,12 +10,21 @@ from torchvision.datasets import ImageFolder
10
 
11
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
12
  from ExtractEmbedding import QueryToEmbedding
 
 
13
 
14
  csv.field_size_limit(sys.maxsize)
15
 
16
  concat = lambda x: np.concatenate(x, axis=0)
17
 
18
- gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
 
 
 
 
 
 
 
19
  gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
20
 
21
  # CUB training set
@@ -26,13 +35,21 @@ gdown.cached_download(
26
  md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
27
  )
28
 
29
- # EXTRACT
30
  torchvision.datasets.utils.extract_archive(
31
  from_path="CUB_train.zip",
32
- to_path="Training/",
33
  remove_finished=False,
34
  )
35
 
 
 
 
 
 
 
 
 
36
 
37
  # Caluclate Accuracy
38
  with open(f"./embeddings.pickle", "rb") as f:
@@ -45,35 +62,53 @@ searcher = SearchableTrainingSet(Xtrain, ytrain)
45
  searcher.build_index()
46
 
47
  # Extract label names
48
- training_folder = ImageFolder(root="./Training/train/")
49
  id_to_bird_name = {
50
  x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
51
  }
52
 
53
 
54
- def search(query_imag, searcher=searcher):
55
- query_embedding = QueryToEmbedding(query_imag)
56
- indices, scores, labels = searcher.search(query_embedding, k=50)
57
 
58
  result_ctr = Counter(labels[0][:20]).most_common(5)
59
 
60
  top1_label = result_ctr[0][0]
61
  top_indices = []
62
 
63
- for a, b in zip(labels[0][:20], scores[0][:20]):
64
  if a == top1_label:
65
  top_indices.append(b)
66
 
67
  gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
68
  predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
69
 
70
- return predicted_labels, gallery_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  demo = gr.Interface(
74
  search,
75
- gr.Image(type="pil"),
76
- ["label", "gallery"],
77
  examples=[["./examples/bird.jpg"]],
78
  description="WIP - kNN on CUB dataset",
79
  title="Work in Progress - CHM-Corr",
 
10
 
11
  from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
12
  from ExtractEmbedding import QueryToEmbedding
13
+ from CHMCorr import chm_classify_and_visualize
14
+ from visualization import plot_from_reranker_output
15
 
16
  csv.field_size_limit(sys.maxsize)
17
 
18
  concat = lambda x: np.concatenate(x, axis=0)
19
 
20
+ # Embeddings
21
+ gdown.cached_download(
22
+ url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
23
+ path="./embeddings.pkl",
24
+ quiet=False,
25
+ md5="002b2a7f5c80d910b9cc740c2265f058",
26
+ )
27
+
28
  gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
29
 
30
  # CUB training set
 
35
  md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
36
  )
37
 
38
+ # EXTRACT training set
39
  torchvision.datasets.utils.extract_archive(
40
  from_path="CUB_train.zip",
41
+ to_path="data/",
42
  remove_finished=False,
43
  )
44
 
45
+ # CHM Weights
46
+ gdown.cached_download(
47
+ url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download",
48
+ path="pas_psi.pt",
49
+ quiet=False,
50
+ md5="6b7b4d7bad7f89600fac340d6aa7708b",
51
+ )
52
+
53
 
54
  # Caluclate Accuracy
55
  with open(f"./embeddings.pickle", "rb") as f:
 
62
  searcher.build_index()
63
 
64
  # Extract label names
65
+ training_folder = ImageFolder(root="./data/train/")
66
  id_to_bird_name = {
67
  x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
68
  }
69
 
70
 
71
+ def search(query_image, searcher=searcher):
72
+ query_embedding = QueryToEmbedding(query_image)
73
+ scores, indices, labels = searcher.search(query_embedding, k=50)
74
 
75
  result_ctr = Counter(labels[0][:20]).most_common(5)
76
 
77
  top1_label = result_ctr[0][0]
78
  top_indices = []
79
 
80
+ for a, b in zip(labels[0][:20], indices[0][:20]):
81
  if a == top1_label:
82
  top_indices.append(b)
83
 
84
  gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
85
  predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
86
 
87
+ print("gallery_images:", gallery_images)
88
+
89
+ # CHM Prediction
90
+ kNN_results = (top1_label, result_ctr[0][1], gallery_images)
91
+ support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
92
+
93
+ print(support_files)
94
+ support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
95
+ print(support_labels)
96
+
97
+ support = [support_files, support_labels]
98
+
99
+ chm_output = chm_classify_and_visualize(
100
+ query_image, kNN_results, support, training_folder
101
+ )
102
+
103
+ viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False)
104
+
105
+ return predicted_labels, gallery_images, viz_plot
106
 
107
 
108
  demo = gr.Interface(
109
  search,
110
+ gr.Image(type="filepath"),
111
+ ["label", "gallery", "plot"],
112
  examples=[["./examples/bird.jpg"]],
113
  description="WIP - kNN on CUB dataset",
114
  title="Work in Progress - CHM-Corr",
common/evaluation.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Evaluates CHMNet with PCK """
2
+
3
+ import torch
4
+
5
+
6
+ class Evaluator:
7
+ r""" Computes evaluation metrics of PCK """
8
+ @classmethod
9
+ def initialize(cls, alpha):
10
+ cls.alpha = torch.tensor(alpha).unsqueeze(1)
11
+
12
+ @classmethod
13
+ def evaluate(cls, prd_kps, batch):
14
+ r""" Compute percentage of correct key-points (PCK) with multiple alpha {0.05, 0.1, 0.15 }"""
15
+
16
+ pcks = []
17
+ for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])):
18
+ pckthres = batch['pckthres'][idx]
19
+ npt = batch['n_pts'][idx]
20
+ prd_kps = pk[:, :npt]
21
+ trg_kps = tk[:, :npt]
22
+
23
+ l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5).unsqueeze(0).repeat(len(cls.alpha), 1)
24
+ thres = pckthres.expand_as(l2dist).float() * cls.alpha
25
+ pck = torch.le(l2dist, thres).sum(dim=1) / float(npt)
26
+ if len(pck) == 1: pck = pck[0]
27
+ pcks.append(pck)
28
+
29
+ eval_result = {'pck': pcks}
30
+
31
+ return eval_result
32
+
common/logger.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Logging """
2
+
3
+ import datetime
4
+ import logging
5
+ import os
6
+
7
+ from tensorboardX import SummaryWriter
8
+ import torch
9
+
10
+
11
+ class Logger:
12
+ r""" Writes results of training/testing """
13
+ @classmethod
14
+ def initialize(cls, args, training):
15
+ logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
16
+ logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
17
+ if logpath == '': logpath = logtime
18
+
19
+ cls.logpath = os.path.join('logs', logpath + '.log')
20
+ cls.benchmark = args.benchmark
21
+ os.makedirs(cls.logpath)
22
+
23
+ logging.basicConfig(filemode='w',
24
+ filename=os.path.join(cls.logpath, 'log.txt'),
25
+ level=logging.INFO,
26
+ format='%(message)s',
27
+ datefmt='%m-%d %H:%M:%S')
28
+
29
+ # Console log config
30
+ console = logging.StreamHandler()
31
+ console.setLevel(logging.INFO)
32
+ formatter = logging.Formatter('%(message)s')
33
+ console.setFormatter(formatter)
34
+ logging.getLogger('').addHandler(console)
35
+
36
+ # Tensorboard writer
37
+ cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
38
+
39
+ # Log arguments
40
+ if training:
41
+ logging.info(':======== Convolutional Hough Matching Networks =========')
42
+ for arg_key in args.__dict__:
43
+ logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
44
+ logging.info(':========================================================\n')
45
+
46
+ @classmethod
47
+ def info(cls, msg):
48
+ r""" Writes message to .txt """
49
+ logging.info(msg)
50
+
51
+ @classmethod
52
+ def save_model(cls, model, epoch, val_pck):
53
+ torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
54
+ cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
55
+
56
+
57
+ class AverageMeter:
58
+ r""" Stores loss, evaluation results, selected layers """
59
+ def __init__(self, benchamrk):
60
+ r""" Constructor of AverageMeter """
61
+ self.buffer_keys = ['pck']
62
+ self.buffer = {}
63
+ for key in self.buffer_keys:
64
+ self.buffer[key] = []
65
+
66
+ self.loss_buffer = []
67
+
68
+ def update(self, eval_result, loss=None):
69
+ for key in self.buffer_keys:
70
+ self.buffer[key] += eval_result[key]
71
+
72
+ if loss is not None:
73
+ self.loss_buffer.append(loss)
74
+
75
+ def write_result(self, split, epoch):
76
+ msg = '\n*** %s ' % split
77
+ msg += '[@Epoch %02d] ' % epoch
78
+
79
+ if len(self.loss_buffer) > 0:
80
+ msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
81
+
82
+ for key in self.buffer_keys:
83
+ msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
84
+ msg += '***\n'
85
+ Logger.info(msg)
86
+
87
+ def write_process(self, batch_idx, datalen, epoch):
88
+ msg = '[Epoch: %02d] ' % epoch
89
+ msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
90
+ if len(self.loss_buffer) > 0:
91
+ msg += 'Loss: %5.2f ' % self.loss_buffer[-1]
92
+ msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
93
+
94
+ for key in self.buffer_keys:
95
+ msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
96
+ Logger.info(msg)
97
+
98
+ def write_test_process(self, batch_idx, datalen):
99
+ msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
100
+
101
+ for key in self.buffer_keys:
102
+ if key == 'pck':
103
+ pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
104
+ val = ''
105
+ for p in pcks:
106
+ val += '%5.2f ' % p.item()
107
+ msg += 'Avg %s: %s ' % (key.upper(), val)
108
+ else:
109
+ msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
110
+ Logger.info(msg)
111
+
112
+ def get_test_result(self):
113
+ result = {}
114
+ for key in self.buffer_keys:
115
+ result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100
116
+
117
+ return result
examples/Red_Winged_Blackbird_0012_6015.jpg ADDED
examples/Red_Winged_Blackbird_0025_5342.jpg ADDED
examples/Yellow_Headed_Blackbird_0020_8549.jpg ADDED
examples/Yellow_Headed_Blackbird_0026_8545.jpg ADDED
examples/sample1.jpeg ADDED
examples/sample2.jpeg ADDED
model/base/backbone.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" ResNet-101 backbone network """
2
+
3
+ import torch.utils.model_zoo as model_zoo
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ __all__ = ['Backbone', 'resnet101']
9
+
10
+
11
+ model_urls = {
12
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1):
21
+ r""" 3x3 convolution with padding """
22
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23
+ padding=1, groups=2, bias=False)
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ r""" 1x1 convolution """
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False)
29
+
30
+
31
+ class Bottleneck(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(Bottleneck, self).__init__()
36
+ self.conv1 = conv1x1(inplanes, planes)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.conv2 = conv3x3(planes, planes, stride)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.conv3 = conv1x1(planes, planes * self.expansion)
41
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
42
+ self.relu = nn.ReLU(inplace=True)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+
49
+ out = self.conv1(x)
50
+ out = self.bn1(out)
51
+ out = self.relu(out)
52
+
53
+ out = self.conv2(out)
54
+ out = self.bn2(out)
55
+ out = self.relu(out)
56
+
57
+ out = self.conv3(out)
58
+ out = self.bn3(out)
59
+
60
+ if self.downsample is not None:
61
+ identity = self.downsample(x)
62
+
63
+ out += identity
64
+ out = self.relu(out)
65
+
66
+ return out
67
+
68
+
69
+ class Backbone(nn.Module):
70
+ def __init__(self, block, layers, zero_init_residual=False):
71
+ super(Backbone, self).__init__()
72
+
73
+ self.inplanes = 128
74
+ self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2,
75
+ bias=False)
76
+ self.bn1 = nn.BatchNorm2d(128)
77
+ self.relu = nn.ReLU(inplace=True)
78
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
79
+ self.layer1 = self._make_layer(block, 128, layers[0])
80
+ self.layer2 = self._make_layer(block, 256, layers[1], stride=2)
81
+ self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
82
+ self.layer4 = self._make_layer(block, 1024, layers[3], stride=2)
83
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
+ self.fc = nn.Linear(512 * block.expansion, 1000)
85
+
86
+ for m in self.modules():
87
+ if isinstance(m, nn.Conv2d):
88
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ nn.init.constant_(m.weight, 1)
91
+ nn.init.constant_(m.bias, 0)
92
+
93
+ # Zero-initialize the last BN in each residual branch,
94
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
95
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
96
+ if zero_init_residual:
97
+ for m in self.modules():
98
+ if isinstance(m, Bottleneck):
99
+ nn.init.constant_(m.bn3.weight, 0)
100
+
101
+ def _make_layer(self, block, planes, blocks, stride=1):
102
+ downsample = None
103
+ if stride != 1 or self.inplanes != planes * block.expansion:
104
+ downsample = nn.Sequential(
105
+ conv1x1(self.inplanes, planes * block.expansion, stride),
106
+ nn.BatchNorm2d(planes * block.expansion),
107
+ )
108
+
109
+ layers = []
110
+ layers.append(block(self.inplanes, planes, stride, downsample))
111
+ self.inplanes = planes * block.expansion
112
+ for _ in range(1, blocks):
113
+ layers.append(block(self.inplanes, planes))
114
+
115
+ return nn.Sequential(*layers)
116
+
117
+
118
+ def resnet101(pretrained=False, **kwargs):
119
+ """Constructs a ResNet-101 model.
120
+
121
+ Args:
122
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
123
+ """
124
+ model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs)
125
+ if pretrained:
126
+ weights = model_zoo.load_url(model_urls['resnet101'])
127
+
128
+ for key in weights:
129
+ if key.split('.')[0] == 'fc':
130
+ weights[key] = weights[key].clone()
131
+ continue
132
+ weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
133
+
134
+ model.load_state_dict(weights)
135
+ return model
136
+
model/base/chm.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" 4D and 6D convolutional Hough matching layers """
2
+
3
+ from torch.nn.modules.conv import _ConvNd
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ import torch
7
+
8
+ from common.logger import Logger
9
+ from . import chm_kernel
10
+
11
+
12
+ def fast4d(corr, kernel, bias=None):
13
+ r""" Optimized implementation of 4D convolution """
14
+ bsz, ch, srch, srcw, trgh, trgw = corr.size()
15
+ out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
16
+ psz = kernel_size // 2
17
+
18
+ out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
19
+ corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
20
+
21
+ for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
22
+ inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
23
+ inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
24
+
25
+ add_sid = max(psz - pidx, 0)
26
+ add_fid = min(srch, srch + psz - pidx)
27
+ slc_sid = max(pidx - psz, 0)
28
+ slc_fid = min(srch, srch - psz + pidx)
29
+
30
+ out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
31
+
32
+ if bias is not None:
33
+ out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
34
+
35
+ return out_corr
36
+
37
+
38
+ def fast6d(corr, kernel, bias, diagonal_idx):
39
+ r""" Optimized implementation of 6D convolutional Hough matching
40
+ NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
41
+ r"""
42
+ bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
43
+ _, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
44
+ corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
45
+ kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
46
+ corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
47
+ corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
48
+ contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
49
+
50
+ ndiag = s6d + (ks6d // 2) * 2
51
+ first_sum = []
52
+ for didx in diagonal_idx:
53
+ first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
54
+ first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
55
+
56
+ corr = []
57
+ for didx in diagonal_idx:
58
+ corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
59
+ sidx = ks6d // 2
60
+ eidx = ndiag - sidx
61
+ corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
62
+ corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
63
+
64
+ reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
65
+ corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
66
+ view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
67
+ return corr
68
+
69
+ def init_param_idx4d(param_dict):
70
+ param_idx = []
71
+ for key in param_dict:
72
+ curr_offset = int(key.split('_')[-1])
73
+ param_idx.append(torch.tensor(param_dict[key]))
74
+ return param_idx
75
+
76
+ class CHM4d(_ConvNd):
77
+ r""" 4D convolutional Hough matching layer
78
+ NOTE: this function only supports in_channels=1 and out_channels=1.
79
+ r"""
80
+ def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
81
+ super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
82
+ (1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
83
+ 1, bias, padding_mode='zeros')
84
+
85
+ # Zero kernel initialization
86
+ self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
87
+ self.nkernels = in_channels * out_channels
88
+
89
+ # Initialize kernel indices
90
+ param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
91
+ param_shared = param_dict4d is not None
92
+
93
+ if param_shared:
94
+ # Initialize the shared parameters (multiplied by the number of times being shared)
95
+ self.param_idx = init_param_idx4d(param_dict4d)
96
+ weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
97
+ for weight, param_idx in zip(weights.sort()[0], self.param_idx):
98
+ weight *= len(param_idx)
99
+ self.weight = nn.Parameter(weights)
100
+ else: # full kernel initialziation
101
+ self.param_idx = None
102
+ self.weight = nn.Parameter(torch.abs(self.weight))
103
+ if bias: self.bias = nn.Parameter(torch.tensor(0.0))
104
+ Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
105
+
106
+ def forward(self, x):
107
+ kernel = self.init_kernel()
108
+ x = fast4d(x, kernel, self.bias)
109
+ return x
110
+
111
+ def init_kernel(self):
112
+ # Initialize CHM kernel (divided by the number of times being shared)
113
+ ksz = self.kernel_size[-1]
114
+ if self.param_idx is None:
115
+ kernel = self.weight
116
+ else:
117
+ kernel = torch.zeros_like(self.zero_kernel4d)
118
+ for idx, pdx in enumerate(self.param_idx):
119
+ kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
120
+ for jdx, kernel_single in enumerate(kernel):
121
+ weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
122
+ kernel_single.view(-1)[pdx] += weight
123
+ kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
124
+ return kernel
125
+
126
+
127
+ class CHM6d(_ConvNd):
128
+ r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
129
+ NOTE: this function only supports in_channels=1 and out_channels=1.
130
+ r"""
131
+ def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
132
+ kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
133
+ super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
134
+ (0,) * 6, (1,) * 6, False, (0,) * 6,
135
+ 1, bias=True, padding_mode='zeros')
136
+
137
+ # Zero kernel initialization
138
+ self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
139
+ self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
140
+ self.nkernels = in_channels * out_channels
141
+
142
+ # Initialize kernel indices
143
+ # Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
144
+ self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
145
+ param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
146
+ param_shared = param_dict4d is not None
147
+
148
+ if param_shared: # psi & iso kernel initialization
149
+ if ktype == 'psi':
150
+ self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
151
+ elif ktype == 'iso':
152
+ self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
153
+ self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
154
+
155
+ # Initialize the shared parameters (multiplied by the number of times being shared)
156
+ self.param_idx = init_param_idx4d(param_dict4d)
157
+ self.param = []
158
+ for param_dict6d in self.param_dict6d:
159
+ weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
160
+ for weight, param_idx in zip(weights, self.param_idx):
161
+ weight *= (len(param_idx) * len(param_dict6d))
162
+ self.param.append(nn.Parameter(weights))
163
+ self.param = nn.ParameterList(self.param)
164
+ else: # full kernel initialziation
165
+ self.param_idx = None
166
+ self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
167
+ Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
168
+ self.weight = None
169
+
170
+ def forward(self, corr):
171
+ kernel = self.init_kernel()
172
+ corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
173
+ return corr
174
+
175
+ def init_kernel(self):
176
+ # Initialize CHM kernel (divided by the number of times being shared)
177
+ if self.param_idx is None:
178
+ return self.param
179
+
180
+ kernel6d = torch.zeros_like(self.zero_kernel6d)
181
+ for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
182
+ ksz4d = self.kernel_size[-1]
183
+ kernel4d = torch.zeros_like(self.zero_kernel4d)
184
+ for jdx, pdx in enumerate(self.param_idx):
185
+ kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
186
+ kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
187
+ kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
188
+
189
+ return kernel6d
190
+
model/base/chm_kernel.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" CHM 4D kernel (psi, iso, and full) generator """
2
+
3
+ import torch
4
+
5
+ from .geometry import Geometry
6
+
7
+
8
+ class KernelGenerator:
9
+ def __init__(self, ksz, ktype):
10
+ self.ksz = ksz
11
+ self.idx4d = Geometry.init_idx4d(ksz)
12
+ self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
13
+ self.center = (ksz // 2, ksz // 2)
14
+ self.ktype = ktype
15
+
16
+ def quadrant(self, crd):
17
+ if crd[0] < self.center[0]:
18
+ horz_quad = -1
19
+ elif crd[0] < self.center[0]:
20
+ horz_quad = 1
21
+ else:
22
+ horz_quad = 0
23
+
24
+ if crd[1] < self.center[1]:
25
+ vert_quad = -1
26
+ elif crd[1] < self.center[1]:
27
+ vert_quad = 1
28
+ else:
29
+ vert_quad = 0
30
+
31
+ return horz_quad, vert_quad
32
+
33
+ def generate(self):
34
+ return None if self.ktype == 'full' else self.generate_chm_kernel()
35
+
36
+ def generate_chm_kernel(self):
37
+ param_dict = {}
38
+ for idx in self.idx4d:
39
+ src_i, src_j, trg_i, trg_j = idx
40
+ d_tail = Geometry.get_distance((src_i, src_j), self.center)
41
+ d_head = Geometry.get_distance((trg_i, trg_j), self.center)
42
+ d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
43
+ horz_quad, vert_quad = self.quadrant((src_j, src_i))
44
+
45
+ src_crd = (src_i, src_j)
46
+ trg_crd = (trg_i, trg_j)
47
+
48
+ key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
49
+ coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
50
+
51
+ if param_dict.get(key) is None: param_dict[key] = []
52
+ param_dict[key].append(coord1d)
53
+
54
+ return param_dict
55
+
56
+ def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
57
+
58
+ if self.ktype == 'iso':
59
+ return '%d' % d_off
60
+ elif self.ktype == 'psi':
61
+ d_max = max(d_head, d_tail)
62
+ d_min = min(d_head, d_tail)
63
+ return '%d_%d_%d' % (d_max, d_min, d_off)
64
+ else:
65
+ raise Exception('not implemented.')
66
+
model/base/correlation.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Provides functions that creates/manipulates correlation matrices """
2
+
3
+ import math
4
+
5
+ from torch.nn.functional import interpolate as resize
6
+ import torch
7
+
8
+ from .geometry import Geometry
9
+
10
+
11
+ class Correlation:
12
+
13
+ @classmethod
14
+ def mutual_nn_filter(cls, correlation_matrix, eps=1e-30):
15
+ r""" Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18 )"""
16
+ corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
17
+ corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0]
18
+ corr_src_max[corr_src_max == 0] += eps
19
+ corr_trg_max[corr_trg_max == 0] += eps
20
+
21
+ corr_src = correlation_matrix / corr_src_max
22
+ corr_trg = correlation_matrix / corr_trg_max
23
+
24
+ return correlation_matrix * (corr_src * corr_trg)
25
+
26
+ @classmethod
27
+ def build_correlation6d(self, src_feat, trg_feat, scales, conv2ds):
28
+ r""" Build 6-dimensional correlation tensor """
29
+
30
+ bsz, _, side, side = src_feat.size()
31
+
32
+ # Construct feature pairs with multiple scales
33
+ _src_feats = []
34
+ _trg_feats = []
35
+ for scale, conv in zip(scales, conv2ds):
36
+ s = (round(side * math.sqrt(scale)),) * 2
37
+ _src_feat = conv(resize(src_feat, s, mode='bilinear', align_corners=True))
38
+ _trg_feat = conv(resize(trg_feat, s, mode='bilinear', align_corners=True))
39
+ _src_feats.append(_src_feat)
40
+ _trg_feats.append(_trg_feat)
41
+
42
+ # Build multiple 4-dimensional correlation tensor
43
+ corr6d = []
44
+ for src_feat in _src_feats:
45
+ ch = src_feat.size(1)
46
+
47
+ src_side = src_feat.size(-1)
48
+ src_feat = src_feat.view(bsz, ch, -1).transpose(1, 2)
49
+ src_norm = src_feat.norm(p=2, dim=2, keepdim=True)
50
+
51
+ for trg_feat in _trg_feats:
52
+ trg_side = trg_feat.size(-1)
53
+ trg_feat = trg_feat.view(bsz, ch, -1)
54
+ trg_norm = trg_feat.norm(p=2, dim=1, keepdim=True)
55
+
56
+ correlation = torch.bmm(src_feat, trg_feat) / torch.bmm(src_norm, trg_norm)
57
+ correlation = correlation.view(bsz, src_side, src_side, trg_side, trg_side).contiguous()
58
+ corr6d.append(correlation)
59
+
60
+ # Resize the spatial sizes of the 4D tensors to the same size
61
+ for idx, correlation in enumerate(corr6d):
62
+ corr6d[idx] = Geometry.interpolate4d(correlation, [side, side])
63
+
64
+ # Build 6-dimensional correlation tensor
65
+ corr6d = torch.stack(corr6d).view(len(scales), len(scales),
66
+ bsz, side, side, side, side).permute(2, 0, 1, 3, 4, 5, 6)
67
+ return corr6d.clamp(min=0)
68
+
model/base/geometry.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Provides functions that manipulate boxes and points """
2
+
3
+ import math
4
+
5
+ import torch.nn.functional as F
6
+ import torch
7
+
8
+
9
+ class Geometry(object):
10
+
11
+ @classmethod
12
+ def initialize(cls, img_size):
13
+ cls.img_size = img_size
14
+
15
+ cls.spatial_side = int(img_size / 8)
16
+ norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
17
+
18
+ cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
19
+ cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
20
+ cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
21
+
22
+ cls.feat_idx = torch.arange(0, cls.spatial_side).float()
23
+
24
+ @classmethod
25
+ def normalize_kps(cls, kps):
26
+ kps = kps.clone().detach()
27
+ kps[kps != -2] -= (cls.img_size // 2)
28
+ kps[kps != -2] /= (cls.img_size // 2)
29
+ return kps
30
+
31
+ @classmethod
32
+ def unnormalize_kps(cls, kps):
33
+ kps = kps.clone().detach()
34
+ kps[kps != -2] *= (cls.img_size // 2)
35
+ kps[kps != -2] += (cls.img_size // 2)
36
+ return kps
37
+
38
+ @classmethod
39
+ def attentive_indexing(cls, kps, thres=0.1):
40
+ r"""kps: normalized keypoints x, y (N, 2)
41
+ returns attentive index map(N, spatial_side, spatial_side)
42
+ """
43
+ nkps = kps.size(0)
44
+ kps = kps.view(nkps, 1, 1, 2)
45
+
46
+ eps = 1e-5
47
+ attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
48
+ attmap = (attmap + eps).pow(0.5)
49
+ attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
50
+ attmap = attmap / attmap.sum(dim=1, keepdim=True)
51
+ attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
52
+
53
+ return attmap
54
+
55
+ @classmethod
56
+ def apply_gaussian_kernel(cls, corr, sigma=17):
57
+ bsz, side, side = corr.size()
58
+
59
+ center = corr.max(dim=2)[1]
60
+ center_y = center // cls.spatial_side
61
+ center_x = center % cls.spatial_side
62
+
63
+ y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
64
+ x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
65
+
66
+ y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
67
+ x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
68
+
69
+ gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
70
+ filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
71
+ filtered_corr = filtered_corr.view(bsz, side, side)
72
+
73
+ return filtered_corr
74
+
75
+ @classmethod
76
+ def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
77
+ r""" Transfer keypoints by weighted average """
78
+
79
+ if not normalized:
80
+ src_kps = Geometry.normalize_kps(src_kps)
81
+ confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
82
+
83
+ pdf = F.softmax(confidence_ts, dim=2)
84
+ prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
85
+ prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
86
+
87
+ prd_kps = []
88
+ for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
89
+ max_pts = src_kp.size()[1]
90
+ prd_xy = torch.stack([x, y]).t()
91
+
92
+ src_kp = src_kp[:, :np].t()
93
+ attmap = cls.attentive_indexing(src_kp).view(np, -1)
94
+ prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
95
+ pads = (torch.zeros((2, max_pts - np)) - 2)
96
+ prd_kp = torch.cat([prd_kp, pads], dim=1)
97
+ prd_kps.append(prd_kp)
98
+
99
+ return torch.stack(prd_kps)
100
+
101
+ @staticmethod
102
+ def get_coord1d(coord4d, ksz):
103
+ i, j, k, l = coord4d
104
+ coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
105
+ return coord1d
106
+
107
+ @staticmethod
108
+ def get_distance(coord1, coord2):
109
+ delta_y = int(math.pow(coord1[0] - coord2[0], 2))
110
+ delta_x = int(math.pow(coord1[1] - coord2[1], 2))
111
+ dist = delta_y + delta_x
112
+ return dist
113
+
114
+ @staticmethod
115
+ def interpolate4d(tensor4d, size):
116
+ bsz, h1, w1, h2, w2 = tensor4d.size()
117
+ tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
118
+ tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
119
+ tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
120
+ tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
121
+ tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
122
+
123
+ return tensor4d
124
+ @staticmethod
125
+ def init_idx4d(ksz):
126
+ i0 = torch.arange(0, ksz).repeat(ksz ** 3)
127
+ i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
128
+ i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
129
+ i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
130
+ idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
131
+
132
+ return idx4d
133
+
model/chmlearner.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Conovlutional Hough matching layers """
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from .base.correlation import Correlation
7
+ from .base.geometry import Geometry
8
+ from .base.chm import CHM4d, CHM6d
9
+
10
+
11
+ class CHMLearner(nn.Module):
12
+
13
+ def __init__(self, ktype, feat_dim):
14
+ super(CHMLearner, self).__init__()
15
+
16
+ # Scale-wise feature transformation
17
+ self.scales = [0.5, 1, 2]
18
+ self.conv2ds = nn.ModuleList([nn.Conv2d(feat_dim, feat_dim // 4, kernel_size=3, padding=1, bias=False) for _ in self.scales])
19
+
20
+ # CHM layers
21
+ ksz_translation = 5
22
+ ksz_scale = 3
23
+ self.chm6d = CHM6d(1, 1, ksz_scale, ksz_translation, ktype)
24
+ self.chm4d = CHM4d(1, 1, ksz_translation, ktype, bias=True)
25
+
26
+ # Activations
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.sigmoid = nn.Sigmoid()
29
+ self.softplus = nn.Softplus()
30
+
31
+ def forward(self, src_feat, trg_feat):
32
+
33
+ corr = Correlation.build_correlation6d(src_feat, trg_feat, self.scales, self.conv2ds).unsqueeze(1)
34
+ bsz, ch, s, s, h, w, h, w = corr.size()
35
+
36
+ # CHM layer (6D)
37
+ corr = self.chm6d(corr)
38
+ corr = self.sigmoid(corr)
39
+
40
+ # Scale-space maxpool
41
+ corr = corr.view(bsz, -1, h, w, h, w).max(dim=1)[0]
42
+ corr = Geometry.interpolate4d(corr, [h * 2, w * 2]).unsqueeze(1)
43
+
44
+ # CHM layer (4D)
45
+ corr = self.chm4d(corr).squeeze(1)
46
+
47
+ # To ensure non-negative vote scores & soft cyclic constraints
48
+ corr = self.softplus(corr)
49
+ corr = Correlation.mutual_nn_filter(corr.view(bsz, corr.size(-1) ** 2, corr.size(-1) ** 2).contiguous())
50
+
51
+ return corr
52
+
model/chmnet.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Convolutional Hough Matching Networks """
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from . import chmlearner as chmlearner
7
+ from .base import backbone
8
+
9
+
10
+ class CHMNet(nn.Module):
11
+ def __init__(self, ktype):
12
+ super(CHMNet, self).__init__()
13
+
14
+ self.backbone = backbone.resnet101(pretrained=True)
15
+ self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)
16
+
17
+ def forward(self, src_img, trg_img):
18
+ src_feat, trg_feat = self.extract_features(src_img, trg_img)
19
+ correlation = self.learner(src_feat, trg_feat)
20
+ return correlation
21
+
22
+ def extract_features(self, src_img, trg_img):
23
+ feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
24
+ feat = self.backbone.bn1.forward(feat)
25
+ feat = self.backbone.relu.forward(feat)
26
+ feat = self.backbone.maxpool.forward(feat)
27
+
28
+ for idx in range(1, 5):
29
+ feat = self.backbone.__getattr__('layer%d' % idx)(feat)
30
+
31
+ if idx == 3:
32
+ src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
33
+ trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
34
+ return src_feat, trg_feat
35
+
36
+ def training_objective(cls, prd_kps, trg_kps, npts):
37
+ l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
38
+ loss = []
39
+ for dist, npt in zip(l2dist, npts):
40
+ loss.append(dist[:npt].mean())
41
+ return torch.stack(loss).mean()
42
+
visualization.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from collections import Counter
3
+ from itertools import product
4
+
5
+ import matplotlib
6
+ import matplotlib.patches as patches
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ from matplotlib import gridspec
10
+ from matplotlib import pyplot as plt
11
+ from matplotlib.patches import ConnectionPatch, ConnectionStyle
12
+ from PIL import Image
13
+
14
+ connectionstyle = ConnectionStyle("Arc3, rad=0.2")
15
+
16
+ display_transform = transforms.Compose(
17
+ [transforms.Resize(240), transforms.CenterCrop((240, 240))]
18
+ )
19
+ display_transform_knn = transforms.Compose(
20
+ [transforms.Resize(256), transforms.CenterCrop((224, 224))]
21
+ )
22
+
23
+
24
+ def keep_top_k(input_array, K=5):
25
+ """
26
+ return top 5 (k) from numpy array
27
+ """
28
+ top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1]
29
+ masked = np.zeros_like(input_array)
30
+ masked[input_array >= top_5] = 1
31
+ return masked
32
+
33
+
34
+ def arg_topK(inputarray, topK=5):
35
+ """
36
+ returns indicies related to top K element (largest)
37
+ """
38
+ return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
39
+
40
+
41
+ # FOR MULTI
42
+ def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
43
+ """
44
+ visualize chm results from a reranker output dict
45
+ """
46
+
47
+ ### SET COLORS
48
+ cmap = matplotlib.cm.get_cmap("gist_rainbow")
49
+ rgba = cmap(0.5)
50
+ colors = []
51
+ for k in range(5):
52
+ colors.append(cmap(k / 5.0))
53
+
54
+ ### SET POINTS
55
+ A = np.linspace(1 + 17, 240 - 17 - 1, 7)
56
+ point_list = list(product(A, A))
57
+
58
+ nrow = 4
59
+ ncol = 7
60
+
61
+ fig = plt.figure(figsize=(32, 18))
62
+ gs = gridspec.GridSpec(
63
+ nrow,
64
+ ncol,
65
+ width_ratios=[1, 0.2, 1, 1, 1, 1, 1],
66
+ height_ratios=[1, 1, 1, 1],
67
+ wspace=0.1,
68
+ hspace=0.1,
69
+ top=0.9,
70
+ bottom=0.05,
71
+ left=0.17,
72
+ right=0.845,
73
+ )
74
+ axes = [[None for n in range(ncol - 1)] for x in range(nrow)]
75
+
76
+ for i in range(4):
77
+ axes[i] = []
78
+ for j in range(7):
79
+ if j != 1:
80
+ if (i, j) in [(2, 0), (3, 0)]:
81
+ axes[i].append(new_ax)
82
+ else:
83
+ new_ax = plt.subplot(gs[i, j])
84
+ new_ax.set_xticklabels([])
85
+ new_ax.set_xticks([])
86
+ new_ax.set_yticklabels([])
87
+ new_ax.set_yticks([])
88
+ new_ax.axis("off")
89
+ axes[i].append(new_ax)
90
+
91
+ ##################### DRAW EVERYTHING
92
+ axes[0][0].imshow(
93
+ display_transform(Image.open(reranker_output["q"]).convert("RGB"))
94
+ )
95
+ axes[0][0].set_title(
96
+ f'Query - K={reranker_output["K"]}, N={reranker_output["N"]}', fontsize=21
97
+ )
98
+
99
+ axes[1][0].imshow(
100
+ display_transform(Image.open(reranker_output["q"]).convert("RGB"))
101
+ )
102
+ axes[1][0].set_title(f'Query - K={reranker_output["K"]}', fontsize=21)
103
+
104
+ # axes[2][0].imshow(display_transform(Image.open(reranker_output['q'])))
105
+
106
+ # CHM Top5
107
+ for i in range(min(5, reranker_output["chm-prediction-confidence"])):
108
+ axes[0][1 + i].imshow(
109
+ display_transform(
110
+ Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
111
+ )
112
+ )
113
+ axes[0][1 + i].set_title(f"CHM - Top - {i+1}", fontsize=21)
114
+
115
+ if reranker_output["chm-prediction-confidence"] < 5:
116
+ for i in range(reranker_output["chm-prediction-confidence"], 5):
117
+ axes[0][1 + i].imshow(Image.new(mode="RGB", size=(224, 224), color="white"))
118
+ axes[0][1 + i].set_title(f"", fontsize=21)
119
+
120
+ # KNN top5
121
+ for i in range(min(5, reranker_output["knn-prediction-confidence"])):
122
+ axes[1][1 + i].imshow(
123
+ display_transform_knn(
124
+ Image.open(reranker_output["knn-nearest-neighbors"][i]).convert("RGB")
125
+ )
126
+ )
127
+ axes[1][1 + i].set_title(f"kNN - Top - {i+1}", fontsize=21)
128
+
129
+ if reranker_output["knn-prediction-confidence"] < 5:
130
+ for i in range(reranker_output["knn-prediction-confidence"], 5):
131
+ axes[1][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
132
+ axes[1][1 + i].set_title(f"", fontsize=21)
133
+
134
+ for i in range(min(5, reranker_output["chm-prediction-confidence"])):
135
+ axes[2][i + 1].imshow(
136
+ display_transform(Image.open(reranker_output["q"]).convert("RGB"))
137
+ )
138
+
139
+ # Lower ROWs CHM Top5
140
+ for i in range(min(5, reranker_output["chm-prediction-confidence"])):
141
+ axes[3][1 + i].imshow(
142
+ display_transform(
143
+ Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
144
+ )
145
+ )
146
+
147
+ if reranker_output["chm-prediction-confidence"] < 5:
148
+ for i in range(reranker_output["chm-prediction-confidence"], 5):
149
+ axes[2][i + 1].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
150
+ axes[3][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
151
+
152
+ nzm = reranker_output["non_zero_mask"]
153
+ # Go throught top 5 nearest images
154
+
155
+ # #################################################################################
156
+ if draw_box:
157
+ # SQUARAES
158
+ for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
159
+ # ON SOURCE
160
+ valid_patches_source = arg_topK(
161
+ reranker_output["masked_cos_values"][NC], topK=nzm
162
+ )
163
+
164
+ # ON QUERY
165
+ target_masked_patches = arg_topK(
166
+ reranker_output["masked_cos_values"][NC], topK=nzm
167
+ )
168
+ valid_patches_target = [
169
+ reranker_output["correspondance_map"][NC][x]
170
+ for x in target_masked_patches
171
+ ]
172
+ valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
173
+
174
+ patch_colors = [c for c in colors]
175
+ overlaps = [
176
+ item
177
+ for item, count in Counter(valid_patches_target).items()
178
+ if count > 1
179
+ ]
180
+
181
+ for O in overlaps:
182
+ indices = [i for i, val in enumerate(valid_patches_target) if val == O]
183
+ for ii in indices[1:]:
184
+ patch_colors[ii] = patch_colors[indices[0]]
185
+
186
+ for i in valid_patches_source:
187
+ Psource = point_list[i]
188
+ rect = patches.Rectangle(
189
+ (Psource[0] - 16, Psource[1] - 16),
190
+ 32,
191
+ 32,
192
+ linewidth=2,
193
+ edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
194
+ facecolor="none",
195
+ alpha=1,
196
+ )
197
+ axes[2][1 + NC].add_patch(rect)
198
+
199
+ for i in valid_patches_target:
200
+ Psource = point_list[i]
201
+ rect = patches.Rectangle(
202
+ (Psource[0] - 16, Psource[1] - 16),
203
+ 32,
204
+ 32,
205
+ linewidth=2,
206
+ edgecolor=patch_colors[valid_patches_target.index(i)],
207
+ facecolor="none",
208
+ alpha=1,
209
+ )
210
+ axes[3][1 + NC].add_patch(rect)
211
+
212
+ #################################################################################
213
+ # Show correspondence lines and points
214
+ if draw_arcs:
215
+ for CK in range(min(5, reranker_output["chm-prediction-confidence"])):
216
+ target_keypoints = []
217
+ topk_index = arg_topK(reranker_output["masked_cos_values"][CK], topK=nzm)
218
+ for i in range(nzm): # Number of Connections
219
+ con = ConnectionPatch(
220
+ xyA=(
221
+ reranker_output["src-keypoints"][CK][i, 0],
222
+ reranker_output["src-keypoints"][CK][i, 1],
223
+ ),
224
+ xyB=(
225
+ reranker_output["tgt-keypoints"][CK][i, 0],
226
+ reranker_output["tgt-keypoints"][CK][i, 1],
227
+ ),
228
+ coordsA="data",
229
+ coordsB="data",
230
+ axesA=axes[2][1 + CK],
231
+ axesB=axes[3][1 + CK],
232
+ color=colors[i],
233
+ connectionstyle=connectionstyle,
234
+ shrinkA=1.0,
235
+ shrinkB=1.0,
236
+ linewidth=1,
237
+ )
238
+
239
+ axes[3][1 + CK].add_artist(con)
240
+
241
+ # Scatter Plot
242
+ axes[2][1 + CK].scatter(
243
+ reranker_output["src-keypoints"][CK][:, 0],
244
+ reranker_output["src-keypoints"][CK][:, 1],
245
+ c=colors[:nzm],
246
+ s=10,
247
+ )
248
+ axes[3][1 + CK].scatter(
249
+ reranker_output["tgt-keypoints"][CK][:, 0],
250
+ reranker_output["tgt-keypoints"][CK][:, 1],
251
+ c=colors[:nzm],
252
+ s=10,
253
+ )
254
+
255
+ fig.text(
256
+ 0.5,
257
+ 0.95,
258
+ f"CHM: {reranker_output['chm-prediction']}",
259
+ ha="center",
260
+ va="bottom",
261
+ color="black",
262
+ fontsize=22,
263
+ )
264
+ fig.text(
265
+ 0.8,
266
+ 0.95,
267
+ f"KNN: {reranker_output['knn-prediction']}",
268
+ ha="right",
269
+ va="bottom",
270
+ color="black",
271
+ fontsize=22,
272
+ )
273
+
274
+ return fig