Vincentqyw commited on
Commit
c74a070
·
1 Parent(s): c608946
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +3 -2
  2. third_party/ALIKE/alike.py +91 -36
  3. third_party/ALIKE/alnet.py +66 -36
  4. third_party/ALIKE/demo.py +82 -48
  5. third_party/ALIKE/hseq/eval.py +71 -36
  6. third_party/ALIKE/hseq/extract.py +45 -29
  7. third_party/ALIKE/soft_detect.py +72 -32
  8. third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py +5 -4
  9. third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py +4 -3
  10. third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py +6 -5
  11. third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py +4 -3
  12. third_party/ASpanFormer/configs/data/base.py +1 -0
  13. third_party/ASpanFormer/configs/data/megadepth_test_1500.py +3 -3
  14. third_party/ASpanFormer/configs/data/megadepth_trainval_832.py +7 -3
  15. third_party/ASpanFormer/configs/data/scannet_trainval.py +7 -3
  16. third_party/ASpanFormer/demo/demo.py +68 -40
  17. third_party/ASpanFormer/demo/demo_utils.py +71 -27
  18. third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py +1 -1
  19. third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py +224 -110
  20. third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py +36 -20
  21. third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py +22 -26
  22. third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py +247 -140
  23. third_party/ASpanFormer/src/ASpanFormer/aspanformer.py +107 -62
  24. third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py +8 -6
  25. third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py +36 -21
  26. third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py +168 -132
  27. third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py +6 -6
  28. third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py +32 -22
  29. third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py +29 -10
  30. third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py +36 -17
  31. third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py +62 -41
  32. third_party/ASpanFormer/src/config/default.py +50 -31
  33. third_party/ASpanFormer/src/datasets/__init__.py +0 -1
  34. third_party/ASpanFormer/src/datasets/megadepth.py +83 -56
  35. third_party/ASpanFormer/src/datasets/sampler.py +33 -20
  36. third_party/ASpanFormer/src/datasets/scannet.py +52 -42
  37. third_party/ASpanFormer/src/lightning/data.py +222 -143
  38. third_party/ASpanFormer/src/lightning/lightning_aspanformer.py +218 -120
  39. third_party/ASpanFormer/src/losses/aspan_loss.py +155 -97
  40. third_party/ASpanFormer/src/optimizers/__init__.py +22 -9
  41. third_party/ASpanFormer/src/utils/augment.py +33 -23
  42. third_party/ASpanFormer/src/utils/comm.py +12 -7
  43. third_party/ASpanFormer/src/utils/dataloader.py +8 -7
  44. third_party/ASpanFormer/src/utils/dataset.py +48 -38
  45. third_party/ASpanFormer/src/utils/metrics.py +100 -67
  46. third_party/ASpanFormer/src/utils/misc.py +83 -38
  47. third_party/ASpanFormer/src/utils/plotting.py +128 -94
  48. third_party/ASpanFormer/src/utils/profiler.py +5 -4
  49. third_party/ASpanFormer/test.py +43 -17
  50. third_party/ASpanFormer/tools/extract.py +59 -25
app.py CHANGED
@@ -9,9 +9,10 @@ from extra_utils.utils import (
9
  match_features,
10
  get_model,
11
  get_feature_model,
12
- display_matches
13
  )
14
 
 
15
  def run_matching(
16
  match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
17
  ):
@@ -277,7 +278,7 @@ def run(config):
277
  matcher_info,
278
  ]
279
  button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
280
-
281
  app.launch(share=False)
282
 
283
 
 
9
  match_features,
10
  get_model,
11
  get_feature_model,
12
+ display_matches,
13
  )
14
 
15
+
16
  def run_matching(
17
  match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
18
  ):
 
278
  matcher_info,
279
  ]
280
  button_reset.click(fn=ui_reset_state, inputs=inputs, outputs=reset_outputs)
281
+
282
  app.launch(share=False)
283
 
284
 
third_party/ALIKE/alike.py CHANGED
@@ -12,46 +12,89 @@ from soft_detect import DKD
12
  import time
13
 
14
  configs = {
15
- 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
16
- 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-t.pth')},
17
- 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
18
- 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-s.pth')},
19
- 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
20
- 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-n.pth')},
21
- 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
22
- 'model_path': os.path.join(os.path.split(__file__)[0], 'models', 'alike-l.pth')},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
 
26
  class ALike(ALNet):
27
- def __init__(self,
28
- # ================================== feature encoder
29
- c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
30
- single_head: bool = False,
31
- # ================================== detect parameters
32
- radius: int = 2,
33
- top_k: int = 500, scores_th: float = 0.5,
34
- n_limit: int = 5000,
35
- device: str = 'cpu',
36
- model_path: str = ''
37
- ):
 
 
 
 
 
 
38
  super().__init__(c1, c2, c3, c4, dim, single_head)
39
  self.radius = radius
40
  self.top_k = top_k
41
  self.n_limit = n_limit
42
  self.scores_th = scores_th
43
- self.dkd = DKD(radius=self.radius, top_k=self.top_k,
44
- scores_th=self.scores_th, n_limit=self.n_limit)
 
 
 
 
45
  self.device = device
46
 
47
- if model_path != '':
48
  state_dict = torch.load(model_path, self.device)
49
  self.load_state_dict(state_dict)
50
  self.to(self.device)
51
  self.eval()
52
- logging.info(f'Loaded model parameters from {model_path}')
53
  logging.info(
54
- f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB")
 
55
 
56
  def extract_dense_map(self, image, ret_dict=False):
57
  # ====================================================
@@ -81,7 +124,10 @@ class ALike(ALNet):
81
  descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
82
 
83
  if ret_dict:
84
- return {'descriptor_map': descriptor_map, 'scores_map': scores_map, }
 
 
 
85
  else:
86
  return descriptor_map, scores_map
87
 
@@ -104,15 +150,22 @@ class ALike(ALNet):
104
  image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
105
 
106
  # ==================== convert image to tensor
107
- image = torch.from_numpy(image).to(self.device).to(torch.float32).permute(2, 0, 1)[None] / 255.0
 
 
 
 
 
 
108
 
109
  # ==================== extract keypoints
110
  start = time.time()
111
 
112
  with torch.no_grad():
113
  descriptor_map, scores_map = self.extract_dense_map(image)
114
- keypoints, descriptors, scores, _ = self.dkd(scores_map, descriptor_map,
115
- sub_pixel=sub_pixel)
 
116
  keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
117
  keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
118
 
@@ -124,14 +177,16 @@ class ALike(ALNet):
124
 
125
  end = time.time()
126
 
127
- return {'keypoints': keypoints.cpu().numpy(),
128
- 'descriptors': descriptors.cpu().numpy(),
129
- 'scores': scores.cpu().numpy(),
130
- 'scores_map': scores_map.cpu().numpy(),
131
- 'time': end - start, }
 
 
132
 
133
 
134
- if __name__ == '__main__':
135
  import numpy as np
136
  from thop import profile
137
 
@@ -139,5 +194,5 @@ if __name__ == '__main__':
139
 
140
  image = np.random.random((640, 480, 3)).astype(np.float32)
141
  flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
142
- print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
143
- print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3))
 
12
  import time
13
 
14
  configs = {
15
+ "alike-t": {
16
+ "c1": 8,
17
+ "c2": 16,
18
+ "c3": 32,
19
+ "c4": 64,
20
+ "dim": 64,
21
+ "single_head": True,
22
+ "radius": 2,
23
+ "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-t.pth"),
24
+ },
25
+ "alike-s": {
26
+ "c1": 8,
27
+ "c2": 16,
28
+ "c3": 48,
29
+ "c4": 96,
30
+ "dim": 96,
31
+ "single_head": True,
32
+ "radius": 2,
33
+ "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-s.pth"),
34
+ },
35
+ "alike-n": {
36
+ "c1": 16,
37
+ "c2": 32,
38
+ "c3": 64,
39
+ "c4": 128,
40
+ "dim": 128,
41
+ "single_head": True,
42
+ "radius": 2,
43
+ "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-n.pth"),
44
+ },
45
+ "alike-l": {
46
+ "c1": 32,
47
+ "c2": 64,
48
+ "c3": 128,
49
+ "c4": 128,
50
+ "dim": 128,
51
+ "single_head": False,
52
+ "radius": 2,
53
+ "model_path": os.path.join(os.path.split(__file__)[0], "models", "alike-l.pth"),
54
+ },
55
  }
56
 
57
 
58
  class ALike(ALNet):
59
+ def __init__(
60
+ self,
61
+ # ================================== feature encoder
62
+ c1: int = 32,
63
+ c2: int = 64,
64
+ c3: int = 128,
65
+ c4: int = 128,
66
+ dim: int = 128,
67
+ single_head: bool = False,
68
+ # ================================== detect parameters
69
+ radius: int = 2,
70
+ top_k: int = 500,
71
+ scores_th: float = 0.5,
72
+ n_limit: int = 5000,
73
+ device: str = "cpu",
74
+ model_path: str = "",
75
+ ):
76
  super().__init__(c1, c2, c3, c4, dim, single_head)
77
  self.radius = radius
78
  self.top_k = top_k
79
  self.n_limit = n_limit
80
  self.scores_th = scores_th
81
+ self.dkd = DKD(
82
+ radius=self.radius,
83
+ top_k=self.top_k,
84
+ scores_th=self.scores_th,
85
+ n_limit=self.n_limit,
86
+ )
87
  self.device = device
88
 
89
+ if model_path != "":
90
  state_dict = torch.load(model_path, self.device)
91
  self.load_state_dict(state_dict)
92
  self.to(self.device)
93
  self.eval()
94
+ logging.info(f"Loaded model parameters from {model_path}")
95
  logging.info(
96
+ f"Number of model parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e3}KB"
97
+ )
98
 
99
  def extract_dense_map(self, image, ret_dict=False):
100
  # ====================================================
 
124
  descriptor_map = torch.nn.functional.normalize(descriptor_map, p=2, dim=1)
125
 
126
  if ret_dict:
127
+ return {
128
+ "descriptor_map": descriptor_map,
129
+ "scores_map": scores_map,
130
+ }
131
  else:
132
  return descriptor_map, scores_map
133
 
 
150
  image = cv2.resize(image, dsize=None, fx=ratio, fy=ratio)
151
 
152
  # ==================== convert image to tensor
153
+ image = (
154
+ torch.from_numpy(image)
155
+ .to(self.device)
156
+ .to(torch.float32)
157
+ .permute(2, 0, 1)[None]
158
+ / 255.0
159
+ )
160
 
161
  # ==================== extract keypoints
162
  start = time.time()
163
 
164
  with torch.no_grad():
165
  descriptor_map, scores_map = self.extract_dense_map(image)
166
+ keypoints, descriptors, scores, _ = self.dkd(
167
+ scores_map, descriptor_map, sub_pixel=sub_pixel
168
+ )
169
  keypoints, descriptors, scores = keypoints[0], descriptors[0], scores[0]
170
  keypoints = (keypoints + 1) / 2 * keypoints.new_tensor([[W - 1, H - 1]])
171
 
 
177
 
178
  end = time.time()
179
 
180
+ return {
181
+ "keypoints": keypoints.cpu().numpy(),
182
+ "descriptors": descriptors.cpu().numpy(),
183
+ "scores": scores.cpu().numpy(),
184
+ "scores_map": scores_map.cpu().numpy(),
185
+ "time": end - start,
186
+ }
187
 
188
 
189
+ if __name__ == "__main__":
190
  import numpy as np
191
  from thop import profile
192
 
 
194
 
195
  image = np.random.random((640, 480, 3)).astype(np.float32)
196
  flops, params = profile(net, inputs=(image, 9999, False), verbose=False)
197
+ print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
198
+ print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3))
third_party/ALIKE/alnet.py CHANGED
@@ -5,9 +5,13 @@ from typing import Optional, Callable
5
 
6
 
7
  class ConvBlock(nn.Module):
8
- def __init__(self, in_channels, out_channels,
9
- gate: Optional[Callable[..., nn.Module]] = None,
10
- norm_layer: Optional[Callable[..., nn.Module]] = None):
 
 
 
 
11
  super().__init__()
12
  if gate is None:
13
  self.gate = nn.ReLU(inplace=True)
@@ -31,16 +35,16 @@ class ResBlock(nn.Module):
31
  expansion: int = 1
32
 
33
  def __init__(
34
- self,
35
- inplanes: int,
36
- planes: int,
37
- stride: int = 1,
38
- downsample: Optional[nn.Module] = None,
39
- groups: int = 1,
40
- base_width: int = 64,
41
- dilation: int = 1,
42
- gate: Optional[Callable[..., nn.Module]] = None,
43
- norm_layer: Optional[Callable[..., nn.Module]] = None
44
  ) -> None:
45
  super(ResBlock, self).__init__()
46
  if gate is None:
@@ -50,7 +54,7 @@ class ResBlock(nn.Module):
50
  if norm_layer is None:
51
  norm_layer = nn.BatchNorm2d
52
  if groups != 1 or base_width != 64:
53
- raise ValueError('ResBlock only supports groups=1 and base_width=64')
54
  if dilation > 1:
55
  raise NotImplementedError("Dilation > 1 not supported in ResBlock")
56
  # Both self.conv1 and self.downsample layers downsample the input when stride != 1
@@ -81,9 +85,15 @@ class ResBlock(nn.Module):
81
 
82
 
83
  class ALNet(nn.Module):
84
- def __init__(self, c1: int = 32, c2: int = 64, c3: int = 128, c4: int = 128, dim: int = 128,
85
- single_head: bool = True,
86
- ):
 
 
 
 
 
 
87
  super().__init__()
88
 
89
  self.gate = nn.ReLU(inplace=True)
@@ -93,28 +103,48 @@ class ALNet(nn.Module):
93
 
94
  self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
95
 
96
- self.block2 = ResBlock(inplanes=c1, planes=c2, stride=1,
97
- downsample=nn.Conv2d(c1, c2, 1),
98
- gate=self.gate,
99
- norm_layer=nn.BatchNorm2d)
100
- self.block3 = ResBlock(inplanes=c2, planes=c3, stride=1,
101
- downsample=nn.Conv2d(c2, c3, 1),
102
- gate=self.gate,
103
- norm_layer=nn.BatchNorm2d)
104
- self.block4 = ResBlock(inplanes=c3, planes=c4, stride=1,
105
- downsample=nn.Conv2d(c3, c4, 1),
106
- gate=self.gate,
107
- norm_layer=nn.BatchNorm2d)
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # ================================== feature aggregation
110
  self.conv1 = resnet.conv1x1(c1, dim // 4)
111
  self.conv2 = resnet.conv1x1(c2, dim // 4)
112
  self.conv3 = resnet.conv1x1(c3, dim // 4)
113
  self.conv4 = resnet.conv1x1(dim, dim // 4)
114
- self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
115
- self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
116
- self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
117
- self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
 
 
 
 
 
 
 
 
118
 
119
  # ================================== detector and descriptor head
120
  self.single_head = single_head
@@ -153,12 +183,12 @@ class ALNet(nn.Module):
153
  return scores_map, descriptor_map
154
 
155
 
156
- if __name__ == '__main__':
157
  from thop import profile
158
 
159
  net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
160
 
161
  image = torch.randn(1, 3, 640, 480)
162
  flops, params = profile(net, inputs=(image,), verbose=False)
163
- print('{:<30} {:<8} GFLops'.format('Computational complexity: ', flops / 1e9))
164
- print('{:<30} {:<8} KB'.format('Number of parameters: ', params / 1e3))
 
5
 
6
 
7
  class ConvBlock(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels,
11
+ out_channels,
12
+ gate: Optional[Callable[..., nn.Module]] = None,
13
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
14
+ ):
15
  super().__init__()
16
  if gate is None:
17
  self.gate = nn.ReLU(inplace=True)
 
35
  expansion: int = 1
36
 
37
  def __init__(
38
+ self,
39
+ inplanes: int,
40
+ planes: int,
41
+ stride: int = 1,
42
+ downsample: Optional[nn.Module] = None,
43
+ groups: int = 1,
44
+ base_width: int = 64,
45
+ dilation: int = 1,
46
+ gate: Optional[Callable[..., nn.Module]] = None,
47
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
48
  ) -> None:
49
  super(ResBlock, self).__init__()
50
  if gate is None:
 
54
  if norm_layer is None:
55
  norm_layer = nn.BatchNorm2d
56
  if groups != 1 or base_width != 64:
57
+ raise ValueError("ResBlock only supports groups=1 and base_width=64")
58
  if dilation > 1:
59
  raise NotImplementedError("Dilation > 1 not supported in ResBlock")
60
  # Both self.conv1 and self.downsample layers downsample the input when stride != 1
 
85
 
86
 
87
  class ALNet(nn.Module):
88
+ def __init__(
89
+ self,
90
+ c1: int = 32,
91
+ c2: int = 64,
92
+ c3: int = 128,
93
+ c4: int = 128,
94
+ dim: int = 128,
95
+ single_head: bool = True,
96
+ ):
97
  super().__init__()
98
 
99
  self.gate = nn.ReLU(inplace=True)
 
103
 
104
  self.block1 = ConvBlock(3, c1, self.gate, nn.BatchNorm2d)
105
 
106
+ self.block2 = ResBlock(
107
+ inplanes=c1,
108
+ planes=c2,
109
+ stride=1,
110
+ downsample=nn.Conv2d(c1, c2, 1),
111
+ gate=self.gate,
112
+ norm_layer=nn.BatchNorm2d,
113
+ )
114
+ self.block3 = ResBlock(
115
+ inplanes=c2,
116
+ planes=c3,
117
+ stride=1,
118
+ downsample=nn.Conv2d(c2, c3, 1),
119
+ gate=self.gate,
120
+ norm_layer=nn.BatchNorm2d,
121
+ )
122
+ self.block4 = ResBlock(
123
+ inplanes=c3,
124
+ planes=c4,
125
+ stride=1,
126
+ downsample=nn.Conv2d(c3, c4, 1),
127
+ gate=self.gate,
128
+ norm_layer=nn.BatchNorm2d,
129
+ )
130
 
131
  # ================================== feature aggregation
132
  self.conv1 = resnet.conv1x1(c1, dim // 4)
133
  self.conv2 = resnet.conv1x1(c2, dim // 4)
134
  self.conv3 = resnet.conv1x1(c3, dim // 4)
135
  self.conv4 = resnet.conv1x1(dim, dim // 4)
136
+ self.upsample2 = nn.Upsample(
137
+ scale_factor=2, mode="bilinear", align_corners=True
138
+ )
139
+ self.upsample4 = nn.Upsample(
140
+ scale_factor=4, mode="bilinear", align_corners=True
141
+ )
142
+ self.upsample8 = nn.Upsample(
143
+ scale_factor=8, mode="bilinear", align_corners=True
144
+ )
145
+ self.upsample32 = nn.Upsample(
146
+ scale_factor=32, mode="bilinear", align_corners=True
147
+ )
148
 
149
  # ================================== detector and descriptor head
150
  self.single_head = single_head
 
183
  return scores_map, descriptor_map
184
 
185
 
186
+ if __name__ == "__main__":
187
  from thop import profile
188
 
189
  net = ALNet(c1=16, c2=32, c3=64, c4=128, dim=128, single_head=True)
190
 
191
  image = torch.randn(1, 3, 640, 480)
192
  flops, params = profile(net, inputs=(image,), verbose=False)
193
+ print("{:<30} {:<8} GFLops".format("Computational complexity: ", flops / 1e9))
194
+ print("{:<30} {:<8} KB".format("Number of parameters: ", params / 1e3))
third_party/ALIKE/demo.py CHANGED
@@ -12,13 +12,13 @@ from alike import ALike, configs
12
  class ImageLoader(object):
13
  def __init__(self, filepath: str):
14
  self.N = 3000
15
- if filepath.startswith('camera'):
16
  camera = int(filepath[6:])
17
  self.cap = cv2.VideoCapture(camera)
18
  if not self.cap.isOpened():
19
  raise IOError(f"Can't open camera {camera}!")
20
- logging.info(f'Opened camera {camera}')
21
- self.mode = 'camera'
22
  elif os.path.exists(filepath):
23
  if os.path.isfile(filepath):
24
  self.cap = cv2.VideoCapture(filepath)
@@ -27,34 +27,38 @@ class ImageLoader(object):
27
  rate = self.cap.get(cv2.CAP_PROP_FPS)
28
  self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
29
  duration = self.N / rate
30
- logging.info(f'Opened video {filepath}')
31
- logging.info(f'Frames: {self.N}, FPS: {rate}, Duration: {duration}s')
32
- self.mode = 'video'
33
  else:
34
- self.images = glob.glob(os.path.join(filepath, '*.png')) + \
35
- glob.glob(os.path.join(filepath, '*.jpg')) + \
36
- glob.glob(os.path.join(filepath, '*.ppm'))
 
 
37
  self.images.sort()
38
  self.N = len(self.images)
39
- logging.info(f'Loading {self.N} images')
40
- self.mode = 'images'
41
  else:
42
- raise IOError('Error filepath (camerax/path of images/path of videos): ', filepath)
 
 
43
 
44
  def __getitem__(self, item):
45
- if self.mode == 'camera' or self.mode == 'video':
46
  if item > self.N:
47
  return None
48
  ret, img = self.cap.read()
49
  if not ret:
50
  raise "Can't read image from camera"
51
- if self.mode == 'video':
52
  self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
53
- elif self.mode == 'images':
54
  filename = self.images[item]
55
  img = cv2.imread(filename)
56
  if img is None:
57
- raise Exception('Error reading image %s' % filename)
58
  return img
59
 
60
  def __len__(self):
@@ -99,38 +103,68 @@ class SimpleTracker(object):
99
  nn12 = np.argmax(sim, axis=1)
100
  nn21 = np.argmax(sim, axis=0)
101
  ids1 = np.arange(0, sim.shape[0])
102
- mask = (ids1 == nn21[nn12])
103
  matches = np.stack([ids1[mask], nn12[mask]])
104
  return matches.transpose()
105
 
106
 
107
- if __name__ == '__main__':
108
- parser = argparse.ArgumentParser(description='ALike Demo.')
109
- parser.add_argument('input', type=str, default='',
110
- help='Image directory or movie file or "camera0" (for webcam0).')
111
- parser.add_argument('--model', choices=['alike-t', 'alike-s', 'alike-n', 'alike-l'], default="alike-t",
112
- help="The model configuration")
113
- parser.add_argument('--device', type=str, default='cuda', help="Running device (default: cuda).")
114
- parser.add_argument('--top_k', type=int, default=-1,
115
- help='Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)')
116
- parser.add_argument('--scores_th', type=float, default=0.2,
117
- help='Detector score threshold (default: 0.2).')
118
- parser.add_argument('--n_limit', type=int, default=5000,
119
- help='Maximum number of keypoints to be detected (default: 5000).')
120
- parser.add_argument('--no_display', action='store_true',
121
- help='Do not display images to screen. Useful if running remotely (default: False).')
122
- parser.add_argument('--no_sub_pixel', action='store_true',
123
- help='Do not detect sub-pixel keypoints (default: False).')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  args = parser.parse_args()
125
 
126
  logging.basicConfig(level=logging.INFO)
127
 
128
  image_loader = ImageLoader(args.input)
129
- model = ALike(**configs[args.model],
130
- device=args.device,
131
- top_k=args.top_k,
132
- scores_th=args.scores_th,
133
- n_limit=args.n_limit)
 
 
134
  tracker = SimpleTracker()
135
 
136
  if not args.no_display:
@@ -142,26 +176,26 @@ if __name__ == '__main__':
142
  for img in progress_bar:
143
  if img is None:
144
  break
145
-
146
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
147
  pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
148
- kpts = pred['keypoints']
149
- desc = pred['descriptors']
150
- runtime.append(pred['time'])
151
 
152
  out, N_matches = tracker.update(img, kpts, desc)
153
 
154
- ave_fps = (1. / np.stack(runtime)).mean()
155
  status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
156
  progress_bar.set_description(status)
157
 
158
  if not args.no_display:
159
- cv2.setWindowTitle(args.model, args.model + ': ' + status)
160
  cv2.imshow(args.model, out)
161
- if cv2.waitKey(1) == ord('q'):
162
  break
163
 
164
- logging.info('Finished!')
165
  if not args.no_display:
166
- logging.info('Press any key to exit!')
167
  cv2.waitKey()
 
12
  class ImageLoader(object):
13
  def __init__(self, filepath: str):
14
  self.N = 3000
15
+ if filepath.startswith("camera"):
16
  camera = int(filepath[6:])
17
  self.cap = cv2.VideoCapture(camera)
18
  if not self.cap.isOpened():
19
  raise IOError(f"Can't open camera {camera}!")
20
+ logging.info(f"Opened camera {camera}")
21
+ self.mode = "camera"
22
  elif os.path.exists(filepath):
23
  if os.path.isfile(filepath):
24
  self.cap = cv2.VideoCapture(filepath)
 
27
  rate = self.cap.get(cv2.CAP_PROP_FPS)
28
  self.N = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
29
  duration = self.N / rate
30
+ logging.info(f"Opened video {filepath}")
31
+ logging.info(f"Frames: {self.N}, FPS: {rate}, Duration: {duration}s")
32
+ self.mode = "video"
33
  else:
34
+ self.images = (
35
+ glob.glob(os.path.join(filepath, "*.png"))
36
+ + glob.glob(os.path.join(filepath, "*.jpg"))
37
+ + glob.glob(os.path.join(filepath, "*.ppm"))
38
+ )
39
  self.images.sort()
40
  self.N = len(self.images)
41
+ logging.info(f"Loading {self.N} images")
42
+ self.mode = "images"
43
  else:
44
+ raise IOError(
45
+ "Error filepath (camerax/path of images/path of videos): ", filepath
46
+ )
47
 
48
  def __getitem__(self, item):
49
+ if self.mode == "camera" or self.mode == "video":
50
  if item > self.N:
51
  return None
52
  ret, img = self.cap.read()
53
  if not ret:
54
  raise "Can't read image from camera"
55
+ if self.mode == "video":
56
  self.cap.set(cv2.CAP_PROP_POS_FRAMES, item)
57
+ elif self.mode == "images":
58
  filename = self.images[item]
59
  img = cv2.imread(filename)
60
  if img is None:
61
+ raise Exception("Error reading image %s" % filename)
62
  return img
63
 
64
  def __len__(self):
 
103
  nn12 = np.argmax(sim, axis=1)
104
  nn21 = np.argmax(sim, axis=0)
105
  ids1 = np.arange(0, sim.shape[0])
106
+ mask = ids1 == nn21[nn12]
107
  matches = np.stack([ids1[mask], nn12[mask]])
108
  return matches.transpose()
109
 
110
 
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(description="ALike Demo.")
113
+ parser.add_argument(
114
+ "input",
115
+ type=str,
116
+ default="",
117
+ help='Image directory or movie file or "camera0" (for webcam0).',
118
+ )
119
+ parser.add_argument(
120
+ "--model",
121
+ choices=["alike-t", "alike-s", "alike-n", "alike-l"],
122
+ default="alike-t",
123
+ help="The model configuration",
124
+ )
125
+ parser.add_argument(
126
+ "--device", type=str, default="cuda", help="Running device (default: cuda)."
127
+ )
128
+ parser.add_argument(
129
+ "--top_k",
130
+ type=int,
131
+ default=-1,
132
+ help="Detect top K keypoints. -1 for threshold based mode, >0 for top K mode. (default: -1)",
133
+ )
134
+ parser.add_argument(
135
+ "--scores_th",
136
+ type=float,
137
+ default=0.2,
138
+ help="Detector score threshold (default: 0.2).",
139
+ )
140
+ parser.add_argument(
141
+ "--n_limit",
142
+ type=int,
143
+ default=5000,
144
+ help="Maximum number of keypoints to be detected (default: 5000).",
145
+ )
146
+ parser.add_argument(
147
+ "--no_display",
148
+ action="store_true",
149
+ help="Do not display images to screen. Useful if running remotely (default: False).",
150
+ )
151
+ parser.add_argument(
152
+ "--no_sub_pixel",
153
+ action="store_true",
154
+ help="Do not detect sub-pixel keypoints (default: False).",
155
+ )
156
  args = parser.parse_args()
157
 
158
  logging.basicConfig(level=logging.INFO)
159
 
160
  image_loader = ImageLoader(args.input)
161
+ model = ALike(
162
+ **configs[args.model],
163
+ device=args.device,
164
+ top_k=args.top_k,
165
+ scores_th=args.scores_th,
166
+ n_limit=args.n_limit,
167
+ )
168
  tracker = SimpleTracker()
169
 
170
  if not args.no_display:
 
176
  for img in progress_bar:
177
  if img is None:
178
  break
179
+
180
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
181
  pred = model(img_rgb, sub_pixel=not args.no_sub_pixel)
182
+ kpts = pred["keypoints"]
183
+ desc = pred["descriptors"]
184
+ runtime.append(pred["time"])
185
 
186
  out, N_matches = tracker.update(img, kpts, desc)
187
 
188
+ ave_fps = (1.0 / np.stack(runtime)).mean()
189
  status = f"Fps:{ave_fps:.1f}, Keypoints/Matches: {len(kpts)}/{N_matches}"
190
  progress_bar.set_description(status)
191
 
192
  if not args.no_display:
193
+ cv2.setWindowTitle(args.model, args.model + ": " + status)
194
  cv2.imshow(args.model, out)
195
+ if cv2.waitKey(1) == ord("q"):
196
  break
197
 
198
+ logging.info("Finished!")
199
  if not args.no_display:
200
+ logging.info("Press any key to exit!")
201
  cv2.waitKey()
third_party/ALIKE/hseq/eval.py CHANGED
@@ -6,29 +6,53 @@ import numpy as np
6
  from extract import extract_method
7
 
8
  use_cuda = torch.cuda.is_available()
9
- device = torch.device('cuda' if use_cuda else 'cpu')
10
-
11
- methods = ['d2', 'lfnet', 'superpoint', 'r2d2', 'aslfeat', 'disk',
12
- 'alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
13
- names = ['D2-Net(MS)', 'LF-Net(MS)', 'SuperPoint', 'R2D2(MS)', 'ASLFeat(MS)', 'DISK',
14
- 'ALike-N', 'ALike-L', 'ALike-N(MS)', 'ALike-L(MS)']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  top_k = None
17
  n_i = 52
18
  n_v = 56
19
- cache_dir = 'hseq/cache'
20
- dataset_path = 'hseq/hpatches-sequences-release'
21
 
22
 
23
- def generate_read_function(method, extension='ppm'):
24
  def read_function(seq_name, im_idx):
25
- aux = np.load(os.path.join(dataset_path, seq_name, '%d.%s.%s' % (im_idx, extension, method)))
 
 
 
 
26
  if top_k is None:
27
- return aux['keypoints'], aux['descriptors']
28
  else:
29
- assert ('scores' in aux)
30
- ids = np.argsort(aux['scores'])[-top_k:]
31
- return aux['keypoints'][ids, :], aux['descriptors'][ids, :]
32
 
33
  return read_function
34
 
@@ -39,7 +63,7 @@ def mnn_matcher(descriptors_a, descriptors_b):
39
  nn12 = torch.max(sim, dim=1)[1]
40
  nn21 = torch.max(sim, dim=0)[1]
41
  ids1 = torch.arange(0, sim.shape[0], device=device)
42
- mask = (ids1 == nn21[nn12])
43
  matches = torch.stack([ids1[mask], nn12[mask]])
44
  return matches.t().data.cpu().numpy()
45
 
@@ -73,7 +97,7 @@ def benchmark_features(read_feats):
73
  n_feats.append(keypoints_a.shape[0])
74
 
75
  # =========== compute homography
76
- ref_img = cv2.imread(os.path.join(dataset_path, seq_name, '1.ppm'))
77
  ref_img_shape = ref_img.shape
78
 
79
  for im_idx in range(2, 7):
@@ -82,17 +106,19 @@ def benchmark_features(read_feats):
82
 
83
  matches = mnn_matcher(
84
  torch.from_numpy(descriptors_a).to(device=device),
85
- torch.from_numpy(descriptors_b).to(device=device)
86
  )
87
 
88
- homography = np.loadtxt(os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx)))
 
 
89
 
90
- pos_a = keypoints_a[matches[:, 0], : 2]
91
  pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
92
  pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
93
- pos_b_proj = pos_b_proj_h[:, : 2] / pos_b_proj_h[:, 2:]
94
 
95
- pos_b = keypoints_b[matches[:, 1], : 2]
96
 
97
  dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
98
 
@@ -103,28 +129,37 @@ def benchmark_features(read_feats):
103
  dist = np.array([float("inf")])
104
 
105
  for thr in rng:
106
- if seq_name[0] == 'i':
107
  i_err[thr] += np.mean(dist <= thr)
108
  else:
109
  v_err[thr] += np.mean(dist <= thr)
110
 
111
  # =========== compute homography
112
  gt_homo = homography
113
- pred_homo, _ = cv2.findHomography(keypoints_a[matches[:, 0], : 2], keypoints_b[matches[:, 1], : 2],
114
- cv2.RANSAC)
 
 
 
115
  if pred_homo is None:
116
  homo_dist = np.array([float("inf")])
117
  else:
118
- corners = np.array([[0, 0],
119
- [ref_img_shape[1] - 1, 0],
120
- [0, ref_img_shape[0] - 1],
121
- [ref_img_shape[1] - 1, ref_img_shape[0] - 1]])
 
 
 
 
122
  real_warped_corners = homo_trans(corners, gt_homo)
123
  warped_corners = homo_trans(corners, pred_homo)
124
- homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
 
 
125
 
126
  for thr in rng:
127
- if seq_name[0] == 'i':
128
  i_err_homo[thr] += np.mean(homo_dist <= thr)
129
  else:
130
  v_err_homo[thr] += np.mean(homo_dist <= thr)
@@ -136,10 +171,10 @@ def benchmark_features(read_feats):
136
  return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
137
 
138
 
139
- if __name__ == '__main__':
140
  errors = {}
141
  for method in methods:
142
- output_file = os.path.join(cache_dir, method + '.npy')
143
  read_function = generate_read_function(method)
144
  if os.path.exists(output_file):
145
  errors[method] = np.load(output_file, allow_pickle=True)
@@ -152,11 +187,11 @@ if __name__ == '__main__':
152
  i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
153
 
154
  print(f"====={name}=====")
155
- print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end='')
156
  for thr in range(1, 4):
157
  err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
158
- print(f"{err * 100:.2f}%", end=' ')
159
  for thr in range(1, 4):
160
  err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
161
- print(f"{err_hom * 100:.2f}%", end=' ')
162
- print('')
 
6
  from extract import extract_method
7
 
8
  use_cuda = torch.cuda.is_available()
9
+ device = torch.device("cuda" if use_cuda else "cpu")
10
+
11
+ methods = [
12
+ "d2",
13
+ "lfnet",
14
+ "superpoint",
15
+ "r2d2",
16
+ "aslfeat",
17
+ "disk",
18
+ "alike-n",
19
+ "alike-l",
20
+ "alike-n-ms",
21
+ "alike-l-ms",
22
+ ]
23
+ names = [
24
+ "D2-Net(MS)",
25
+ "LF-Net(MS)",
26
+ "SuperPoint",
27
+ "R2D2(MS)",
28
+ "ASLFeat(MS)",
29
+ "DISK",
30
+ "ALike-N",
31
+ "ALike-L",
32
+ "ALike-N(MS)",
33
+ "ALike-L(MS)",
34
+ ]
35
 
36
  top_k = None
37
  n_i = 52
38
  n_v = 56
39
+ cache_dir = "hseq/cache"
40
+ dataset_path = "hseq/hpatches-sequences-release"
41
 
42
 
43
+ def generate_read_function(method, extension="ppm"):
44
  def read_function(seq_name, im_idx):
45
+ aux = np.load(
46
+ os.path.join(
47
+ dataset_path, seq_name, "%d.%s.%s" % (im_idx, extension, method)
48
+ )
49
+ )
50
  if top_k is None:
51
+ return aux["keypoints"], aux["descriptors"]
52
  else:
53
+ assert "scores" in aux
54
+ ids = np.argsort(aux["scores"])[-top_k:]
55
+ return aux["keypoints"][ids, :], aux["descriptors"][ids, :]
56
 
57
  return read_function
58
 
 
63
  nn12 = torch.max(sim, dim=1)[1]
64
  nn21 = torch.max(sim, dim=0)[1]
65
  ids1 = torch.arange(0, sim.shape[0], device=device)
66
+ mask = ids1 == nn21[nn12]
67
  matches = torch.stack([ids1[mask], nn12[mask]])
68
  return matches.t().data.cpu().numpy()
69
 
 
97
  n_feats.append(keypoints_a.shape[0])
98
 
99
  # =========== compute homography
100
+ ref_img = cv2.imread(os.path.join(dataset_path, seq_name, "1.ppm"))
101
  ref_img_shape = ref_img.shape
102
 
103
  for im_idx in range(2, 7):
 
106
 
107
  matches = mnn_matcher(
108
  torch.from_numpy(descriptors_a).to(device=device),
109
+ torch.from_numpy(descriptors_b).to(device=device),
110
  )
111
 
112
+ homography = np.loadtxt(
113
+ os.path.join(dataset_path, seq_name, "H_1_" + str(im_idx))
114
+ )
115
 
116
+ pos_a = keypoints_a[matches[:, 0], :2]
117
  pos_a_h = np.concatenate([pos_a, np.ones([matches.shape[0], 1])], axis=1)
118
  pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
119
+ pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
120
 
121
+ pos_b = keypoints_b[matches[:, 1], :2]
122
 
123
  dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
124
 
 
129
  dist = np.array([float("inf")])
130
 
131
  for thr in rng:
132
+ if seq_name[0] == "i":
133
  i_err[thr] += np.mean(dist <= thr)
134
  else:
135
  v_err[thr] += np.mean(dist <= thr)
136
 
137
  # =========== compute homography
138
  gt_homo = homography
139
+ pred_homo, _ = cv2.findHomography(
140
+ keypoints_a[matches[:, 0], :2],
141
+ keypoints_b[matches[:, 1], :2],
142
+ cv2.RANSAC,
143
+ )
144
  if pred_homo is None:
145
  homo_dist = np.array([float("inf")])
146
  else:
147
+ corners = np.array(
148
+ [
149
+ [0, 0],
150
+ [ref_img_shape[1] - 1, 0],
151
+ [0, ref_img_shape[0] - 1],
152
+ [ref_img_shape[1] - 1, ref_img_shape[0] - 1],
153
+ ]
154
+ )
155
  real_warped_corners = homo_trans(corners, gt_homo)
156
  warped_corners = homo_trans(corners, pred_homo)
157
+ homo_dist = np.mean(
158
+ np.linalg.norm(real_warped_corners - warped_corners, axis=1)
159
+ )
160
 
161
  for thr in rng:
162
+ if seq_name[0] == "i":
163
  i_err_homo[thr] += np.mean(homo_dist <= thr)
164
  else:
165
  v_err_homo[thr] += np.mean(homo_dist <= thr)
 
171
  return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
172
 
173
 
174
+ if __name__ == "__main__":
175
  errors = {}
176
  for method in methods:
177
+ output_file = os.path.join(cache_dir, method + ".npy")
178
  read_function = generate_read_function(method)
179
  if os.path.exists(output_file):
180
  errors[method] = np.load(output_file, allow_pickle=True)
 
187
  i_err, v_err, i_err_hom, v_err_hom, _ = errors[method]
188
 
189
  print(f"====={name}=====")
190
+ print(f"MMA@1 MMA@2 MMA@3 MHA@1 MHA@2 MHA@3: ", end="")
191
  for thr in range(1, 4):
192
  err = (i_err[thr] + v_err[thr]) / ((n_i + n_v) * 5)
193
+ print(f"{err * 100:.2f}%", end=" ")
194
  for thr in range(1, 4):
195
  err_hom = (i_err_hom[thr] + v_err_hom[thr]) / ((n_i + n_v) * 5)
196
+ print(f"{err_hom * 100:.2f}%", end=" ")
197
+ print("")
third_party/ALIKE/hseq/extract.py CHANGED
@@ -9,23 +9,23 @@ from tqdm import tqdm
9
  from copy import deepcopy
10
  from torchvision.transforms import ToTensor
11
 
12
- sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
13
  from alike import ALike, configs
14
 
15
- dataset_root = 'hseq/hpatches-sequences-release'
16
  use_cuda = torch.cuda.is_available()
17
- device = 'cuda' if use_cuda else 'cpu'
18
- methods = ['alike-n', 'alike-l', 'alike-n-ms', 'alike-l-ms']
19
 
20
 
21
  class HPatchesDataset(data.Dataset):
22
- def __init__(self, root: str = dataset_root, alteration: str = 'all'):
23
  """
24
  Args:
25
  root: dataset root path
26
  alteration: # 'all', 'i' for illumination or 'v' for viewpoint
27
  """
28
- assert (Path(root).exists()), f"Dataset root path {root} dose not exist!"
29
  self.root = root
30
 
31
  # get all image file name
@@ -35,15 +35,15 @@ class HPatchesDataset(data.Dataset):
35
  folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
36
  self.seqs = []
37
  for folder in folders:
38
- if alteration == 'i' and folder.stem[0] != 'i':
39
  continue
40
- if alteration == 'v' and folder.stem[0] != 'v':
41
  continue
42
 
43
  self.seqs.append(folder)
44
 
45
  self.len = len(self.seqs)
46
- assert (self.len > 0), f'Can not find PatchDataset in path {self.root}'
47
 
48
  def __getitem__(self, item):
49
  folder = self.seqs[item]
@@ -51,12 +51,12 @@ class HPatchesDataset(data.Dataset):
51
  imgs = []
52
  homos = []
53
  for i in range(1, 7):
54
- img = cv2.imread(str(folder / f'{i}.ppm'), cv2.IMREAD_COLOR)
55
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC
56
  imgs.append(img)
57
 
58
  if i != 1:
59
- homo = np.loadtxt(str(folder / f'H_1_{i}')).astype('float32')
60
  homos.append(homo)
61
 
62
  return imgs, homos, folder.stem
@@ -68,11 +68,18 @@ class HPatchesDataset(data.Dataset):
68
  return self.__class__
69
 
70
 
71
- def extract_multiscale(model, img, scale_f=2 ** 0.5,
72
- min_scale=1., max_scale=1.,
73
- min_size=0., max_size=99999.,
74
- image_size_max=99999,
75
- n_k=0, sort=False):
 
 
 
 
 
 
 
76
  H_, W_, three = img.shape
77
  assert three == 3, "input image shape should be [HxWx3]"
78
 
@@ -100,7 +107,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
100
  # extract descriptors
101
  with torch.no_grad():
102
  descriptor_map, scores_map = model.extract_dense_map(image)
103
- keypoints_, descriptors_, scores_, _ = model.dkd(scores_map, descriptor_map)
 
 
104
 
105
  keypoints.append(keypoints_[0])
106
  descriptors.append(descriptors_[0])
@@ -110,7 +119,9 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
110
 
111
  # down-scale the image for next iteration
112
  nh, nw = round(H * s), round(W * s)
113
- image = torch.nn.functional.interpolate(image, (nh, nw), mode='bilinear', align_corners=False)
 
 
114
 
115
  # restore value
116
  torch.backends.cudnn.benchmark = old_bm
@@ -131,29 +142,34 @@ def extract_multiscale(model, img, scale_f=2 ** 0.5,
131
  descriptors = descriptors[0:n_k]
132
  scores = scores[0:n_k]
133
 
134
- return {'keypoints': keypoints, 'descriptors': descriptors, 'scores': scores}
135
 
136
 
137
  def extract_method(m):
138
- hpatches = HPatchesDataset(root=dataset_root, alteration='all')
139
  model = m[:7]
140
- min_scale = 0.3 if m[8:] == 'ms' else 1.0
141
 
142
  model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
143
 
144
- progbar = tqdm(hpatches, desc='Extracting for {}'.format(m))
145
  for imgs, homos, seq_name in progbar:
146
  for i in range(1, 7):
147
  img = imgs[i - 1]
148
- pred = extract_multiscale(model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000)
149
- kpts, descs, scores = pred['keypoints'], pred['descriptors'], pred['scores']
 
 
150
 
151
- with open(os.path.join(dataset_root, seq_name, f'{i}.ppm.{m}'), 'wb') as f:
152
- np.savez(f, keypoints=kpts.cpu().numpy(),
153
- scores=scores.cpu().numpy(),
154
- descriptors=descs.cpu().numpy())
 
 
 
155
 
156
 
157
- if __name__ == '__main__':
158
  for method in methods:
159
  extract_method(method)
 
9
  from copy import deepcopy
10
  from torchvision.transforms import ToTensor
11
 
12
+ sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
13
  from alike import ALike, configs
14
 
15
+ dataset_root = "hseq/hpatches-sequences-release"
16
  use_cuda = torch.cuda.is_available()
17
+ device = "cuda" if use_cuda else "cpu"
18
+ methods = ["alike-n", "alike-l", "alike-n-ms", "alike-l-ms"]
19
 
20
 
21
  class HPatchesDataset(data.Dataset):
22
+ def __init__(self, root: str = dataset_root, alteration: str = "all"):
23
  """
24
  Args:
25
  root: dataset root path
26
  alteration: # 'all', 'i' for illumination or 'v' for viewpoint
27
  """
28
+ assert Path(root).exists(), f"Dataset root path {root} dose not exist!"
29
  self.root = root
30
 
31
  # get all image file name
 
35
  folders = [x for x in Path(self.root).iterdir() if x.is_dir()]
36
  self.seqs = []
37
  for folder in folders:
38
+ if alteration == "i" and folder.stem[0] != "i":
39
  continue
40
+ if alteration == "v" and folder.stem[0] != "v":
41
  continue
42
 
43
  self.seqs.append(folder)
44
 
45
  self.len = len(self.seqs)
46
+ assert self.len > 0, f"Can not find PatchDataset in path {self.root}"
47
 
48
  def __getitem__(self, item):
49
  folder = self.seqs[item]
 
51
  imgs = []
52
  homos = []
53
  for i in range(1, 7):
54
+ img = cv2.imread(str(folder / f"{i}.ppm"), cv2.IMREAD_COLOR)
55
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # HxWxC
56
  imgs.append(img)
57
 
58
  if i != 1:
59
+ homo = np.loadtxt(str(folder / f"H_1_{i}")).astype("float32")
60
  homos.append(homo)
61
 
62
  return imgs, homos, folder.stem
 
68
  return self.__class__
69
 
70
 
71
+ def extract_multiscale(
72
+ model,
73
+ img,
74
+ scale_f=2**0.5,
75
+ min_scale=1.0,
76
+ max_scale=1.0,
77
+ min_size=0.0,
78
+ max_size=99999.0,
79
+ image_size_max=99999,
80
+ n_k=0,
81
+ sort=False,
82
+ ):
83
  H_, W_, three = img.shape
84
  assert three == 3, "input image shape should be [HxWx3]"
85
 
 
107
  # extract descriptors
108
  with torch.no_grad():
109
  descriptor_map, scores_map = model.extract_dense_map(image)
110
+ keypoints_, descriptors_, scores_, _ = model.dkd(
111
+ scores_map, descriptor_map
112
+ )
113
 
114
  keypoints.append(keypoints_[0])
115
  descriptors.append(descriptors_[0])
 
119
 
120
  # down-scale the image for next iteration
121
  nh, nw = round(H * s), round(W * s)
122
+ image = torch.nn.functional.interpolate(
123
+ image, (nh, nw), mode="bilinear", align_corners=False
124
+ )
125
 
126
  # restore value
127
  torch.backends.cudnn.benchmark = old_bm
 
142
  descriptors = descriptors[0:n_k]
143
  scores = scores[0:n_k]
144
 
145
+ return {"keypoints": keypoints, "descriptors": descriptors, "scores": scores}
146
 
147
 
148
  def extract_method(m):
149
+ hpatches = HPatchesDataset(root=dataset_root, alteration="all")
150
  model = m[:7]
151
+ min_scale = 0.3 if m[8:] == "ms" else 1.0
152
 
153
  model = ALike(**configs[model], device=device, top_k=0, scores_th=0.2, n_limit=5000)
154
 
155
+ progbar = tqdm(hpatches, desc="Extracting for {}".format(m))
156
  for imgs, homos, seq_name in progbar:
157
  for i in range(1, 7):
158
  img = imgs[i - 1]
159
+ pred = extract_multiscale(
160
+ model, img, min_scale=min_scale, max_scale=1, sort=False, n_k=5000
161
+ )
162
+ kpts, descs, scores = pred["keypoints"], pred["descriptors"], pred["scores"]
163
 
164
+ with open(os.path.join(dataset_root, seq_name, f"{i}.ppm.{m}"), "wb") as f:
165
+ np.savez(
166
+ f,
167
+ keypoints=kpts.cpu().numpy(),
168
+ scores=scores.cpu().numpy(),
169
+ descriptors=descs.cpu().numpy(),
170
+ )
171
 
172
 
173
+ if __name__ == "__main__":
174
  for method in methods:
175
  extract_method(method)
third_party/ALIKE/soft_detect.py CHANGED
@@ -17,13 +17,15 @@ import torch.nn.functional as F
17
  # v
18
  # [ y: range=-1.0~1.0; h: range=0~H ]
19
 
 
20
  def simple_nms(scores, nms_radius: int):
21
- """ Fast Non-maximum suppression to remove nearby points """
22
- assert (nms_radius >= 0)
23
 
24
  def max_pool(x):
25
  return torch.nn.functional.max_pool2d(
26
- x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
 
27
 
28
  zeros = torch.zeros_like(scores)
29
  max_mask = scores == max_pool(scores)
@@ -50,8 +52,14 @@ def sample_descriptor(descriptor_map, kpts, bilinear_interp=False):
50
  kptsi = kpts[index] # Nx2,(x,y)
51
 
52
  if bilinear_interp:
53
- descriptors_ = torch.nn.functional.grid_sample(descriptor_map[index].unsqueeze(0), kptsi.view(1, 1, -1, 2),
54
- mode='bilinear', align_corners=True)[0, :, 0, :] # CxN
 
 
 
 
 
 
55
  else:
56
  kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
57
  kptsi = kptsi.long()
@@ -94,10 +102,10 @@ class DKD(nn.Module):
94
  nms_scores = simple_nms(scores_nograd, 2)
95
 
96
  # remove border
97
- nms_scores[:, :, :self.radius + 1, :] = 0
98
- nms_scores[:, :, :, :self.radius + 1] = 0
99
- nms_scores[:, :, h - self.radius:, :] = 0
100
- nms_scores[:, :, :, w - self.radius:] = 0
101
 
102
  # detect keypoints without grad
103
  if self.top_k > 0:
@@ -121,7 +129,7 @@ class DKD(nn.Module):
121
  if len(indices) > self.n_limit:
122
  kpts_sc = scores[indices]
123
  sort_idx = kpts_sc.sort(descending=True)[1]
124
- sel_idx = sort_idx[:self.n_limit]
125
  indices = indices[sel_idx]
126
  indices_keypoints.append(indices)
127
 
@@ -134,42 +142,73 @@ class DKD(nn.Module):
134
  self.hw_grid = self.hw_grid.to(patches) # to device
135
  for b_idx in range(b):
136
  patch = patches[b_idx].t() # (H*W) x (kernel**2)
137
- indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
 
 
138
  patch_scores = patch[indices_kpt] # M x (kernel**2)
139
 
140
  # max is detached to prevent undesired backprop loops in the graph
141
  max_v = patch_scores.max(dim=1).values.detach()[:, None]
142
- x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1]
 
 
143
 
144
  # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
145
- xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2
146
-
147
- hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius,
148
- dim=-1) ** 2
 
 
 
 
 
 
 
 
149
  scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
150
 
151
  # compute result keypoints
152
- keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
 
 
153
  keypoints_xy = keypoints_xy_nms + xy_residual
154
- keypoints_xy = keypoints_xy / keypoints_xy.new_tensor(
155
- [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
156
-
157
- kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
158
- keypoints_xy.view(1, 1, -1, 2),
159
- mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
 
 
 
 
 
 
160
 
161
  keypoints.append(keypoints_xy)
162
  scoredispersitys.append(scoredispersity)
163
  kptscores.append(kptscore)
164
  else:
165
  for b_idx in range(b):
166
- indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
167
- keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
168
- keypoints_xy = keypoints_xy_nms / keypoints_xy_nms.new_tensor(
169
- [w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
170
- kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0),
171
- keypoints_xy.view(1, 1, -1, 2),
172
- mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
 
 
 
 
 
 
 
 
 
 
 
173
  keypoints.append(keypoints_xy)
174
  scoredispersitys.append(None)
175
  kptscores.append(kptscore)
@@ -183,8 +222,9 @@ class DKD(nn.Module):
183
  :param sub_pixel: whether to use sub-pixel keypoint detection
184
  :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
185
  """
186
- keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map,
187
- sub_pixel)
 
188
 
189
  descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
190
 
 
17
  # v
18
  # [ y: range=-1.0~1.0; h: range=0~H ]
19
 
20
+
21
  def simple_nms(scores, nms_radius: int):
22
+ """Fast Non-maximum suppression to remove nearby points"""
23
+ assert nms_radius >= 0
24
 
25
  def max_pool(x):
26
  return torch.nn.functional.max_pool2d(
27
+ x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
28
+ )
29
 
30
  zeros = torch.zeros_like(scores)
31
  max_mask = scores == max_pool(scores)
 
52
  kptsi = kpts[index] # Nx2,(x,y)
53
 
54
  if bilinear_interp:
55
+ descriptors_ = torch.nn.functional.grid_sample(
56
+ descriptor_map[index].unsqueeze(0),
57
+ kptsi.view(1, 1, -1, 2),
58
+ mode="bilinear",
59
+ align_corners=True,
60
+ )[
61
+ 0, :, 0, :
62
+ ] # CxN
63
  else:
64
  kptsi = (kptsi + 1) / 2 * kptsi.new_tensor([[width - 1, height - 1]])
65
  kptsi = kptsi.long()
 
102
  nms_scores = simple_nms(scores_nograd, 2)
103
 
104
  # remove border
105
+ nms_scores[:, :, : self.radius + 1, :] = 0
106
+ nms_scores[:, :, :, : self.radius + 1] = 0
107
+ nms_scores[:, :, h - self.radius :, :] = 0
108
+ nms_scores[:, :, :, w - self.radius :] = 0
109
 
110
  # detect keypoints without grad
111
  if self.top_k > 0:
 
129
  if len(indices) > self.n_limit:
130
  kpts_sc = scores[indices]
131
  sort_idx = kpts_sc.sort(descending=True)[1]
132
+ sel_idx = sort_idx[: self.n_limit]
133
  indices = indices[sel_idx]
134
  indices_keypoints.append(indices)
135
 
 
142
  self.hw_grid = self.hw_grid.to(patches) # to device
143
  for b_idx in range(b):
144
  patch = patches[b_idx].t() # (H*W) x (kernel**2)
145
+ indices_kpt = indices_keypoints[
146
+ b_idx
147
+ ] # one dimension vector, say its size is M
148
  patch_scores = patch[indices_kpt] # M x (kernel**2)
149
 
150
  # max is detached to prevent undesired backprop loops in the graph
151
  max_v = patch_scores.max(dim=1).values.detach()[:, None]
152
+ x_exp = (
153
+ (patch_scores - max_v) / self.temperature
154
+ ).exp() # M * (kernel**2), in [0, 1]
155
 
156
  # \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
157
+ xy_residual = (
158
+ x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
159
+ ) # Soft-argmax, Mx2
160
+
161
+ hw_grid_dist2 = (
162
+ torch.norm(
163
+ (self.hw_grid[None, :, :] - xy_residual[:, None, :])
164
+ / self.radius,
165
+ dim=-1,
166
+ )
167
+ ** 2
168
+ )
169
  scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
170
 
171
  # compute result keypoints
172
+ keypoints_xy_nms = torch.stack(
173
+ [indices_kpt % w, indices_kpt // w], dim=1
174
+ ) # Mx2
175
  keypoints_xy = keypoints_xy_nms + xy_residual
176
+ keypoints_xy = (
177
+ keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1
178
+ ) # (w,h) -> (-1~1,-1~1)
179
+
180
+ kptscore = torch.nn.functional.grid_sample(
181
+ scores_map[b_idx].unsqueeze(0),
182
+ keypoints_xy.view(1, 1, -1, 2),
183
+ mode="bilinear",
184
+ align_corners=True,
185
+ )[
186
+ 0, 0, 0, :
187
+ ] # CxN
188
 
189
  keypoints.append(keypoints_xy)
190
  scoredispersitys.append(scoredispersity)
191
  kptscores.append(kptscore)
192
  else:
193
  for b_idx in range(b):
194
+ indices_kpt = indices_keypoints[
195
+ b_idx
196
+ ] # one dimension vector, say its size is M
197
+ keypoints_xy_nms = torch.stack(
198
+ [indices_kpt % w, indices_kpt // w], dim=1
199
+ ) # Mx2
200
+ keypoints_xy = (
201
+ keypoints_xy_nms / keypoints_xy_nms.new_tensor([w - 1, h - 1]) * 2
202
+ - 1
203
+ ) # (w,h) -> (-1~1,-1~1)
204
+ kptscore = torch.nn.functional.grid_sample(
205
+ scores_map[b_idx].unsqueeze(0),
206
+ keypoints_xy.view(1, 1, -1, 2),
207
+ mode="bilinear",
208
+ align_corners=True,
209
+ )[
210
+ 0, 0, 0, :
211
+ ] # CxN
212
  keypoints.append(keypoints_xy)
213
  scoredispersitys.append(None)
214
  kptscores.append(kptscore)
 
222
  :param sub_pixel: whether to use sub-pixel keypoint detection
223
  :return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
224
  """
225
+ keypoints, scoredispersitys, kptscores = self.detect_keypoints(
226
+ scores_map, sub_pixel
227
+ )
228
 
229
  descriptors = sample_descriptor(descriptor_map, keypoints, sub_pixel)
230
 
third_party/ASpanFormer/configs/aspan/indoor/aspan_test.py CHANGED
@@ -1,10 +1,11 @@
1
  import sys
2
  from pathlib import Path
3
- sys.path.append(str(Path(__file__).parent / '../../../'))
 
4
  from src.config.default import _CN as cfg
5
 
6
- cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
7
 
8
  cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
9
- cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
10
- cfg.ASPAN.COARSE.TRAIN_RES = [480,640]
 
1
  import sys
2
  from pathlib import Path
3
+
4
+ sys.path.append(str(Path(__file__).parent / "../../../"))
5
  from src.config.default import _CN as cfg
6
 
7
+ cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
8
 
9
  cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
10
+ cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
11
+ cfg.ASPAN.COARSE.TRAIN_RES = [480, 640]
third_party/ASpanFormer/configs/aspan/indoor/aspan_train.py CHANGED
@@ -1,10 +1,11 @@
1
  import sys
2
  from pathlib import Path
3
- sys.path.append(str(Path(__file__).parent / '../../../'))
 
4
  from src.config.default import _CN as cfg
5
 
6
- cfg.ASPAN.COARSE.COARSEST_LEVEL= [15,20]
7
- cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
8
 
9
  cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
10
  cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
 
1
  import sys
2
  from pathlib import Path
3
+
4
+ sys.path.append(str(Path(__file__).parent / "../../../"))
5
  from src.config.default import _CN as cfg
6
 
7
+ cfg.ASPAN.COARSE.COARSEST_LEVEL = [15, 20]
8
+ cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
9
 
10
  cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
11
  cfg.ASPAN.MATCH_COARSE.BORDER_RM = 0
third_party/ASpanFormer/configs/aspan/outdoor/aspan_test.py CHANGED
@@ -1,12 +1,13 @@
1
  import sys
2
  from pathlib import Path
3
- sys.path.append(str(Path(__file__).parent / '../../../'))
 
4
  from src.config.default import _CN as cfg
5
 
6
- cfg.ASPAN.COARSE.COARSEST_LEVEL= [36,36]
7
- cfg.ASPAN.COARSE.TRAIN_RES = [832,832]
8
- cfg.ASPAN.COARSE.TEST_RES = [1152,1152]
9
- cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
10
 
11
  cfg.TRAINER.CANONICAL_LR = 8e-3
12
  cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
 
1
  import sys
2
  from pathlib import Path
3
+
4
+ sys.path.append(str(Path(__file__).parent / "../../../"))
5
  from src.config.default import _CN as cfg
6
 
7
+ cfg.ASPAN.COARSE.COARSEST_LEVEL = [36, 36]
8
+ cfg.ASPAN.COARSE.TRAIN_RES = [832, 832]
9
+ cfg.ASPAN.COARSE.TEST_RES = [1152, 1152]
10
+ cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
11
 
12
  cfg.TRAINER.CANONICAL_LR = 8e-3
13
  cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs
third_party/ASpanFormer/configs/aspan/outdoor/aspan_train.py CHANGED
@@ -1,10 +1,11 @@
1
  import sys
2
  from pathlib import Path
3
- sys.path.append(str(Path(__file__).parent / '../../../'))
 
4
  from src.config.default import _CN as cfg
5
 
6
- cfg.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
7
- cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax'
8
  cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
9
 
10
  cfg.TRAINER.CANONICAL_LR = 8e-3
 
1
  import sys
2
  from pathlib import Path
3
+
4
+ sys.path.append(str(Path(__file__).parent / "../../../"))
5
  from src.config.default import _CN as cfg
6
 
7
+ cfg.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
8
+ cfg.ASPAN.MATCH_COARSE.MATCH_TYPE = "dual_softmax"
9
  cfg.ASPAN.MATCH_COARSE.SPARSE_SPVS = False
10
 
11
  cfg.TRAINER.CANONICAL_LR = 8e-3
third_party/ASpanFormer/configs/data/base.py CHANGED
@@ -4,6 +4,7 @@ Setups in data configs will override all existed setups!
4
  """
5
 
6
  from yacs.config import CfgNode as CN
 
7
  _CN = CN()
8
  _CN.DATASET = CN()
9
  _CN.TRAINER = CN()
 
4
  """
5
 
6
  from yacs.config import CfgNode as CN
7
+
8
  _CN = CN()
9
  _CN.DATASET = CN()
10
  _CN.TRAINER = CN()
third_party/ASpanFormer/configs/data/megadepth_test_1500.py CHANGED
@@ -8,6 +8,6 @@ cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}"
8
  cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
9
 
10
  cfg.DATASET.MGDPT_IMG_RESIZE = 1152
11
- cfg.DATASET.MGDPT_IMG_PAD=True
12
- cfg.DATASET.MGDPT_DF =8
13
- cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
 
8
  cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt"
9
 
10
  cfg.DATASET.MGDPT_IMG_RESIZE = 1152
11
+ cfg.DATASET.MGDPT_IMG_PAD = True
12
+ cfg.DATASET.MGDPT_DF = 8
13
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
third_party/ASpanFormer/configs/data/megadepth_trainval_832.py CHANGED
@@ -11,9 +11,13 @@ cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0
11
  TEST_BASE_PATH = "data/megadepth/index"
12
  cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
13
  cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
14
- cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500"
15
- cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
16
- cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
 
 
 
 
17
 
18
  # 368 scenes in total for MegaDepth
19
  # (with difficulty balanced (further split each scene to 3 sub-scenes))
 
11
  TEST_BASE_PATH = "data/megadepth/index"
12
  cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth"
13
  cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test"
14
+ cfg.DATASET.VAL_NPZ_ROOT = (
15
+ cfg.DATASET.TEST_NPZ_ROOT
16
+ ) = f"{TEST_BASE_PATH}/scene_info_val_1500"
17
+ cfg.DATASET.VAL_LIST_PATH = (
18
+ cfg.DATASET.TEST_LIST_PATH
19
+ ) = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt"
20
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
21
 
22
  # 368 scenes in total for MegaDepth
23
  # (with difficulty balanced (further split each scene to 3 sub-scenes))
third_party/ASpanFormer/configs/data/scannet_trainval.py CHANGED
@@ -12,6 +12,10 @@ TEST_BASE_PATH = "assets/scannet_test_1500"
12
  cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
13
  cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
14
  cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
15
- cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt"
16
- cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz"
17
- cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
 
 
 
 
 
12
  cfg.DATASET.TEST_DATA_SOURCE = "ScanNet"
13
  cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test"
14
  cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH
15
+ cfg.DATASET.VAL_LIST_PATH = (
16
+ cfg.DATASET.TEST_LIST_PATH
17
+ ) = f"{TEST_BASE_PATH}/scannet_test.txt"
18
+ cfg.DATASET.VAL_INTRINSIC_PATH = (
19
+ cfg.DATASET.TEST_INTRINSIC_PATH
20
+ ) = f"{TEST_BASE_PATH}/intrinsics.npz"
21
+ cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val
third_party/ASpanFormer/demo/demo.py CHANGED
@@ -1,63 +1,91 @@
1
  import os
2
  import sys
 
3
  ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
4
  sys.path.insert(0, ROOT_DIR)
5
 
6
- from src.ASpanFormer.aspanformer import ASpanFormer
7
  from src.config.default import get_cfg_defaults
8
  from src.utils.misc import lower_config
9
- import demo_utils
10
 
11
  import cv2
12
  import torch
13
  import numpy as np
14
 
15
  import argparse
 
16
  parser = argparse.ArgumentParser()
17
- parser.add_argument('--config_path', type=str, default='../configs/aspan/outdoor/aspan_test.py',
18
- help='path for config file.')
19
- parser.add_argument('--img0_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg',
20
- help='path for image0.')
21
- parser.add_argument('--img1_path', type=str, default='../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg',
22
- help='path for image1.')
23
- parser.add_argument('--weights_path', type=str, default='../weights/outdoor.ckpt',
24
- help='path for model weights.')
25
- parser.add_argument('--long_dim0', type=int, default=1024,
26
- help='resize for longest dim of image0.')
27
- parser.add_argument('--long_dim1', type=int, default=1024,
28
- help='resize for longest dim of image1.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  args = parser.parse_args()
31
 
32
 
33
- if __name__=='__main__':
34
  config = get_cfg_defaults()
35
  config.merge_from_file(args.config_path)
36
  _config = lower_config(config)
37
- matcher = ASpanFormer(config=_config['aspan'])
38
- state_dict = torch.load(args.weights_path, map_location='cpu')['state_dict']
39
- matcher.load_state_dict(state_dict,strict=False)
40
- matcher.cuda(),matcher.eval()
41
-
42
- img0,img1=cv2.imread(args.img0_path),cv2.imread(args.img1_path)
43
- img0_g,img1_g=cv2.imread(args.img0_path,0),cv2.imread(args.img1_path,0)
44
- img0,img1=demo_utils.resize(img0,args.long_dim0),demo_utils.resize(img1,args.long_dim1)
45
- img0_g,img1_g=demo_utils.resize(img0_g,args.long_dim0),demo_utils.resize(img1_g,args.long_dim1)
46
- data={'image0':torch.from_numpy(img0_g/255.)[None,None].cuda().float(),
47
- 'image1':torch.from_numpy(img1_g/255.)[None,None].cuda().float()}
48
- with torch.no_grad():
49
- matcher(data,online_resize=True)
50
- corr0,corr1=data['mkpts0_f'].cpu().numpy(),data['mkpts1_f'].cpu().numpy()
51
-
52
- F_hat,mask_F=cv2.findFundamentalMat(corr0,corr1,method=cv2.FM_RANSAC,ransacReprojThreshold=1)
 
 
 
 
 
 
 
 
53
  if mask_F is not None:
54
- mask_F=mask_F[:,0].astype(bool)
55
  else:
56
- mask_F=np.zeros_like(corr0[:,0]).astype(bool)
57
-
58
- #visualize match
59
- display=demo_utils.draw_match(img0,img1,corr0,corr1)
60
- display_ransac=demo_utils.draw_match(img0,img1,corr0[mask_F],corr1[mask_F])
61
- cv2.imwrite('match.png',display)
62
- cv2.imwrite('match_ransac.png',display_ransac)
63
- print(len(corr1),len(corr1[mask_F]))
 
1
  import os
2
  import sys
3
+
4
  ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
5
  sys.path.insert(0, ROOT_DIR)
6
 
7
+ from src.ASpanFormer.aspanformer import ASpanFormer
8
  from src.config.default import get_cfg_defaults
9
  from src.utils.misc import lower_config
10
+ import demo_utils
11
 
12
  import cv2
13
  import torch
14
  import numpy as np
15
 
16
  import argparse
17
+
18
  parser = argparse.ArgumentParser()
19
+ parser.add_argument(
20
+ "--config_path",
21
+ type=str,
22
+ default="../configs/aspan/outdoor/aspan_test.py",
23
+ help="path for config file.",
24
+ )
25
+ parser.add_argument(
26
+ "--img0_path",
27
+ type=str,
28
+ default="../assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg",
29
+ help="path for image0.",
30
+ )
31
+ parser.add_argument(
32
+ "--img1_path",
33
+ type=str,
34
+ default="../assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg",
35
+ help="path for image1.",
36
+ )
37
+ parser.add_argument(
38
+ "--weights_path",
39
+ type=str,
40
+ default="../weights/outdoor.ckpt",
41
+ help="path for model weights.",
42
+ )
43
+ parser.add_argument(
44
+ "--long_dim0", type=int, default=1024, help="resize for longest dim of image0."
45
+ )
46
+ parser.add_argument(
47
+ "--long_dim1", type=int, default=1024, help="resize for longest dim of image1."
48
+ )
49
 
50
  args = parser.parse_args()
51
 
52
 
53
+ if __name__ == "__main__":
54
  config = get_cfg_defaults()
55
  config.merge_from_file(args.config_path)
56
  _config = lower_config(config)
57
+ matcher = ASpanFormer(config=_config["aspan"])
58
+ state_dict = torch.load(args.weights_path, map_location="cpu")["state_dict"]
59
+ matcher.load_state_dict(state_dict, strict=False)
60
+ matcher.cuda(), matcher.eval()
61
+
62
+ img0, img1 = cv2.imread(args.img0_path), cv2.imread(args.img1_path)
63
+ img0_g, img1_g = cv2.imread(args.img0_path, 0), cv2.imread(args.img1_path, 0)
64
+ img0, img1 = demo_utils.resize(img0, args.long_dim0), demo_utils.resize(
65
+ img1, args.long_dim1
66
+ )
67
+ img0_g, img1_g = demo_utils.resize(img0_g, args.long_dim0), demo_utils.resize(
68
+ img1_g, args.long_dim1
69
+ )
70
+ data = {
71
+ "image0": torch.from_numpy(img0_g / 255.0)[None, None].cuda().float(),
72
+ "image1": torch.from_numpy(img1_g / 255.0)[None, None].cuda().float(),
73
+ }
74
+ with torch.no_grad():
75
+ matcher(data, online_resize=True)
76
+ corr0, corr1 = data["mkpts0_f"].cpu().numpy(), data["mkpts1_f"].cpu().numpy()
77
+
78
+ F_hat, mask_F = cv2.findFundamentalMat(
79
+ corr0, corr1, method=cv2.FM_RANSAC, ransacReprojThreshold=1
80
+ )
81
  if mask_F is not None:
82
+ mask_F = mask_F[:, 0].astype(bool)
83
  else:
84
+ mask_F = np.zeros_like(corr0[:, 0]).astype(bool)
85
+
86
+ # visualize match
87
+ display = demo_utils.draw_match(img0, img1, corr0, corr1)
88
+ display_ransac = demo_utils.draw_match(img0, img1, corr0[mask_F], corr1[mask_F])
89
+ cv2.imwrite("match.png", display)
90
+ cv2.imwrite("match_ransac.png", display_ransac)
91
+ print(len(corr1), len(corr1[mask_F]))
third_party/ASpanFormer/demo/demo_utils.py CHANGED
@@ -1,44 +1,88 @@
1
  import cv2
2
  import numpy as np
3
 
4
- def resize(image,long_dim):
5
- h,w=image.shape[0],image.shape[1]
6
- image=cv2.resize(image,(int(w*long_dim/max(h,w)),int(h*long_dim/max(h,w))))
 
 
 
7
  return image
8
 
9
- def draw_points(img,points,color=(0,255,0),radius=3):
 
10
  dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
11
  for i in range(points.shape[0]):
12
- cv2.circle(img, dp[i],radius=radius,color=color)
13
  return img
14
-
15
 
16
- def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
 
 
 
 
 
 
 
 
 
 
 
17
  if resize is not None:
18
- scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
19
- img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA)
20
- corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
21
- corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
22
- corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  assert len(corr1) == len(corr2)
25
 
26
  draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
27
  if color is None:
28
- color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
29
- if len(color)==1:
30
- display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
31
- matchColor=color[0],
32
- singlePointColor=color[0],
33
- flags=4
34
- )
 
 
 
 
 
 
35
  else:
36
- height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
37
- display=np.zeros([height,width,3],np.uint8)
38
- display[:img1.shape[0],:img1.shape[1]]=img1
39
- display[:img2.shape[0],img1.shape[1]:]=img2
40
  for i in range(len(corr1)):
41
- left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
42
- cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
43
- cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
44
- return display
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
 
4
+
5
+ def resize(image, long_dim):
6
+ h, w = image.shape[0], image.shape[1]
7
+ image = cv2.resize(
8
+ image, (int(w * long_dim / max(h, w)), int(h * long_dim / max(h, w)))
9
+ )
10
  return image
11
 
12
+
13
+ def draw_points(img, points, color=(0, 255, 0), radius=3):
14
  dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
15
  for i in range(points.shape[0]):
16
+ cv2.circle(img, dp[i], radius=radius, color=color)
17
  return img
 
18
 
19
+
20
+ def draw_match(
21
+ img1,
22
+ img2,
23
+ corr1,
24
+ corr2,
25
+ inlier=[True],
26
+ color=None,
27
+ radius1=1,
28
+ radius2=1,
29
+ resize=None,
30
+ ):
31
  if resize is not None:
32
+ scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
33
+ img2.shape[1] / resize[0],
34
+ img2.shape[0] / resize[1],
35
+ ]
36
+ img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
37
+ img2, resize, interpolation=cv2.INTER_AREA
38
+ )
39
+ corr1, corr2 = (
40
+ corr1 / np.asarray(scale1)[np.newaxis],
41
+ corr2 / np.asarray(scale2)[np.newaxis],
42
+ )
43
+ corr1_key = [
44
+ cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
45
+ ]
46
+ corr2_key = [
47
+ cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
48
+ ]
49
 
50
  assert len(corr1) == len(corr2)
51
 
52
  draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
53
  if color is None:
54
+ color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
55
+ if len(color) == 1:
56
+ display = cv2.drawMatches(
57
+ img1,
58
+ corr1_key,
59
+ img2,
60
+ corr2_key,
61
+ draw_matches,
62
+ None,
63
+ matchColor=color[0],
64
+ singlePointColor=color[0],
65
+ flags=4,
66
+ )
67
  else:
68
+ height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
69
+ display = np.zeros([height, width, 3], np.uint8)
70
+ display[: img1.shape[0], : img1.shape[1]] = img1
71
+ display[: img2.shape[0], img1.shape[1] :] = img2
72
  for i in range(len(corr1)):
73
+ left_x, left_y, right_x, right_y = (
74
+ int(corr1[i][0]),
75
+ int(corr1[i][1]),
76
+ int(corr2[i][0] + img1.shape[1]),
77
+ int(corr2[i][1]),
78
+ )
79
+ cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
80
+ cv2.line(
81
+ display,
82
+ (left_x, left_y),
83
+ (right_x, right_y),
84
+ cur_color,
85
+ 1,
86
+ lineType=cv2.LINE_AA,
87
+ )
88
+ return display
third_party/ASpanFormer/src/ASpanFormer/aspan_module/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .transformer import LocalFeatureTransformer_Flow
2
- from .loftr import LocalFeatureTransformer
3
  from .fine_preprocess import FinePreprocess
 
1
  from .transformer import LocalFeatureTransformer_Flow
2
+ from .loftr import LocalFeatureTransformer
3
  from .fine_preprocess import FinePreprocess
third_party/ASpanFormer/src/ASpanFormer/aspan_module/attention.py CHANGED
@@ -4,39 +4,59 @@ import torch.nn as nn
4
  from itertools import product
5
  from torch.nn import functional as F
6
 
 
7
  class layernorm2d(nn.Module):
8
-
9
- def __init__(self,dim) :
10
- super().__init__()
11
- self.dim=dim
12
- self.affine=nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
13
- self.bias=nn.parameter.Parameter(torch.zeros(dim), requires_grad=True)
14
-
15
- def forward(self,x):
16
- #x: B*C*H*W
17
- mean,std=x.mean(dim=1,keepdim=True),x.std(dim=1,keepdim=True)
18
- return self.affine[None,:,None,None]*(x-mean)/(std+1e-6)+self.bias[None,:,None,None]
 
 
19
 
20
 
21
  class HierachicalAttention(Module):
22
- def __init__(self,d_model,nhead,nsample,radius_scale,nlevel=3):
23
  super().__init__()
24
- self.d_model=d_model
25
- self.nhead=nhead
26
- self.nsample=nsample
27
- self.nlevel=nlevel
28
- self.radius_scale=radius_scale
29
  self.merge_head = nn.Sequential(
30
- nn.Conv1d(d_model*3, d_model, kernel_size=1,bias=False),
31
  nn.ReLU(True),
32
- nn.Conv1d(d_model, d_model, kernel_size=1,bias=False),
33
  )
34
- self.fullattention=FullAttention(d_model,nhead)
35
- self.temp=nn.parameter.Parameter(torch.tensor(1.),requires_grad=True)
36
- sample_offset=torch.tensor([[pos[0]-nsample[1]/2+0.5, pos[1]-nsample[1]/2+0.5] for pos in product(range(nsample[1]), range(nsample[1]))]) #r^2*2
37
- self.sample_offset=nn.parameter.Parameter(sample_offset,requires_grad=False)
 
 
 
 
 
38
 
39
- def forward(self,query,key,value,flow,size_q,size_kv,mask0=None, mask1=None,ds0=[4,4],ds1=[4,4]):
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
41
  Args:
42
  q,k,v (torch.Tensor): [B, C, L]
@@ -45,123 +65,217 @@ class HierachicalAttention(Module):
45
  Return:
46
  all_message (torch.Tensor): [B, C, H, W]
47
  """
48
-
49
- variance=flow[:,:,:,2:]
50
- offset=flow[:,:,:,:2] #B*H*W*2
51
- bs=query.shape[0]
52
- h0,w0=size_q[0],size_q[1]
53
- h1,w1=size_kv[0],size_kv[1]
54
- variance=torch.exp(0.5*variance)*self.radius_scale #b*h*w*2(pixel scale)
55
- span_scale=torch.clamp((variance*2/self.nsample[1]),min=1) #b*h*w*2
56
-
57
- sub_sample0,sub_sample1=[ds0,2,1],[ds1,2,1]
58
- q_list=[F.avg_pool2d(query.view(bs,-1,h0,w0),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
59
- k_list=[F.avg_pool2d(key.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
60
- v_list=[F.avg_pool2d(value.view(bs,-1,h1,w1),kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1] #n_level
61
-
62
- offset_list=[F.avg_pool2d(offset.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1)/sub_size for sub_size in sub_sample0[1:]] #n_level-1
63
- span_list=[F.avg_pool2d(span_scale.permute(0,3,1,2),kernel_size=sub_size*self.nsample[0],stride=sub_size*self.nsample[0]).permute(0,2,3,1) for sub_size in sub_sample0[1:]] #n_level-1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  if mask0 is not None:
66
- mask0,mask1=mask0.view(bs,1,h0,w0),mask1.view(bs,1,h1,w1)
67
- mask0_list=[-F.max_pool2d(-mask0,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample0]
68
- mask1_list=[-F.max_pool2d(-mask1,kernel_size=sub_size,stride=sub_size) for sub_size in sub_sample1]
 
 
 
 
 
 
69
  else:
70
- mask0_list=mask1_list=[None,None,None]
71
-
72
- message_list=[]
73
- #full attention at coarse scale
74
- mask0_flatten=mask0_list[0].view(bs,-1) if mask0 is not None else None
75
- mask1_flatten=mask1_list[0].view(bs,-1) if mask1 is not None else None
76
- message_list.append(self.fullattention(q_list[0],k_list[0],v_list[0],mask0_flatten,mask1_flatten,self.temp).view(bs,self.d_model,h0//ds0[0],w0//ds0[1]))
77
-
78
- for index in range(1,self.nlevel):
79
- q,k,v=q_list[index],k_list[index],v_list[index]
80
- mask0,mask1=mask0_list[index],mask1_list[index]
81
- s,o=span_list[index-1],offset_list[index-1] #B*h*w(*2)
82
- q,k,v,sample_pixel,mask_sample=self.partition_token(q,k,v,o,s,mask0) #B*Head*D*G*N(G*N=H*W for q)
83
- message_list.append(self.group_attention(q,k,v,1,mask_sample).view(bs,self.d_model,h0//sub_sample0[index],w0//sub_sample0[index]))
84
- #fuse
85
- all_message=torch.cat([F.upsample(message_list[idx],scale_factor=sub_sample0[idx],mode='nearest') \
86
- for idx in range(self.nlevel)],dim=1).view(bs,-1,h0*w0) #b*3d*H*W
87
-
88
- all_message=self.merge_head(all_message).view(bs,-1,h0,w0) #b*d*H*W
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  return all_message
90
-
91
- def partition_token(self,q,k,v,offset,span_scale,maskv):
92
- #q,k,v: B*C*H*W
93
- #o: B*H/2*W/2*2
94
- #span_scale:B*H*W
95
- bs=q.shape[0]
96
- h,w=q.shape[2],q.shape[3]
97
- hk,wk=k.shape[2],k.shape[3]
98
- offset=offset.view(bs,-1,2)
99
- span_scale=span_scale.view(bs,-1,1,2)
100
- #B*G*2
101
- offset_sample=self.sample_offset[None,None]*span_scale
102
- sample_pixel=offset[:,:,None]+offset_sample#B*G*r^2*2
103
- sample_norm=sample_pixel/torch.tensor([wk/2,hk/2]).cuda()[None,None,None]-1
104
-
105
- q = q.view(bs, -1 , h // self.nsample[0], self.nsample[0], w // self.nsample[0], self.nsample[0]).\
106
- permute(0, 1, 2, 4, 3, 5).contiguous().view(bs, self.nhead,self.d_model//self.nhead, -1,self.nsample[0]**2)#B*head*D*G*N(G*N=H*W for q)
107
- #sample token
108
- k=F.grid_sample(k, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
109
- v=F.grid_sample(v, grid=sample_norm).view(bs, self.nhead,self.d_model//self.nhead,-1, self.nsample[1]**2) #B*head*D*G*r^2
110
- #import pdb;pdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if maskv is not None:
112
- mask_sample=F.grid_sample(maskv.view(bs,-1,h,w).float(),grid=sample_norm,mode='nearest')==1 #B*1*G*r^2
 
 
 
 
 
113
  else:
114
- mask_sample=None
115
- return q,k,v,sample_pixel,mask_sample
116
-
117
 
118
- def group_attention(self,query,key,value,temp,mask_sample=None):
119
- #q,k,v: B*Head*D*G*N(G*N=H*W for q)
120
- bs=query.shape[0]
121
- #import pdb;pdb.set_trace()
122
  QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
123
  if mask_sample is not None:
124
- num_head,number_n=QK.shape[1],QK.shape[3]
125
- QK.masked_fill_(~(mask_sample[:,:,:,None]).expand(-1,num_head,-1,number_n,-1).bool(), float(-1e8))
 
 
 
 
 
126
  # Compute the attention and the weighted average
127
- softmax_temp = temp / query.size(2)**.5 # sqrt(D)
128
  A = torch.softmax(softmax_temp * QK, dim=-1)
129
- queried_values = torch.einsum("bhgnm,bhdgm->bhdgn", A, value).contiguous().view(bs,self.d_model,-1)
 
 
 
 
130
  return queried_values
131
 
132
-
133
 
134
  class FullAttention(Module):
135
- def __init__(self,d_model,nhead):
136
  super().__init__()
137
- self.d_model=d_model
138
- self.nhead=nhead
139
 
140
- def forward(self, q, k,v , mask0=None, mask1=None, temp=1):
141
- """ Multi-head scaled dot-product attention, a.k.a full attention.
142
  Args:
143
  q,k,v: [N, D, L]
144
  mask: [N, L]
145
  Returns:
146
  msg: [N,L]
147
  """
148
- bs=q.shape[0]
149
- q,k,v=q.view(bs,self.nhead,self.d_model//self.nhead,-1),k.view(bs,self.nhead,self.d_model//self.nhead,-1),v.view(bs,self.nhead,self.d_model//self.nhead,-1)
 
 
 
 
150
  # Compute the unnormalized attention and apply the masks
151
  QK = torch.einsum("nhdl,nhds->nhls", q, k)
152
  if mask0 is not None:
153
- QK.masked_fill_(~(mask0[:,None, :, None] * mask1[:, None, None]).bool(), float(-1e8))
 
 
154
  # Compute the attention and the weighted average
155
- softmax_temp = temp / q.size(2)**.5 # sqrt(D)
156
  A = torch.softmax(softmax_temp * QK, dim=-1)
157
- queried_values = torch.einsum("nhls,nhds->nhdl", A, v).contiguous().view(bs,self.d_model,-1)
 
 
 
 
158
  return queried_values
159
-
160
-
161
 
162
  def elu_feature_map(x):
163
  return F.elu(x) + 1
164
 
 
165
  class LinearAttention(Module):
166
  def __init__(self, eps=1e-6):
167
  super().__init__()
@@ -169,7 +283,7 @@ class LinearAttention(Module):
169
  self.eps = eps
170
 
171
  def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
172
- """ Multi-Head linear attention proposed in "Transformers are RNNs"
173
  Args:
174
  queries: [N, L, H, D]
175
  keys: [N, S, H, D]
@@ -195,4 +309,4 @@ class LinearAttention(Module):
195
  Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
196
  queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
197
 
198
- return queried_values.contiguous()
 
4
  from itertools import product
5
  from torch.nn import functional as F
6
 
7
+
8
  class layernorm2d(nn.Module):
9
+ def __init__(self, dim):
10
+ super().__init__()
11
+ self.dim = dim
12
+ self.affine = nn.parameter.Parameter(torch.ones(dim), requires_grad=True)
13
+ self.bias = nn.parameter.Parameter(torch.zeros(dim), requires_grad=True)
14
+
15
+ def forward(self, x):
16
+ # x: B*C*H*W
17
+ mean, std = x.mean(dim=1, keepdim=True), x.std(dim=1, keepdim=True)
18
+ return (
19
+ self.affine[None, :, None, None] * (x - mean) / (std + 1e-6)
20
+ + self.bias[None, :, None, None]
21
+ )
22
 
23
 
24
  class HierachicalAttention(Module):
25
+ def __init__(self, d_model, nhead, nsample, radius_scale, nlevel=3):
26
  super().__init__()
27
+ self.d_model = d_model
28
+ self.nhead = nhead
29
+ self.nsample = nsample
30
+ self.nlevel = nlevel
31
+ self.radius_scale = radius_scale
32
  self.merge_head = nn.Sequential(
33
+ nn.Conv1d(d_model * 3, d_model, kernel_size=1, bias=False),
34
  nn.ReLU(True),
35
+ nn.Conv1d(d_model, d_model, kernel_size=1, bias=False),
36
  )
37
+ self.fullattention = FullAttention(d_model, nhead)
38
+ self.temp = nn.parameter.Parameter(torch.tensor(1.0), requires_grad=True)
39
+ sample_offset = torch.tensor(
40
+ [
41
+ [pos[0] - nsample[1] / 2 + 0.5, pos[1] - nsample[1] / 2 + 0.5]
42
+ for pos in product(range(nsample[1]), range(nsample[1]))
43
+ ]
44
+ ) # r^2*2
45
+ self.sample_offset = nn.parameter.Parameter(sample_offset, requires_grad=False)
46
 
47
+ def forward(
48
+ self,
49
+ query,
50
+ key,
51
+ value,
52
+ flow,
53
+ size_q,
54
+ size_kv,
55
+ mask0=None,
56
+ mask1=None,
57
+ ds0=[4, 4],
58
+ ds1=[4, 4],
59
+ ):
60
  """
61
  Args:
62
  q,k,v (torch.Tensor): [B, C, L]
 
65
  Return:
66
  all_message (torch.Tensor): [B, C, H, W]
67
  """
68
+
69
+ variance = flow[:, :, :, 2:]
70
+ offset = flow[:, :, :, :2] # B*H*W*2
71
+ bs = query.shape[0]
72
+ h0, w0 = size_q[0], size_q[1]
73
+ h1, w1 = size_kv[0], size_kv[1]
74
+ variance = torch.exp(0.5 * variance) * self.radius_scale # b*h*w*2(pixel scale)
75
+ span_scale = torch.clamp((variance * 2 / self.nsample[1]), min=1) # b*h*w*2
76
+
77
+ sub_sample0, sub_sample1 = [ds0, 2, 1], [ds1, 2, 1]
78
+ q_list = [
79
+ F.avg_pool2d(
80
+ query.view(bs, -1, h0, w0), kernel_size=sub_size, stride=sub_size
81
+ )
82
+ for sub_size in sub_sample0
83
+ ]
84
+ k_list = [
85
+ F.avg_pool2d(
86
+ key.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
87
+ )
88
+ for sub_size in sub_sample1
89
+ ]
90
+ v_list = [
91
+ F.avg_pool2d(
92
+ value.view(bs, -1, h1, w1), kernel_size=sub_size, stride=sub_size
93
+ )
94
+ for sub_size in sub_sample1
95
+ ] # n_level
96
+
97
+ offset_list = [
98
+ F.avg_pool2d(
99
+ offset.permute(0, 3, 1, 2),
100
+ kernel_size=sub_size * self.nsample[0],
101
+ stride=sub_size * self.nsample[0],
102
+ ).permute(0, 2, 3, 1)
103
+ / sub_size
104
+ for sub_size in sub_sample0[1:]
105
+ ] # n_level-1
106
+ span_list = [
107
+ F.avg_pool2d(
108
+ span_scale.permute(0, 3, 1, 2),
109
+ kernel_size=sub_size * self.nsample[0],
110
+ stride=sub_size * self.nsample[0],
111
+ ).permute(0, 2, 3, 1)
112
+ for sub_size in sub_sample0[1:]
113
+ ] # n_level-1
114
 
115
  if mask0 is not None:
116
+ mask0, mask1 = mask0.view(bs, 1, h0, w0), mask1.view(bs, 1, h1, w1)
117
+ mask0_list = [
118
+ -F.max_pool2d(-mask0, kernel_size=sub_size, stride=sub_size)
119
+ for sub_size in sub_sample0
120
+ ]
121
+ mask1_list = [
122
+ -F.max_pool2d(-mask1, kernel_size=sub_size, stride=sub_size)
123
+ for sub_size in sub_sample1
124
+ ]
125
  else:
126
+ mask0_list = mask1_list = [None, None, None]
127
+
128
+ message_list = []
129
+ # full attention at coarse scale
130
+ mask0_flatten = mask0_list[0].view(bs, -1) if mask0 is not None else None
131
+ mask1_flatten = mask1_list[0].view(bs, -1) if mask1 is not None else None
132
+ message_list.append(
133
+ self.fullattention(
134
+ q_list[0], k_list[0], v_list[0], mask0_flatten, mask1_flatten, self.temp
135
+ ).view(bs, self.d_model, h0 // ds0[0], w0 // ds0[1])
136
+ )
137
+
138
+ for index in range(1, self.nlevel):
139
+ q, k, v = q_list[index], k_list[index], v_list[index]
140
+ mask0, mask1 = mask0_list[index], mask1_list[index]
141
+ s, o = span_list[index - 1], offset_list[index - 1] # B*h*w(*2)
142
+ q, k, v, sample_pixel, mask_sample = self.partition_token(
143
+ q, k, v, o, s, mask0
144
+ ) # B*Head*D*G*N(G*N=H*W for q)
145
+ message_list.append(
146
+ self.group_attention(q, k, v, 1, mask_sample).view(
147
+ bs, self.d_model, h0 // sub_sample0[index], w0 // sub_sample0[index]
148
+ )
149
+ )
150
+ # fuse
151
+ all_message = torch.cat(
152
+ [
153
+ F.upsample(
154
+ message_list[idx], scale_factor=sub_sample0[idx], mode="nearest"
155
+ )
156
+ for idx in range(self.nlevel)
157
+ ],
158
+ dim=1,
159
+ ).view(
160
+ bs, -1, h0 * w0
161
+ ) # b*3d*H*W
162
+
163
+ all_message = self.merge_head(all_message).view(bs, -1, h0, w0) # b*d*H*W
164
  return all_message
165
+
166
+ def partition_token(self, q, k, v, offset, span_scale, maskv):
167
+ # q,k,v: B*C*H*W
168
+ # o: B*H/2*W/2*2
169
+ # span_scale:B*H*W
170
+ bs = q.shape[0]
171
+ h, w = q.shape[2], q.shape[3]
172
+ hk, wk = k.shape[2], k.shape[3]
173
+ offset = offset.view(bs, -1, 2)
174
+ span_scale = span_scale.view(bs, -1, 1, 2)
175
+ # B*G*2
176
+ offset_sample = self.sample_offset[None, None] * span_scale
177
+ sample_pixel = offset[:, :, None] + offset_sample # B*G*r^2*2
178
+ sample_norm = (
179
+ sample_pixel / torch.tensor([wk / 2, hk / 2]).cuda()[None, None, None] - 1
180
+ )
181
+
182
+ q = (
183
+ q.view(
184
+ bs,
185
+ -1,
186
+ h // self.nsample[0],
187
+ self.nsample[0],
188
+ w // self.nsample[0],
189
+ self.nsample[0],
190
+ )
191
+ .permute(0, 1, 2, 4, 3, 5)
192
+ .contiguous()
193
+ .view(bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[0] ** 2)
194
+ ) # B*head*D*G*N(G*N=H*W for q)
195
+ # sample token
196
+ k = F.grid_sample(k, grid=sample_norm).view(
197
+ bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
198
+ ) # B*head*D*G*r^2
199
+ v = F.grid_sample(v, grid=sample_norm).view(
200
+ bs, self.nhead, self.d_model // self.nhead, -1, self.nsample[1] ** 2
201
+ ) # B*head*D*G*r^2
202
+ # import pdb;pdb.set_trace()
203
  if maskv is not None:
204
+ mask_sample = (
205
+ F.grid_sample(
206
+ maskv.view(bs, -1, h, w).float(), grid=sample_norm, mode="nearest"
207
+ )
208
+ == 1
209
+ ) # B*1*G*r^2
210
  else:
211
+ mask_sample = None
212
+ return q, k, v, sample_pixel, mask_sample
 
213
 
214
+ def group_attention(self, query, key, value, temp, mask_sample=None):
215
+ # q,k,v: B*Head*D*G*N(G*N=H*W for q)
216
+ bs = query.shape[0]
217
+ # import pdb;pdb.set_trace()
218
  QK = torch.einsum("bhdgn,bhdgm->bhgnm", query, key)
219
  if mask_sample is not None:
220
+ num_head, number_n = QK.shape[1], QK.shape[3]
221
+ QK.masked_fill_(
222
+ ~(mask_sample[:, :, :, None])
223
+ .expand(-1, num_head, -1, number_n, -1)
224
+ .bool(),
225
+ float(-1e8),
226
+ )
227
  # Compute the attention and the weighted average
228
+ softmax_temp = temp / query.size(2) ** 0.5 # sqrt(D)
229
  A = torch.softmax(softmax_temp * QK, dim=-1)
230
+ queried_values = (
231
+ torch.einsum("bhgnm,bhdgm->bhdgn", A, value)
232
+ .contiguous()
233
+ .view(bs, self.d_model, -1)
234
+ )
235
  return queried_values
236
 
 
237
 
238
  class FullAttention(Module):
239
+ def __init__(self, d_model, nhead):
240
  super().__init__()
241
+ self.d_model = d_model
242
+ self.nhead = nhead
243
 
244
+ def forward(self, q, k, v, mask0=None, mask1=None, temp=1):
245
+ """Multi-head scaled dot-product attention, a.k.a full attention.
246
  Args:
247
  q,k,v: [N, D, L]
248
  mask: [N, L]
249
  Returns:
250
  msg: [N,L]
251
  """
252
+ bs = q.shape[0]
253
+ q, k, v = (
254
+ q.view(bs, self.nhead, self.d_model // self.nhead, -1),
255
+ k.view(bs, self.nhead, self.d_model // self.nhead, -1),
256
+ v.view(bs, self.nhead, self.d_model // self.nhead, -1),
257
+ )
258
  # Compute the unnormalized attention and apply the masks
259
  QK = torch.einsum("nhdl,nhds->nhls", q, k)
260
  if mask0 is not None:
261
+ QK.masked_fill_(
262
+ ~(mask0[:, None, :, None] * mask1[:, None, None]).bool(), float(-1e8)
263
+ )
264
  # Compute the attention and the weighted average
265
+ softmax_temp = temp / q.size(2) ** 0.5 # sqrt(D)
266
  A = torch.softmax(softmax_temp * QK, dim=-1)
267
+ queried_values = (
268
+ torch.einsum("nhls,nhds->nhdl", A, v)
269
+ .contiguous()
270
+ .view(bs, self.d_model, -1)
271
+ )
272
  return queried_values
273
+
 
274
 
275
  def elu_feature_map(x):
276
  return F.elu(x) + 1
277
 
278
+
279
  class LinearAttention(Module):
280
  def __init__(self, eps=1e-6):
281
  super().__init__()
 
283
  self.eps = eps
284
 
285
  def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
286
+ """Multi-Head linear attention proposed in "Transformers are RNNs"
287
  Args:
288
  queries: [N, L, H, D]
289
  keys: [N, S, H, D]
 
309
  Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
310
  queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
311
 
312
+ return queried_values.contiguous()
third_party/ASpanFormer/src/ASpanFormer/aspan_module/fine_preprocess.py CHANGED
@@ -9,15 +9,15 @@ class FinePreprocess(nn.Module):
9
  super().__init__()
10
 
11
  self.config = config
12
- self.cat_c_feat = config['fine_concat_coarse_feat']
13
- self.W = self.config['fine_window_size']
14
 
15
- d_model_c = self.config['coarse']['d_model']
16
- d_model_f = self.config['fine']['d_model']
17
  self.d_model_f = d_model_f
18
  if self.cat_c_feat:
19
  self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
20
- self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
21
 
22
  self._reset_parameters()
23
 
@@ -28,32 +28,48 @@ class FinePreprocess(nn.Module):
28
 
29
  def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
30
  W = self.W
31
- stride = data['hw0_f'][0] // data['hw0_c'][0]
32
 
33
- data.update({'W': W})
34
- if data['b_ids'].shape[0] == 0:
35
  feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
36
  feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
37
  return feat0, feat1
38
 
39
  # 1. unfold(crop) all local windows
40
- feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
41
- feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
42
- feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
43
- feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
 
 
 
 
44
 
45
  # 2. select only the predicted matches
46
- feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
47
- feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
48
 
49
  # option: use coarse-level loftr feature as context: concat and linear
50
  if self.cat_c_feat:
51
- feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
52
- feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
53
- feat_cf_win = self.merge_feat(torch.cat([
54
- torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
55
- repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
56
- ], -1))
 
 
 
 
 
 
 
 
 
 
 
 
57
  feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
58
 
59
  return feat_f0_unfold, feat_f1_unfold
 
9
  super().__init__()
10
 
11
  self.config = config
12
+ self.cat_c_feat = config["fine_concat_coarse_feat"]
13
+ self.W = self.config["fine_window_size"]
14
 
15
+ d_model_c = self.config["coarse"]["d_model"]
16
+ d_model_f = self.config["fine"]["d_model"]
17
  self.d_model_f = d_model_f
18
  if self.cat_c_feat:
19
  self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
20
+ self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
21
 
22
  self._reset_parameters()
23
 
 
28
 
29
  def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
30
  W = self.W
31
+ stride = data["hw0_f"][0] // data["hw0_c"][0]
32
 
33
+ data.update({"W": W})
34
+ if data["b_ids"].shape[0] == 0:
35
  feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
36
  feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
37
  return feat0, feat1
38
 
39
  # 1. unfold(crop) all local windows
40
+ feat_f0_unfold = F.unfold(
41
+ feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2
42
+ )
43
+ feat_f0_unfold = rearrange(feat_f0_unfold, "n (c ww) l -> n l ww c", ww=W**2)
44
+ feat_f1_unfold = F.unfold(
45
+ feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2
46
+ )
47
+ feat_f1_unfold = rearrange(feat_f1_unfold, "n (c ww) l -> n l ww c", ww=W**2)
48
 
49
  # 2. select only the predicted matches
50
+ feat_f0_unfold = feat_f0_unfold[data["b_ids"], data["i_ids"]] # [n, ww, cf]
51
+ feat_f1_unfold = feat_f1_unfold[data["b_ids"], data["j_ids"]]
52
 
53
  # option: use coarse-level loftr feature as context: concat and linear
54
  if self.cat_c_feat:
55
+ feat_c_win = self.down_proj(
56
+ torch.cat(
57
+ [
58
+ feat_c0[data["b_ids"], data["i_ids"]],
59
+ feat_c1[data["b_ids"], data["j_ids"]],
60
+ ],
61
+ 0,
62
+ )
63
+ ) # [2n, c]
64
+ feat_cf_win = self.merge_feat(
65
+ torch.cat(
66
+ [
67
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
68
+ repeat(feat_c_win, "n c -> n ww c", ww=W**2), # [2n, ww, cf]
69
+ ],
70
+ -1,
71
+ )
72
+ )
73
  feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
74
 
75
  return feat_f0_unfold, feat_f1_unfold
third_party/ASpanFormer/src/ASpanFormer/aspan_module/loftr.py CHANGED
@@ -3,11 +3,9 @@ import torch
3
  import torch.nn as nn
4
  from .attention import LinearAttention
5
 
 
6
  class LoFTREncoderLayer(nn.Module):
7
- def __init__(self,
8
- d_model,
9
- nhead,
10
- attention='linear'):
11
  super(LoFTREncoderLayer, self).__init__()
12
 
13
  self.dim = d_model // nhead
@@ -22,9 +20,9 @@ class LoFTREncoderLayer(nn.Module):
22
 
23
  # feed-forward network
24
  self.mlp = nn.Sequential(
25
- nn.Linear(d_model*2, d_model*2, bias=False),
26
  nn.ReLU(True),
27
- nn.Linear(d_model*2, d_model, bias=False),
28
  )
29
 
30
  # norm and dropout
@@ -43,16 +41,14 @@ class LoFTREncoderLayer(nn.Module):
43
  query, key, value = x, source, source
44
 
45
  # multi-head attention
46
- query = self.q_proj(query).view(
47
- bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
48
- key = self.k_proj(key).view(bs, -1, self.nhead,
49
- self.dim) # [N, S, (H, D)]
50
  value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
51
 
52
  message = self.attention(
53
- query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
54
- message = self.merge(message.view(
55
- bs, -1, self.nhead*self.dim)) # [N, L, C]
56
  message = self.norm1(message)
57
 
58
  # feed-forward network
@@ -69,13 +65,15 @@ class LocalFeatureTransformer(nn.Module):
69
  super(LocalFeatureTransformer, self).__init__()
70
 
71
  self.config = config
72
- self.d_model = config['d_model']
73
- self.nhead = config['nhead']
74
- self.layer_names = config['layer_names']
75
  encoder_layer = LoFTREncoderLayer(
76
- config['d_model'], config['nhead'], config['attention'])
 
77
  self.layers = nn.ModuleList(
78
- [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
 
79
  self._reset_parameters()
80
 
81
  def _reset_parameters(self):
@@ -93,20 +91,18 @@ class LocalFeatureTransformer(nn.Module):
93
  """
94
 
95
  assert self.d_model == feat0.size(
96
- 2), "the feature number of src and transformer must be equal"
 
97
 
98
  index = 0
99
  for layer, name in zip(self.layers, self.layer_names):
100
- if name == 'self':
101
- feat0 = layer(feat0, feat0, mask0, mask0,
102
- type='self', index=index)
103
  feat1 = layer(feat1, feat1, mask1, mask1)
104
- elif name == 'cross':
105
  feat0 = layer(feat0, feat1, mask0, mask1)
106
- feat1 = layer(feat1, feat0, mask1, mask0,
107
- type='cross', index=index)
108
  index += 1
109
  else:
110
  raise KeyError
111
  return feat0, feat1
112
-
 
3
  import torch.nn as nn
4
  from .attention import LinearAttention
5
 
6
+
7
  class LoFTREncoderLayer(nn.Module):
8
+ def __init__(self, d_model, nhead, attention="linear"):
 
 
 
9
  super(LoFTREncoderLayer, self).__init__()
10
 
11
  self.dim = d_model // nhead
 
20
 
21
  # feed-forward network
22
  self.mlp = nn.Sequential(
23
+ nn.Linear(d_model * 2, d_model * 2, bias=False),
24
  nn.ReLU(True),
25
+ nn.Linear(d_model * 2, d_model, bias=False),
26
  )
27
 
28
  # norm and dropout
 
41
  query, key, value = x, source, source
42
 
43
  # multi-head attention
44
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
45
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
 
 
46
  value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
47
 
48
  message = self.attention(
49
+ query, key, value, q_mask=x_mask, kv_mask=source_mask
50
+ ) # [N, L, (H, D)]
51
+ message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
52
  message = self.norm1(message)
53
 
54
  # feed-forward network
 
65
  super(LocalFeatureTransformer, self).__init__()
66
 
67
  self.config = config
68
+ self.d_model = config["d_model"]
69
+ self.nhead = config["nhead"]
70
+ self.layer_names = config["layer_names"]
71
  encoder_layer = LoFTREncoderLayer(
72
+ config["d_model"], config["nhead"], config["attention"]
73
+ )
74
  self.layers = nn.ModuleList(
75
+ [copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]
76
+ )
77
  self._reset_parameters()
78
 
79
  def _reset_parameters(self):
 
91
  """
92
 
93
  assert self.d_model == feat0.size(
94
+ 2
95
+ ), "the feature number of src and transformer must be equal"
96
 
97
  index = 0
98
  for layer, name in zip(self.layers, self.layer_names):
99
+ if name == "self":
100
+ feat0 = layer(feat0, feat0, mask0, mask0, type="self", index=index)
 
101
  feat1 = layer(feat1, feat1, mask1, mask1)
102
+ elif name == "cross":
103
  feat0 = layer(feat0, feat1, mask0, mask1)
104
+ feat1 = layer(feat1, feat0, mask1, mask0, type="cross", index=index)
 
105
  index += 1
106
  else:
107
  raise KeyError
108
  return feat0, feat1
 
third_party/ASpanFormer/src/ASpanFormer/aspan_module/transformer.py CHANGED
@@ -2,44 +2,42 @@ import copy
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from .attention import FullAttention, HierachicalAttention ,layernorm2d
6
 
7
 
8
  class messageLayer_ini(nn.Module):
9
-
10
- def __init__(self, d_model, d_flow,d_value, nhead):
11
  super().__init__()
12
  super(messageLayer_ini, self).__init__()
13
 
14
  self.d_model = d_model
15
  self.d_flow = d_flow
16
- self.d_value=d_value
17
  self.nhead = nhead
18
- self.attention = FullAttention(d_model,nhead)
19
 
20
- self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
21
- self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
22
- self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
23
- self.merge_head=nn.Conv1d(d_model,d_model,kernel_size=1,bias=False)
24
 
25
- self.merge_f= self.merge_f = nn.Sequential(
26
- nn.Conv2d(d_model*2, d_model*2, kernel_size=1, bias=False),
27
  nn.ReLU(True),
28
- nn.Conv2d(d_model*2, d_model, kernel_size=1, bias=False),
29
  )
30
 
31
  self.norm1 = layernorm2d(d_model)
32
  self.norm2 = layernorm2d(d_model)
33
 
 
 
 
 
 
 
34
 
35
- def forward(self, x0, x1,pos0,pos1,mask0=None,mask1=None):
36
- #x1,x2: b*d*L
37
- x0,x1=self.update(x0,x1,pos1,mask0,mask1),\
38
- self.update(x1,x0,pos0,mask1,mask0)
39
- return x0,x1
40
-
41
-
42
- def update(self,f0,f1,pos1,mask0,mask1):
43
  """
44
  Args:
45
  f0: [N, D, H, W]
@@ -47,53 +45,77 @@ class messageLayer_ini(nn.Module):
47
  Returns:
48
  f0_new: (N, d, h, w)
49
  """
50
- bs,h,w=f0.shape[0],f0.shape[2],f0.shape[3]
51
 
52
- f0_flatten,f1_flatten=f0.view(bs,self.d_model,-1),f1.view(bs,self.d_model,-1)
53
- pos1_flatten=pos1.view(bs,self.d_value-self.d_model,-1)
54
- f1_flatten_v=torch.cat([f1_flatten,pos1_flatten],dim=1)
 
 
55
 
56
- queries,keys=self.q_proj(f0_flatten),self.k_proj(f1_flatten)
57
- values=self.v_proj(f1_flatten_v).view(bs,self.nhead,self.d_model//self.nhead,-1)
58
-
59
- queried_values=self.attention(queries,keys,values,mask0,mask1)
60
- msg=self.merge_head(queried_values).view(bs,-1,h,w)
61
- msg=self.norm2(self.merge_f(torch.cat([f0,self.norm1(msg)],dim=1)))
62
- return f0+msg
63
 
 
 
 
 
64
 
65
 
66
  class messageLayer_gla(nn.Module):
67
-
68
- def __init__(self,d_model,d_flow,d_value,
69
- nhead,radius_scale,nsample,update_flow=True):
70
  super().__init__()
71
  self.d_model = d_model
72
- self.d_flow=d_flow
73
- self.d_value=d_value
74
  self.nhead = nhead
75
- self.radius_scale=radius_scale
76
- self.update_flow=update_flow
77
- self.flow_decoder=nn.Sequential(
78
- nn.Conv1d(d_flow, d_flow//2, kernel_size=1, bias=False),
79
- nn.ReLU(True),
80
- nn.Conv1d(d_flow//2, 4, kernel_size=1, bias=False))
81
- self.attention=HierachicalAttention(d_model,nhead,nsample,radius_scale)
82
-
83
- self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
84
- self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1,bias=False)
85
- self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1,bias=False)
86
-
87
- d_extra=d_flow if update_flow else 0
88
- self.merge_f=nn.Sequential(
89
- nn.Conv2d(d_model*2+d_extra, d_model+d_flow, kernel_size=1, bias=False),
90
- nn.ReLU(True),
91
- nn.Conv2d(d_model+d_flow, d_model+d_extra, kernel_size=3,padding=1, bias=False),
92
- )
93
- self.norm1 = layernorm2d(d_model)
94
- self.norm2 = layernorm2d(d_model+d_extra)
95
 
96
- def forward(self, x0, x1, flow_feature0,flow_feature1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
  Args:
99
  x0 (torch.Tensor): [B, C, H, W]
@@ -101,88 +123,135 @@ class messageLayer_gla(nn.Module):
101
  flow_feature0 (torch.Tensor): [B, C', H, W]
102
  flow_feature1 (torch.Tensor): [B, C', H, W]
103
  """
104
- flow0,flow1=self.decode_flow(flow_feature0,flow_feature1.shape[2:]),self.decode_flow(flow_feature1,flow_feature0.shape[2:])
105
- x0_new,flow_feature0_new=self.update(x0,x1,flow0.detach(),flow_feature0,pos1,mask0,mask1,ds0,ds1)
106
- x1_new,flow_feature1_new=self.update(x1,x0,flow1.detach(),flow_feature1,pos0,mask1,mask0,ds1,ds0)
107
- return x0_new,x1_new,flow_feature0_new,flow_feature1_new,flow0,flow1
108
-
109
- def update(self,x0,x1,flow0,flow_feature0,pos1,mask0,mask1,ds0,ds1):
110
- bs=x0.shape[0]
111
- queries,keys=self.q_proj(x0.view(bs,self.d_model,-1)),self.k_proj(x1.view(bs,self.d_model,-1))
112
- x1_pos=torch.cat([x1,pos1],dim=1)
113
- values=self.v_proj(x1_pos.view(bs,self.d_value,-1))
114
- msg=self.attention(queries,keys,values,flow0,x0.shape[2:],x1.shape[2:],mask0,mask1,ds0,ds1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if self.update_flow:
117
- update_feature=torch.cat([x0,flow_feature0],dim=1)
118
  else:
119
- update_feature=x0
120
- msg=self.norm2(self.merge_f(torch.cat([update_feature,self.norm1(msg)],dim=1)))
121
- update_feature=update_feature+msg
122
-
123
- x0_new,flow_feature0_new=update_feature[:,:self.d_model],update_feature[:,self.d_model:]
124
- return x0_new,flow_feature0_new
125
-
126
- def decode_flow(self,flow_feature,kshape):
127
- bs,h,w=flow_feature.shape[0],flow_feature.shape[2],flow_feature.shape[3]
128
- scale_factor=torch.tensor([kshape[1],kshape[0]]).cuda()[None,None,None]
129
- flow=self.flow_decoder(flow_feature.view(bs,-1,h*w)).permute(0,2,1).view(bs,h,w,4)
130
- flow_coordinates=torch.sigmoid(flow[:,:,:,:2])*scale_factor
131
- flow_var=flow[:,:,:,2:]
132
- flow=torch.cat([flow_coordinates,flow_var],dim=-1) #B*H*W*4
 
 
 
 
 
 
 
 
 
133
  return flow
134
 
135
 
136
  class flow_initializer(nn.Module):
137
-
138
  def __init__(self, dim, dim_flow, nhead, layer_num):
139
  super().__init__()
140
- self.layer_num= layer_num
141
  self.dim = dim
142
  self.dim_flow = dim_flow
143
 
144
- encoder_layer = messageLayer_ini(
145
- dim ,dim_flow,dim+dim_flow , nhead)
146
  self.layers_coarse = nn.ModuleList(
147
- [copy.deepcopy(encoder_layer) for _ in range(layer_num)])
148
- self.decoupler = nn.Conv2d(
149
- self.dim, self.dim+self.dim_flow, kernel_size=1)
150
- self.up_merge = nn.Conv2d(2*dim, dim, kernel_size=1)
151
 
152
- def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
 
 
153
  # feat0: [B, C, H0, W0]
154
  # feat1: [B, C, H1, W1]
155
  # use low-res MHA to initialize flow feature
156
  bs = feat0.size(0)
157
- h0,w0,h1,w1=feat0.shape[2],feat0.shape[3],feat1.shape[2],feat1.shape[3]
158
 
159
  # coarse level
160
- sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), \
161
- F.avg_pool2d(feat1, ds1, stride=ds1)
 
 
 
 
 
162
 
163
- sub_pos0,sub_pos1=F.avg_pool2d(pos0, ds0, stride=ds0), \
164
- F.avg_pool2d(pos1, ds1, stride=ds1)
165
-
166
  if mask0 is not None:
167
- mask0,mask1=-F.max_pool2d(-mask0.view(bs,1,h0,w0),ds0,stride=ds0).view(bs,-1),\
168
- -F.max_pool2d(-mask1.view(bs,1,h1,w1),ds1,stride=ds1).view(bs,-1)
169
-
 
 
 
 
 
170
  for layer in self.layers_coarse:
171
- sub_feat0, sub_feat1 = layer(sub_feat0, sub_feat1,sub_pos0,sub_pos1,mask0,mask1)
 
 
172
  # decouple flow and visual features
173
- decoupled_feature0, decoupled_feature1 = self.decoupler(sub_feat0),self.decoupler(sub_feat1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- sub_feat0, sub_flow_feature0 = decoupled_feature0[:,:self.dim], decoupled_feature0[:, self.dim:]
176
- sub_feat1, sub_flow_feature1 = decoupled_feature1[:,:self.dim], decoupled_feature1[:, self.dim:]
177
- update_feat0, flow_feature0 = F.upsample(sub_feat0, scale_factor=ds0, mode='bilinear'),\
178
- F.upsample(sub_flow_feature0, scale_factor=ds0, mode='bilinear')
179
- update_feat1, flow_feature1 = F.upsample(sub_feat1, scale_factor=ds1, mode='bilinear'),\
180
- F.upsample(sub_flow_feature1, scale_factor=ds1, mode='bilinear')
181
-
182
- feat0 = feat0+self.up_merge(torch.cat([feat0, update_feat0], dim=1))
183
- feat1 = feat1+self.up_merge(torch.cat([feat1, update_feat1], dim=1))
184
-
185
- return feat0,feat1,flow_feature0,flow_feature1 #b*c*h*w
186
 
187
 
188
  class LocalFeatureTransformer_Flow(nn.Module):
@@ -192,27 +261,49 @@ class LocalFeatureTransformer_Flow(nn.Module):
192
  super(LocalFeatureTransformer_Flow, self).__init__()
193
 
194
  self.config = config
195
- self.d_model = config['d_model']
196
- self.nhead = config['nhead']
 
 
 
 
 
 
 
197
 
198
- self.pos_transform=nn.Conv2d(config['d_model'],config['d_flow'],kernel_size=1,bias=False)
199
- self.ini_layer = flow_initializer(self.d_model, config['d_flow'], config['nhead'],config['ini_layer_num'])
200
-
201
  encoder_layer = messageLayer_gla(
202
- config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'])
203
- encoder_layer_last=messageLayer_gla(
204
- config['d_model'], config['d_flow'], config['d_flow']+config['d_model'], config['nhead'],config['radius_scale'],config['nsample'],update_flow=False)
205
- self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(config['layer_num']-1)]+[encoder_layer_last])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  self._reset_parameters()
207
-
208
  def _reset_parameters(self):
209
- for name,p in self.named_parameters():
210
- if 'temp' in name or 'sample_offset' in name:
211
  continue
212
  if p.dim() > 1:
213
  nn.init.xavier_uniform_(p)
214
 
215
- def forward(self, feat0, feat1,pos0,pos1,mask0=None,mask1=None,ds0=[4,4],ds1=[4,4]):
 
 
216
  """
217
  Args:
218
  feat0 (torch.Tensor): [N, C, H, W]
@@ -224,21 +315,37 @@ class LocalFeatureTransformer_Flow(nn.Module):
224
  flow_list: [L,N,H,W,4]*1(2)
225
  """
226
  bs = feat0.size(0)
227
-
228
- pos0,pos1=self.pos_transform(pos0),self.pos_transform(pos1)
229
- pos0,pos1=pos0.expand(bs,-1,-1,-1),pos1.expand(bs,-1,-1,-1)
230
  assert self.d_model == feat0.size(
231
- 1), "the feature number of src and transformer must be equal"
232
-
233
- flow_list=[[],[]]# [px,py,sx,sy]
 
234
  if mask0 is not None:
235
- mask0,mask1=mask0[:,None].float(),mask1[:,None].float()
236
- feat0,feat1, flow_feature0, flow_feature1 = self.ini_layer(feat0, feat1,pos0,pos1,mask0,mask1,ds0,ds1)
 
 
237
  for layer in self.layers:
238
- feat0,feat1,flow_feature0,flow_feature1,flow0,flow1=layer(feat0,feat1,flow_feature0,flow_feature1,pos0,pos1,mask0,mask1,ds0,ds1)
 
 
 
 
 
 
 
 
 
 
 
239
  flow_list[0].append(flow0)
240
  flow_list[1].append(flow1)
241
- flow_list[0]=torch.stack(flow_list[0],dim=0)
242
- flow_list[1]=torch.stack(flow_list[1],dim=0)
243
- feat0, feat1 = feat0.permute(0, 2, 3, 1).view(bs, -1, self.d_model), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
244
- return feat0, feat1, flow_list
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ from .attention import FullAttention, HierachicalAttention, layernorm2d
6
 
7
 
8
  class messageLayer_ini(nn.Module):
9
+ def __init__(self, d_model, d_flow, d_value, nhead):
 
10
  super().__init__()
11
  super(messageLayer_ini, self).__init__()
12
 
13
  self.d_model = d_model
14
  self.d_flow = d_flow
15
+ self.d_value = d_value
16
  self.nhead = nhead
17
+ self.attention = FullAttention(d_model, nhead)
18
 
19
+ self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
20
+ self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
21
+ self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
22
+ self.merge_head = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
23
 
24
+ self.merge_f = self.merge_f = nn.Sequential(
25
+ nn.Conv2d(d_model * 2, d_model * 2, kernel_size=1, bias=False),
26
  nn.ReLU(True),
27
+ nn.Conv2d(d_model * 2, d_model, kernel_size=1, bias=False),
28
  )
29
 
30
  self.norm1 = layernorm2d(d_model)
31
  self.norm2 = layernorm2d(d_model)
32
 
33
+ def forward(self, x0, x1, pos0, pos1, mask0=None, mask1=None):
34
+ # x1,x2: b*d*L
35
+ x0, x1 = self.update(x0, x1, pos1, mask0, mask1), self.update(
36
+ x1, x0, pos0, mask1, mask0
37
+ )
38
+ return x0, x1
39
 
40
+ def update(self, f0, f1, pos1, mask0, mask1):
 
 
 
 
 
 
 
41
  """
42
  Args:
43
  f0: [N, D, H, W]
 
45
  Returns:
46
  f0_new: (N, d, h, w)
47
  """
48
+ bs, h, w = f0.shape[0], f0.shape[2], f0.shape[3]
49
 
50
+ f0_flatten, f1_flatten = f0.view(bs, self.d_model, -1), f1.view(
51
+ bs, self.d_model, -1
52
+ )
53
+ pos1_flatten = pos1.view(bs, self.d_value - self.d_model, -1)
54
+ f1_flatten_v = torch.cat([f1_flatten, pos1_flatten], dim=1)
55
 
56
+ queries, keys = self.q_proj(f0_flatten), self.k_proj(f1_flatten)
57
+ values = self.v_proj(f1_flatten_v).view(
58
+ bs, self.nhead, self.d_model // self.nhead, -1
59
+ )
 
 
 
60
 
61
+ queried_values = self.attention(queries, keys, values, mask0, mask1)
62
+ msg = self.merge_head(queried_values).view(bs, -1, h, w)
63
+ msg = self.norm2(self.merge_f(torch.cat([f0, self.norm1(msg)], dim=1)))
64
+ return f0 + msg
65
 
66
 
67
  class messageLayer_gla(nn.Module):
68
+ def __init__(
69
+ self, d_model, d_flow, d_value, nhead, radius_scale, nsample, update_flow=True
70
+ ):
71
  super().__init__()
72
  self.d_model = d_model
73
+ self.d_flow = d_flow
74
+ self.d_value = d_value
75
  self.nhead = nhead
76
+ self.radius_scale = radius_scale
77
+ self.update_flow = update_flow
78
+ self.flow_decoder = nn.Sequential(
79
+ nn.Conv1d(d_flow, d_flow // 2, kernel_size=1, bias=False),
80
+ nn.ReLU(True),
81
+ nn.Conv1d(d_flow // 2, 4, kernel_size=1, bias=False),
82
+ )
83
+ self.attention = HierachicalAttention(d_model, nhead, nsample, radius_scale)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ self.q_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
86
+ self.k_proj = nn.Conv1d(d_model, d_model, kernel_size=1, bias=False)
87
+ self.v_proj = nn.Conv1d(d_value, d_model, kernel_size=1, bias=False)
88
+
89
+ d_extra = d_flow if update_flow else 0
90
+ self.merge_f = nn.Sequential(
91
+ nn.Conv2d(
92
+ d_model * 2 + d_extra, d_model + d_flow, kernel_size=1, bias=False
93
+ ),
94
+ nn.ReLU(True),
95
+ nn.Conv2d(
96
+ d_model + d_flow,
97
+ d_model + d_extra,
98
+ kernel_size=3,
99
+ padding=1,
100
+ bias=False,
101
+ ),
102
+ )
103
+ self.norm1 = layernorm2d(d_model)
104
+ self.norm2 = layernorm2d(d_model + d_extra)
105
+
106
+ def forward(
107
+ self,
108
+ x0,
109
+ x1,
110
+ flow_feature0,
111
+ flow_feature1,
112
+ pos0,
113
+ pos1,
114
+ mask0=None,
115
+ mask1=None,
116
+ ds0=[4, 4],
117
+ ds1=[4, 4],
118
+ ):
119
  """
120
  Args:
121
  x0 (torch.Tensor): [B, C, H, W]
 
123
  flow_feature0 (torch.Tensor): [B, C', H, W]
124
  flow_feature1 (torch.Tensor): [B, C', H, W]
125
  """
126
+ flow0, flow1 = self.decode_flow(
127
+ flow_feature0, flow_feature1.shape[2:]
128
+ ), self.decode_flow(flow_feature1, flow_feature0.shape[2:])
129
+ x0_new, flow_feature0_new = self.update(
130
+ x0, x1, flow0.detach(), flow_feature0, pos1, mask0, mask1, ds0, ds1
131
+ )
132
+ x1_new, flow_feature1_new = self.update(
133
+ x1, x0, flow1.detach(), flow_feature1, pos0, mask1, mask0, ds1, ds0
134
+ )
135
+ return x0_new, x1_new, flow_feature0_new, flow_feature1_new, flow0, flow1
136
+
137
+ def update(self, x0, x1, flow0, flow_feature0, pos1, mask0, mask1, ds0, ds1):
138
+ bs = x0.shape[0]
139
+ queries, keys = self.q_proj(x0.view(bs, self.d_model, -1)), self.k_proj(
140
+ x1.view(bs, self.d_model, -1)
141
+ )
142
+ x1_pos = torch.cat([x1, pos1], dim=1)
143
+ values = self.v_proj(x1_pos.view(bs, self.d_value, -1))
144
+ msg = self.attention(
145
+ queries,
146
+ keys,
147
+ values,
148
+ flow0,
149
+ x0.shape[2:],
150
+ x1.shape[2:],
151
+ mask0,
152
+ mask1,
153
+ ds0,
154
+ ds1,
155
+ )
156
 
157
  if self.update_flow:
158
+ update_feature = torch.cat([x0, flow_feature0], dim=1)
159
  else:
160
+ update_feature = x0
161
+ msg = self.norm2(
162
+ self.merge_f(torch.cat([update_feature, self.norm1(msg)], dim=1))
163
+ )
164
+ update_feature = update_feature + msg
165
+
166
+ x0_new, flow_feature0_new = (
167
+ update_feature[:, : self.d_model],
168
+ update_feature[:, self.d_model :],
169
+ )
170
+ return x0_new, flow_feature0_new
171
+
172
+ def decode_flow(self, flow_feature, kshape):
173
+ bs, h, w = flow_feature.shape[0], flow_feature.shape[2], flow_feature.shape[3]
174
+ scale_factor = torch.tensor([kshape[1], kshape[0]]).cuda()[None, None, None]
175
+ flow = (
176
+ self.flow_decoder(flow_feature.view(bs, -1, h * w))
177
+ .permute(0, 2, 1)
178
+ .view(bs, h, w, 4)
179
+ )
180
+ flow_coordinates = torch.sigmoid(flow[:, :, :, :2]) * scale_factor
181
+ flow_var = flow[:, :, :, 2:]
182
+ flow = torch.cat([flow_coordinates, flow_var], dim=-1) # B*H*W*4
183
  return flow
184
 
185
 
186
  class flow_initializer(nn.Module):
 
187
  def __init__(self, dim, dim_flow, nhead, layer_num):
188
  super().__init__()
189
+ self.layer_num = layer_num
190
  self.dim = dim
191
  self.dim_flow = dim_flow
192
 
193
+ encoder_layer = messageLayer_ini(dim, dim_flow, dim + dim_flow, nhead)
 
194
  self.layers_coarse = nn.ModuleList(
195
+ [copy.deepcopy(encoder_layer) for _ in range(layer_num)]
196
+ )
197
+ self.decoupler = nn.Conv2d(self.dim, self.dim + self.dim_flow, kernel_size=1)
198
+ self.up_merge = nn.Conv2d(2 * dim, dim, kernel_size=1)
199
 
200
+ def forward(
201
+ self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
202
+ ):
203
  # feat0: [B, C, H0, W0]
204
  # feat1: [B, C, H1, W1]
205
  # use low-res MHA to initialize flow feature
206
  bs = feat0.size(0)
207
+ h0, w0, h1, w1 = feat0.shape[2], feat0.shape[3], feat1.shape[2], feat1.shape[3]
208
 
209
  # coarse level
210
+ sub_feat0, sub_feat1 = F.avg_pool2d(feat0, ds0, stride=ds0), F.avg_pool2d(
211
+ feat1, ds1, stride=ds1
212
+ )
213
+
214
+ sub_pos0, sub_pos1 = F.avg_pool2d(pos0, ds0, stride=ds0), F.avg_pool2d(
215
+ pos1, ds1, stride=ds1
216
+ )
217
 
 
 
 
218
  if mask0 is not None:
219
+ mask0, mask1 = -F.max_pool2d(
220
+ -mask0.view(bs, 1, h0, w0), ds0, stride=ds0
221
+ ).view(bs, -1), -F.max_pool2d(
222
+ -mask1.view(bs, 1, h1, w1), ds1, stride=ds1
223
+ ).view(
224
+ bs, -1
225
+ )
226
+
227
  for layer in self.layers_coarse:
228
+ sub_feat0, sub_feat1 = layer(
229
+ sub_feat0, sub_feat1, sub_pos0, sub_pos1, mask0, mask1
230
+ )
231
  # decouple flow and visual features
232
+ decoupled_feature0, decoupled_feature1 = self.decoupler(
233
+ sub_feat0
234
+ ), self.decoupler(sub_feat1)
235
+
236
+ sub_feat0, sub_flow_feature0 = (
237
+ decoupled_feature0[:, : self.dim],
238
+ decoupled_feature0[:, self.dim :],
239
+ )
240
+ sub_feat1, sub_flow_feature1 = (
241
+ decoupled_feature1[:, : self.dim],
242
+ decoupled_feature1[:, self.dim :],
243
+ )
244
+ update_feat0, flow_feature0 = F.upsample(
245
+ sub_feat0, scale_factor=ds0, mode="bilinear"
246
+ ), F.upsample(sub_flow_feature0, scale_factor=ds0, mode="bilinear")
247
+ update_feat1, flow_feature1 = F.upsample(
248
+ sub_feat1, scale_factor=ds1, mode="bilinear"
249
+ ), F.upsample(sub_flow_feature1, scale_factor=ds1, mode="bilinear")
250
 
251
+ feat0 = feat0 + self.up_merge(torch.cat([feat0, update_feat0], dim=1))
252
+ feat1 = feat1 + self.up_merge(torch.cat([feat1, update_feat1], dim=1))
253
+
254
+ return feat0, feat1, flow_feature0, flow_feature1 # b*c*h*w
 
 
 
 
 
 
 
255
 
256
 
257
  class LocalFeatureTransformer_Flow(nn.Module):
 
261
  super(LocalFeatureTransformer_Flow, self).__init__()
262
 
263
  self.config = config
264
+ self.d_model = config["d_model"]
265
+ self.nhead = config["nhead"]
266
+
267
+ self.pos_transform = nn.Conv2d(
268
+ config["d_model"], config["d_flow"], kernel_size=1, bias=False
269
+ )
270
+ self.ini_layer = flow_initializer(
271
+ self.d_model, config["d_flow"], config["nhead"], config["ini_layer_num"]
272
+ )
273
 
 
 
 
274
  encoder_layer = messageLayer_gla(
275
+ config["d_model"],
276
+ config["d_flow"],
277
+ config["d_flow"] + config["d_model"],
278
+ config["nhead"],
279
+ config["radius_scale"],
280
+ config["nsample"],
281
+ )
282
+ encoder_layer_last = messageLayer_gla(
283
+ config["d_model"],
284
+ config["d_flow"],
285
+ config["d_flow"] + config["d_model"],
286
+ config["nhead"],
287
+ config["radius_scale"],
288
+ config["nsample"],
289
+ update_flow=False,
290
+ )
291
+ self.layers = nn.ModuleList(
292
+ [copy.deepcopy(encoder_layer) for _ in range(config["layer_num"] - 1)]
293
+ + [encoder_layer_last]
294
+ )
295
  self._reset_parameters()
296
+
297
  def _reset_parameters(self):
298
+ for name, p in self.named_parameters():
299
+ if "temp" in name or "sample_offset" in name:
300
  continue
301
  if p.dim() > 1:
302
  nn.init.xavier_uniform_(p)
303
 
304
+ def forward(
305
+ self, feat0, feat1, pos0, pos1, mask0=None, mask1=None, ds0=[4, 4], ds1=[4, 4]
306
+ ):
307
  """
308
  Args:
309
  feat0 (torch.Tensor): [N, C, H, W]
 
315
  flow_list: [L,N,H,W,4]*1(2)
316
  """
317
  bs = feat0.size(0)
318
+
319
+ pos0, pos1 = self.pos_transform(pos0), self.pos_transform(pos1)
320
+ pos0, pos1 = pos0.expand(bs, -1, -1, -1), pos1.expand(bs, -1, -1, -1)
321
  assert self.d_model == feat0.size(
322
+ 1
323
+ ), "the feature number of src and transformer must be equal"
324
+
325
+ flow_list = [[], []] # [px,py,sx,sy]
326
  if mask0 is not None:
327
+ mask0, mask1 = mask0[:, None].float(), mask1[:, None].float()
328
+ feat0, feat1, flow_feature0, flow_feature1 = self.ini_layer(
329
+ feat0, feat1, pos0, pos1, mask0, mask1, ds0, ds1
330
+ )
331
  for layer in self.layers:
332
+ feat0, feat1, flow_feature0, flow_feature1, flow0, flow1 = layer(
333
+ feat0,
334
+ feat1,
335
+ flow_feature0,
336
+ flow_feature1,
337
+ pos0,
338
+ pos1,
339
+ mask0,
340
+ mask1,
341
+ ds0,
342
+ ds1,
343
+ )
344
  flow_list[0].append(flow0)
345
  flow_list[1].append(flow1)
346
+ flow_list[0] = torch.stack(flow_list[0], dim=0)
347
+ flow_list[1] = torch.stack(flow_list[1], dim=0)
348
+ feat0, feat1 = feat0.permute(0, 2, 3, 1).view(
349
+ bs, -1, self.d_model
350
+ ), feat1.permute(0, 2, 3, 1).view(bs, -1, self.d_model)
351
+ return feat0, feat1, flow_list
third_party/ASpanFormer/src/ASpanFormer/aspanformer.py CHANGED
@@ -5,7 +5,11 @@ from einops.einops import rearrange
5
 
6
  from .backbone import build_backbone
7
  from .utils.position_encoding import PositionEncodingSine
8
- from .aspan_module import LocalFeatureTransformer_Flow, LocalFeatureTransformer, FinePreprocess
 
 
 
 
9
  from .utils.coarse_matching import CoarseMatching
10
  from .utils.fine_matching import FineMatching
11
 
@@ -19,16 +23,18 @@ class ASpanFormer(nn.Module):
19
  # Modules
20
  self.backbone = build_backbone(config)
21
  self.pos_encoding = PositionEncodingSine(
22
- config['coarse']['d_model'],pre_scaling=[config['coarse']['train_res'],config['coarse']['test_res']])
23
- self.loftr_coarse = LocalFeatureTransformer_Flow(config['coarse'])
24
- self.coarse_matching = CoarseMatching(config['match_coarse'])
 
 
25
  self.fine_preprocess = FinePreprocess(config)
26
  self.loftr_fine = LocalFeatureTransformer(config["fine"])
27
  self.fine_matching = FineMatching()
28
- self.coarsest_level=config['coarse']['coarsest_level']
29
 
30
  def forward(self, data, online_resize=False):
31
- """
32
  Update:
33
  data (dict): {
34
  'image0': (torch.Tensor): (N, 1, H, W)
@@ -38,96 +44,135 @@ class ASpanFormer(nn.Module):
38
  }
39
  """
40
  if online_resize:
41
- assert data['image0'].shape[0]==1 and data['image1'].shape[1]==1
42
- self.resize_input(data,self.config['coarse']['train_res'])
43
  else:
44
- data['pos_scale0'],data['pos_scale1']=None,None
45
 
46
  # 1. Local Feature CNN
47
- data.update({
48
- 'bs': data['image0'].size(0),
49
- 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
50
- })
51
-
52
- if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
 
 
 
53
  feats_c, feats_f = self.backbone(
54
- torch.cat([data['image0'], data['image1']], dim=0))
 
55
  (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
56
- data['bs']), feats_f.split(data['bs'])
 
57
  else: # handle different input shapes
58
  (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
59
- data['image0']), self.backbone(data['image1'])
 
60
 
61
- data.update({
62
- 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
63
- 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
64
- })
 
 
 
 
65
 
66
  # 2. coarse-level loftr module
67
  # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
68
- [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(feat_c0,data['pos_scale0']), self.pos_encoding(feat_c1,data['pos_scale1'])
69
- feat_c0 = rearrange(feat_c0, 'n c h w -> n c h w ')
70
- feat_c1 = rearrange(feat_c1, 'n c h w -> n c h w ')
 
 
71
 
72
- #TODO:adjust ds
73
- ds0=[int(data['hw0_c'][0]/self.coarsest_level[0]),int(data['hw0_c'][1]/self.coarsest_level[1])]
74
- ds1=[int(data['hw1_c'][0]/self.coarsest_level[0]),int(data['hw1_c'][1]/self.coarsest_level[1])]
 
 
 
 
 
 
75
  if online_resize:
76
- ds0,ds1=[4,4],[4,4]
77
 
78
  mask_c0 = mask_c1 = None # mask is useful in training
79
- if 'mask0' in data:
80
- mask_c0, mask_c1 = data['mask0'].flatten(
81
- -2), data['mask1'].flatten(-2)
82
  feat_c0, feat_c1, flow_list = self.loftr_coarse(
83
- feat_c0, feat_c1,pos_encoding0,pos_encoding1,mask_c0,mask_c1,ds0,ds1)
 
84
 
85
  # 3. match coarse-level and register predicted offset
86
- self.coarse_matching(feat_c0, feat_c1, flow_list,data,
87
- mask_c0=mask_c0, mask_c1=mask_c1)
 
88
 
89
  # 4. fine-level refinement
90
  feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
91
- feat_f0, feat_f1, feat_c0, feat_c1, data)
 
92
  if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
93
  feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
94
- feat_f0_unfold, feat_f1_unfold)
 
95
 
96
  # 5. match fine-level
97
  self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
98
 
99
  # 6. resize match coordinates back to input resolution
100
  if online_resize:
101
- data['mkpts0_f']*=data['online_resize_scale0']
102
- data['mkpts1_f']*=data['online_resize_scale1']
103
-
104
  def load_state_dict(self, state_dict, *args, **kwargs):
105
  for k in list(state_dict.keys()):
106
- if k.startswith('matcher.'):
107
- if 'sample_offset' in k:
108
  state_dict.pop(k)
109
  else:
110
- state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
111
  return super().load_state_dict(state_dict, *args, **kwargs)
112
-
113
- def resize_input(self,data,train_res,df=32):
114
- h0,w0,h1,w1=data['image0'].shape[2],data['image0'].shape[3],data['image1'].shape[2],data['image1'].shape[3]
115
- data['image0'],data['image1']=self.resize_df(data['image0'],df),self.resize_df(data['image1'],df)
116
-
117
- if len(train_res)==1:
118
- train_res_h=train_res_w=train_res
 
 
 
 
 
 
 
119
  else:
120
- train_res_h,train_res_w=train_res[0],train_res[1]
121
- data['pos_scale0'],data['pos_scale1']=[train_res_h/data['image0'].shape[2],train_res_w/data['image0'].shape[3]],\
122
- [train_res_h/data['image1'].shape[2],train_res_w/data['image1'].shape[3]]
123
- data['online_resize_scale0'],data['online_resize_scale1']=torch.tensor([w0/data['image0'].shape[3],h0/data['image0'].shape[2]])[None].cuda(),\
124
- torch.tensor([w1/data['image1'].shape[3],h1/data['image1'].shape[2]])[None].cuda()
125
-
126
- def resize_df(self,image,df=32):
127
- h,w=image.shape[2],image.shape[3]
128
- h_new,w_new=h//df*df,w//df*df
129
- if h!=h_new or w!=w_new:
130
- img_new=transforms.Resize([h_new,w_new]).forward(image)
 
 
 
 
 
 
 
 
 
 
 
131
  else:
132
- img_new=image
133
  return img_new
 
5
 
6
  from .backbone import build_backbone
7
  from .utils.position_encoding import PositionEncodingSine
8
+ from .aspan_module import (
9
+ LocalFeatureTransformer_Flow,
10
+ LocalFeatureTransformer,
11
+ FinePreprocess,
12
+ )
13
  from .utils.coarse_matching import CoarseMatching
14
  from .utils.fine_matching import FineMatching
15
 
 
23
  # Modules
24
  self.backbone = build_backbone(config)
25
  self.pos_encoding = PositionEncodingSine(
26
+ config["coarse"]["d_model"],
27
+ pre_scaling=[config["coarse"]["train_res"], config["coarse"]["test_res"]],
28
+ )
29
+ self.loftr_coarse = LocalFeatureTransformer_Flow(config["coarse"])
30
+ self.coarse_matching = CoarseMatching(config["match_coarse"])
31
  self.fine_preprocess = FinePreprocess(config)
32
  self.loftr_fine = LocalFeatureTransformer(config["fine"])
33
  self.fine_matching = FineMatching()
34
+ self.coarsest_level = config["coarse"]["coarsest_level"]
35
 
36
  def forward(self, data, online_resize=False):
37
+ """
38
  Update:
39
  data (dict): {
40
  'image0': (torch.Tensor): (N, 1, H, W)
 
44
  }
45
  """
46
  if online_resize:
47
+ assert data["image0"].shape[0] == 1 and data["image1"].shape[1] == 1
48
+ self.resize_input(data, self.config["coarse"]["train_res"])
49
  else:
50
+ data["pos_scale0"], data["pos_scale1"] = None, None
51
 
52
  # 1. Local Feature CNN
53
+ data.update(
54
+ {
55
+ "bs": data["image0"].size(0),
56
+ "hw0_i": data["image0"].shape[2:],
57
+ "hw1_i": data["image1"].shape[2:],
58
+ }
59
+ )
60
+
61
+ if data["hw0_i"] == data["hw1_i"]: # faster & better BN convergence
62
  feats_c, feats_f = self.backbone(
63
+ torch.cat([data["image0"], data["image1"]], dim=0)
64
+ )
65
  (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
66
+ data["bs"]
67
+ ), feats_f.split(data["bs"])
68
  else: # handle different input shapes
69
  (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
70
+ data["image0"]
71
+ ), self.backbone(data["image1"])
72
 
73
+ data.update(
74
+ {
75
+ "hw0_c": feat_c0.shape[2:],
76
+ "hw1_c": feat_c1.shape[2:],
77
+ "hw0_f": feat_f0.shape[2:],
78
+ "hw1_f": feat_f1.shape[2:],
79
+ }
80
+ )
81
 
82
  # 2. coarse-level loftr module
83
  # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
84
+ [feat_c0, pos_encoding0], [feat_c1, pos_encoding1] = self.pos_encoding(
85
+ feat_c0, data["pos_scale0"]
86
+ ), self.pos_encoding(feat_c1, data["pos_scale1"])
87
+ feat_c0 = rearrange(feat_c0, "n c h w -> n c h w ")
88
+ feat_c1 = rearrange(feat_c1, "n c h w -> n c h w ")
89
 
90
+ # TODO:adjust ds
91
+ ds0 = [
92
+ int(data["hw0_c"][0] / self.coarsest_level[0]),
93
+ int(data["hw0_c"][1] / self.coarsest_level[1]),
94
+ ]
95
+ ds1 = [
96
+ int(data["hw1_c"][0] / self.coarsest_level[0]),
97
+ int(data["hw1_c"][1] / self.coarsest_level[1]),
98
+ ]
99
  if online_resize:
100
+ ds0, ds1 = [4, 4], [4, 4]
101
 
102
  mask_c0 = mask_c1 = None # mask is useful in training
103
+ if "mask0" in data:
104
+ mask_c0, mask_c1 = data["mask0"].flatten(-2), data["mask1"].flatten(-2)
 
105
  feat_c0, feat_c1, flow_list = self.loftr_coarse(
106
+ feat_c0, feat_c1, pos_encoding0, pos_encoding1, mask_c0, mask_c1, ds0, ds1
107
+ )
108
 
109
  # 3. match coarse-level and register predicted offset
110
+ self.coarse_matching(
111
+ feat_c0, feat_c1, flow_list, data, mask_c0=mask_c0, mask_c1=mask_c1
112
+ )
113
 
114
  # 4. fine-level refinement
115
  feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
116
+ feat_f0, feat_f1, feat_c0, feat_c1, data
117
+ )
118
  if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
119
  feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
120
+ feat_f0_unfold, feat_f1_unfold
121
+ )
122
 
123
  # 5. match fine-level
124
  self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
125
 
126
  # 6. resize match coordinates back to input resolution
127
  if online_resize:
128
+ data["mkpts0_f"] *= data["online_resize_scale0"]
129
+ data["mkpts1_f"] *= data["online_resize_scale1"]
130
+
131
  def load_state_dict(self, state_dict, *args, **kwargs):
132
  for k in list(state_dict.keys()):
133
+ if k.startswith("matcher."):
134
+ if "sample_offset" in k:
135
  state_dict.pop(k)
136
  else:
137
+ state_dict[k.replace("matcher.", "", 1)] = state_dict.pop(k)
138
  return super().load_state_dict(state_dict, *args, **kwargs)
139
+
140
+ def resize_input(self, data, train_res, df=32):
141
+ h0, w0, h1, w1 = (
142
+ data["image0"].shape[2],
143
+ data["image0"].shape[3],
144
+ data["image1"].shape[2],
145
+ data["image1"].shape[3],
146
+ )
147
+ data["image0"], data["image1"] = self.resize_df(
148
+ data["image0"], df
149
+ ), self.resize_df(data["image1"], df)
150
+
151
+ if len(train_res) == 1:
152
+ train_res_h = train_res_w = train_res
153
  else:
154
+ train_res_h, train_res_w = train_res[0], train_res[1]
155
+ data["pos_scale0"], data["pos_scale1"] = [
156
+ train_res_h / data["image0"].shape[2],
157
+ train_res_w / data["image0"].shape[3],
158
+ ], [
159
+ train_res_h / data["image1"].shape[2],
160
+ train_res_w / data["image1"].shape[3],
161
+ ]
162
+ data["online_resize_scale0"], data["online_resize_scale1"] = (
163
+ torch.tensor([w0 / data["image0"].shape[3], h0 / data["image0"].shape[2]])[
164
+ None
165
+ ].cuda(),
166
+ torch.tensor([w1 / data["image1"].shape[3], h1 / data["image1"].shape[2]])[
167
+ None
168
+ ].cuda(),
169
+ )
170
+
171
+ def resize_df(self, image, df=32):
172
+ h, w = image.shape[2], image.shape[3]
173
+ h_new, w_new = h // df * df, w // df * df
174
+ if h != h_new or w != w_new:
175
+ img_new = transforms.Resize([h_new, w_new]).forward(image)
176
  else:
177
+ img_new = image
178
  return img_new
third_party/ASpanFormer/src/ASpanFormer/backbone/__init__.py CHANGED
@@ -2,10 +2,12 @@ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
2
 
3
 
4
  def build_backbone(config):
5
- if config['backbone_type'] == 'ResNetFPN':
6
- if config['resolution'] == (8, 2):
7
- return ResNetFPN_8_2(config['resnetfpn'])
8
- elif config['resolution'] == (16, 4):
9
- return ResNetFPN_16_4(config['resnetfpn'])
10
  else:
11
- raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
 
 
 
2
 
3
 
4
  def build_backbone(config):
5
+ if config["backbone_type"] == "ResNetFPN":
6
+ if config["resolution"] == (8, 2):
7
+ return ResNetFPN_8_2(config["resnetfpn"])
8
+ elif config["resolution"] == (16, 4):
9
+ return ResNetFPN_16_4(config["resnetfpn"])
10
  else:
11
+ raise ValueError(
12
+ f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported."
13
+ )
third_party/ASpanFormer/src/ASpanFormer/backbone/resnet_fpn.py CHANGED
@@ -4,12 +4,16 @@ import torch.nn.functional as F
4
 
5
  def conv1x1(in_planes, out_planes, stride=1):
6
  """1x1 convolution without padding"""
7
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
 
 
8
 
9
 
10
  def conv3x3(in_planes, out_planes, stride=1):
11
  """3x3 convolution with padding"""
12
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
13
 
14
 
15
  class BasicBlock(nn.Module):
@@ -25,8 +29,7 @@ class BasicBlock(nn.Module):
25
  self.downsample = None
26
  else:
27
  self.downsample = nn.Sequential(
28
- conv1x1(in_planes, planes, stride=stride),
29
- nn.BatchNorm2d(planes)
30
  )
31
 
32
  def forward(self, x):
@@ -37,7 +40,7 @@ class BasicBlock(nn.Module):
37
  if self.downsample is not None:
38
  x = self.downsample(x)
39
 
40
- return self.relu(x+y)
41
 
42
 
43
  class ResNetFPN_8_2(nn.Module):
@@ -50,14 +53,16 @@ class ResNetFPN_8_2(nn.Module):
50
  super().__init__()
51
  # Config
52
  block = BasicBlock
53
- initial_dim = config['initial_dim']
54
- block_dims = config['block_dims']
55
 
56
  # Class Variable
57
  self.in_planes = initial_dim
58
 
59
  # Networks
60
- self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
 
 
61
  self.bn1 = nn.BatchNorm2d(initial_dim)
62
  self.relu = nn.ReLU(inplace=True)
63
 
@@ -84,7 +89,7 @@ class ResNetFPN_8_2(nn.Module):
84
 
85
  for m in self.modules():
86
  if isinstance(m, nn.Conv2d):
87
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
89
  nn.init.constant_(m.weight, 1)
90
  nn.init.constant_(m.bias, 0)
@@ -107,13 +112,17 @@ class ResNetFPN_8_2(nn.Module):
107
  # FPN
108
  x3_out = self.layer3_outconv(x3)
109
 
110
- x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
 
 
111
  x2_out = self.layer2_outconv(x2)
112
- x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
113
 
114
- x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
 
 
115
  x1_out = self.layer1_outconv(x1)
116
- x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
117
 
118
  return [x3_out, x1_out]
119
 
@@ -128,14 +137,16 @@ class ResNetFPN_16_4(nn.Module):
128
  super().__init__()
129
  # Config
130
  block = BasicBlock
131
- initial_dim = config['initial_dim']
132
- block_dims = config['block_dims']
133
 
134
  # Class Variable
135
  self.in_planes = initial_dim
136
 
137
  # Networks
138
- self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
 
 
139
  self.bn1 = nn.BatchNorm2d(initial_dim)
140
  self.relu = nn.ReLU(inplace=True)
141
 
@@ -164,7 +175,7 @@ class ResNetFPN_16_4(nn.Module):
164
 
165
  for m in self.modules():
166
  if isinstance(m, nn.Conv2d):
167
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
  nn.init.constant_(m.weight, 1)
170
  nn.init.constant_(m.bias, 0)
@@ -188,12 +199,16 @@ class ResNetFPN_16_4(nn.Module):
188
  # FPN
189
  x4_out = self.layer4_outconv(x4)
190
 
191
- x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
 
 
192
  x3_out = self.layer3_outconv(x3)
193
- x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
194
 
195
- x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
 
 
196
  x2_out = self.layer2_outconv(x2)
197
- x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
198
 
199
  return [x4_out, x2_out]
 
4
 
5
  def conv1x1(in_planes, out_planes, stride=1):
6
  """1x1 convolution without padding"""
7
+ return nn.Conv2d(
8
+ in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False
9
+ )
10
 
11
 
12
  def conv3x3(in_planes, out_planes, stride=1):
13
  """3x3 convolution with padding"""
14
+ return nn.Conv2d(
15
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
16
+ )
17
 
18
 
19
  class BasicBlock(nn.Module):
 
29
  self.downsample = None
30
  else:
31
  self.downsample = nn.Sequential(
32
+ conv1x1(in_planes, planes, stride=stride), nn.BatchNorm2d(planes)
 
33
  )
34
 
35
  def forward(self, x):
 
40
  if self.downsample is not None:
41
  x = self.downsample(x)
42
 
43
+ return self.relu(x + y)
44
 
45
 
46
  class ResNetFPN_8_2(nn.Module):
 
53
  super().__init__()
54
  # Config
55
  block = BasicBlock
56
+ initial_dim = config["initial_dim"]
57
+ block_dims = config["block_dims"]
58
 
59
  # Class Variable
60
  self.in_planes = initial_dim
61
 
62
  # Networks
63
+ self.conv1 = nn.Conv2d(
64
+ 1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
65
+ )
66
  self.bn1 = nn.BatchNorm2d(initial_dim)
67
  self.relu = nn.ReLU(inplace=True)
68
 
 
89
 
90
  for m in self.modules():
91
  if isinstance(m, nn.Conv2d):
92
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
93
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
94
  nn.init.constant_(m.weight, 1)
95
  nn.init.constant_(m.bias, 0)
 
112
  # FPN
113
  x3_out = self.layer3_outconv(x3)
114
 
115
+ x3_out_2x = F.interpolate(
116
+ x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
117
+ )
118
  x2_out = self.layer2_outconv(x2)
119
+ x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
120
 
121
+ x2_out_2x = F.interpolate(
122
+ x2_out, scale_factor=2.0, mode="bilinear", align_corners=True
123
+ )
124
  x1_out = self.layer1_outconv(x1)
125
+ x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
126
 
127
  return [x3_out, x1_out]
128
 
 
137
  super().__init__()
138
  # Config
139
  block = BasicBlock
140
+ initial_dim = config["initial_dim"]
141
+ block_dims = config["block_dims"]
142
 
143
  # Class Variable
144
  self.in_planes = initial_dim
145
 
146
  # Networks
147
+ self.conv1 = nn.Conv2d(
148
+ 1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False
149
+ )
150
  self.bn1 = nn.BatchNorm2d(initial_dim)
151
  self.relu = nn.ReLU(inplace=True)
152
 
 
175
 
176
  for m in self.modules():
177
  if isinstance(m, nn.Conv2d):
178
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
179
  elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
180
  nn.init.constant_(m.weight, 1)
181
  nn.init.constant_(m.bias, 0)
 
199
  # FPN
200
  x4_out = self.layer4_outconv(x4)
201
 
202
+ x4_out_2x = F.interpolate(
203
+ x4_out, scale_factor=2.0, mode="bilinear", align_corners=True
204
+ )
205
  x3_out = self.layer3_outconv(x3)
206
+ x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
207
 
208
+ x3_out_2x = F.interpolate(
209
+ x3_out, scale_factor=2.0, mode="bilinear", align_corners=True
210
+ )
211
  x2_out = self.layer2_outconv(x2)
212
+ x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
213
 
214
  return [x4_out, x2_out]
third_party/ASpanFormer/src/ASpanFormer/utils/coarse_matching.py CHANGED
@@ -7,8 +7,9 @@ from time import time
7
 
8
  INF = 1e9
9
 
 
10
  def mask_border(m, b: int, v):
11
- """ Mask borders with value
12
  Args:
13
  m (torch.Tensor): [N, H0, W0, H1, W1]
14
  b (int)
@@ -39,22 +40,21 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
39
  h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
40
  h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
41
  for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
42
- m[b_idx, h0 - bd:] = v
43
- m[b_idx, :, w0 - bd:] = v
44
- m[b_idx, :, :, h1 - bd:] = v
45
- m[b_idx, :, :, :, w1 - bd:] = v
46
 
47
 
48
  def compute_max_candidates(p_m0, p_m1):
49
  """Compute the max candidates of all pairs within a batch
50
-
51
  Args:
52
  p_m0, p_m1 (torch.Tensor): padded masks
53
  """
54
  h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
55
  h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
56
- max_cand = torch.sum(
57
- torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
58
  return max_cand
59
 
60
 
@@ -63,29 +63,32 @@ class CoarseMatching(nn.Module):
63
  super().__init__()
64
  self.config = config
65
  # general config
66
- self.thr = config['thr']
67
- self.border_rm = config['border_rm']
68
  # -- # for trainig fine-level LoFTR
69
- self.train_coarse_percent = config['train_coarse_percent']
70
- self.train_pad_num_gt_min = config['train_pad_num_gt_min']
71
-
72
  # we provide 2 options for differentiable matching
73
- self.match_type = config['match_type']
74
- if self.match_type == 'dual_softmax':
75
- self.temperature=nn.parameter.Parameter(torch.tensor(10.), requires_grad=True)
76
- elif self.match_type == 'sinkhorn':
 
 
77
  try:
78
  from .superglue import log_optimal_transport
79
  except ImportError:
80
  raise ImportError("download superglue.py first!")
81
  self.log_optimal_transport = log_optimal_transport
82
  self.bin_score = nn.Parameter(
83
- torch.tensor(config['skh_init_bin_score'], requires_grad=True))
84
- self.skh_iters = config['skh_iters']
85
- self.skh_prefilter = config['skh_prefilter']
 
86
  else:
87
  raise NotImplementedError()
88
-
89
  def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
90
  """
91
  Args:
@@ -108,29 +111,32 @@ class CoarseMatching(nn.Module):
108
  """
109
  N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
110
  # normalize
111
- feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
112
- [feat_c0, feat_c1])
113
-
114
- if self.match_type == 'dual_softmax':
115
- sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
116
- feat_c1) * self.temperature
 
 
117
  if mask_c0 is not None:
118
  sim_matrix.masked_fill_(
119
- ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
120
- -INF)
121
  conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
122
-
123
- elif self.match_type == 'sinkhorn':
124
  # sinkhorn, dustbin included
125
  sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
126
  if mask_c0 is not None:
127
  sim_matrix[:, :L, :S].masked_fill_(
128
- ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
129
- -INF)
130
 
131
  # build uniform prior & use sinkhorn
132
  log_assign_matrix = self.log_optimal_transport(
133
- sim_matrix, self.bin_score, self.skh_iters)
 
134
  assign_matrix = log_assign_matrix.exp()
135
  conf_matrix = assign_matrix[:, :-1, :-1]
136
 
@@ -141,18 +147,21 @@ class CoarseMatching(nn.Module):
141
  conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
142
  conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
143
 
144
- if self.config['sparse_spvs']:
145
- data.update({'conf_matrix_with_bin': assign_matrix.clone()})
146
 
147
- data.update({'conf_matrix': conf_matrix})
148
  # predict coarse matches from conf_matrix
149
  data.update(**self.get_coarse_match(conf_matrix, data))
150
 
151
- #update predicted offset
152
- if flow_list[0].shape[2]==flow_list[1].shape[2] and flow_list[0].shape[3]==flow_list[1].shape[3]:
153
- flow_list=torch.stack(flow_list,dim=0)
154
- data.update({'predict_flow':flow_list}) #[2*L*B*H*W*4]
155
- self.get_offset_match(flow_list,data,mask_c0,mask_c1)
 
 
 
156
 
157
  @torch.no_grad()
158
  def get_coarse_match(self, conf_matrix, data):
@@ -172,28 +181,33 @@ class CoarseMatching(nn.Module):
172
  'mconf' (torch.Tensor): [M]}
173
  """
174
  axes_lengths = {
175
- 'h0c': data['hw0_c'][0],
176
- 'w0c': data['hw0_c'][1],
177
- 'h1c': data['hw1_c'][0],
178
- 'w1c': data['hw1_c'][1]
179
  }
180
  _device = conf_matrix.device
181
  # 1. confidence thresholding
182
  mask = conf_matrix > self.thr
183
- mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
184
- **axes_lengths)
185
- if 'mask0' not in data:
 
186
  mask_border(mask, self.border_rm, False)
187
  else:
188
- mask_border_with_padding(mask, self.border_rm, False,
189
- data['mask0'], data['mask1'])
190
- mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
191
- **axes_lengths)
 
 
192
 
193
  # 2. mutual nearest
194
- mask = mask \
195
- * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
 
196
  * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
 
197
 
198
  # 3. find all valid coarse matches
199
  # this only works when at most one `True` in each row
@@ -208,67 +222,79 @@ class CoarseMatching(nn.Module):
208
  # NOTE:
209
  # The sampling is performed across all pairs in a batch without manually balancing
210
  # #samples for fine-level increases w.r.t. batch_size
211
- if 'mask0' not in data:
212
- num_candidates_max = mask.size(0) * max(
213
- mask.size(1), mask.size(2))
214
  else:
215
  num_candidates_max = compute_max_candidates(
216
- data['mask0'], data['mask1'])
217
- num_matches_train = int(num_candidates_max *
218
- self.train_coarse_percent)
219
  num_matches_pred = len(b_ids)
220
- assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
221
-
 
 
222
  # pred_indices is to select from prediction
223
  if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
224
  pred_indices = torch.arange(num_matches_pred, device=_device)
225
  else:
226
  pred_indices = torch.randint(
227
  num_matches_pred,
228
- (num_matches_train - self.train_pad_num_gt_min, ),
229
- device=_device)
 
230
 
231
  # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
232
  gt_pad_indices = torch.randint(
233
- len(data['spv_b_ids']),
234
- (max(num_matches_train - num_matches_pred,
235
- self.train_pad_num_gt_min), ),
236
- device=_device)
237
- mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
 
 
238
 
239
  b_ids, i_ids, j_ids, mconf = map(
240
- lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
241
- dim=0),
242
- *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
243
- [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
 
 
 
 
244
 
245
  # These matches select patches that feed into fine-level network
246
- coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
247
 
248
  # 4. Update with matches in original image resolution
249
- scale = data['hw0_i'][0] / data['hw0_c'][0]
250
- scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
251
- scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
252
- mkpts0_c = torch.stack(
253
- [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
254
- dim=1) * scale0
255
- mkpts1_c = torch.stack(
256
- [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
257
- dim=1) * scale1
 
 
258
 
259
  # These matches is the current prediction (for visualization)
260
- coarse_matches.update({
261
- 'gt_mask': mconf == 0,
262
- 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
263
- 'mkpts0_c': mkpts0_c[mconf != 0],
264
- 'mkpts1_c': mkpts1_c[mconf != 0],
265
- 'mconf': mconf[mconf != 0]
266
- })
 
 
267
 
268
  return coarse_matches
269
 
270
  @torch.no_grad()
271
- def get_offset_match(self, flow_list, data,mask1,mask2):
272
  """
273
  Args:
274
  offset (torch.Tensor): [L, B, H, W, 2]
@@ -280,52 +306,62 @@ class CoarseMatching(nn.Module):
280
  'mkpts1_c' (torch.Tensor): [M, 2],
281
  'mconf' (torch.Tensor): [M]}
282
  """
283
- offset1=flow_list[0]
284
- bs,layer_num=offset1.shape[1],offset1.shape[0]
285
-
286
- #left side
287
- offset1=offset1.view(layer_num,bs,-1,4)
288
- conf1=offset1[:,:,:,2:].mean(dim=-1)
289
  if mask1 is not None:
290
- conf1.masked_fill_(~mask1.bool()[None].expand(layer_num,-1,-1),100)
291
- offset1=offset1[:,:,:,:2]
292
- self.get_offset_match_work(offset1,conf1,data,'left')
293
-
294
- #rihgt side
295
- if len(flow_list)==2:
296
- offset2=flow_list[1].view(layer_num,bs,-1,4)
297
- conf2=offset2[:,:,:,2:].mean(dim=-1)
298
  if mask2 is not None:
299
- conf2.masked_fill_(~mask2.bool()[None].expand(layer_num,-1,-1),100)
300
- offset2=offset2[:,:,:,:2]
301
- self.get_offset_match_work(offset2,conf2,data,'right')
302
-
303
 
304
  @torch.no_grad()
305
- def get_offset_match_work(self, offset,conf, data,side):
306
- bs,layer_num=offset.shape[1],offset.shape[0]
307
  # 1. confidence thresholding
308
- mask_conf= conf<2
309
  for index in range(bs):
310
- mask_conf[:,index,0]=True #safe guard in case that no match survives
311
  # 3. find offset matches
312
- scale = data['hw0_i'][0] / data['hw0_c'][0]
313
- l_ids,b_ids,i_ids = torch.where(mask_conf)
314
- j_coor=offset[l_ids,b_ids,i_ids,:2] *scale#[N,2]
315
- i_coor=torch.stack([i_ids%data['hw0_c'][1],i_ids//data['hw0_c'][1]],dim=1)*scale
316
- #i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2]
 
 
 
317
  # These matches is the current prediction (for visualization)
318
- data.update({
319
- 'offset_bids_'+side: b_ids, # mconf == 0 => gt matches
320
- 'offset_lids_'+side: l_ids,
321
- 'conf'+side: conf[mask_conf]
322
- })
323
-
324
- if side=='right':
325
- data.update({'offset_kpts0_f_'+side: j_coor.detach(),
326
- 'offset_kpts1_f_'+side: i_coor})
 
 
 
 
 
 
327
  else:
328
- data.update({'offset_kpts0_f_'+side: i_coor,
329
- 'offset_kpts1_f_'+side: j_coor.detach()})
330
-
331
-
 
 
 
7
 
8
  INF = 1e9
9
 
10
+
11
  def mask_border(m, b: int, v):
12
+ """Mask borders with value
13
  Args:
14
  m (torch.Tensor): [N, H0, W0, H1, W1]
15
  b (int)
 
40
  h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
41
  h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
42
  for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
43
+ m[b_idx, h0 - bd :] = v
44
+ m[b_idx, :, w0 - bd :] = v
45
+ m[b_idx, :, :, h1 - bd :] = v
46
+ m[b_idx, :, :, :, w1 - bd :] = v
47
 
48
 
49
  def compute_max_candidates(p_m0, p_m1):
50
  """Compute the max candidates of all pairs within a batch
51
+
52
  Args:
53
  p_m0, p_m1 (torch.Tensor): padded masks
54
  """
55
  h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
56
  h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
57
+ max_cand = torch.sum(torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
 
58
  return max_cand
59
 
60
 
 
63
  super().__init__()
64
  self.config = config
65
  # general config
66
+ self.thr = config["thr"]
67
+ self.border_rm = config["border_rm"]
68
  # -- # for trainig fine-level LoFTR
69
+ self.train_coarse_percent = config["train_coarse_percent"]
70
+ self.train_pad_num_gt_min = config["train_pad_num_gt_min"]
71
+
72
  # we provide 2 options for differentiable matching
73
+ self.match_type = config["match_type"]
74
+ if self.match_type == "dual_softmax":
75
+ self.temperature = nn.parameter.Parameter(
76
+ torch.tensor(10.0), requires_grad=True
77
+ )
78
+ elif self.match_type == "sinkhorn":
79
  try:
80
  from .superglue import log_optimal_transport
81
  except ImportError:
82
  raise ImportError("download superglue.py first!")
83
  self.log_optimal_transport = log_optimal_transport
84
  self.bin_score = nn.Parameter(
85
+ torch.tensor(config["skh_init_bin_score"], requires_grad=True)
86
+ )
87
+ self.skh_iters = config["skh_iters"]
88
+ self.skh_prefilter = config["skh_prefilter"]
89
  else:
90
  raise NotImplementedError()
91
+
92
  def forward(self, feat_c0, feat_c1, flow_list, data, mask_c0=None, mask_c1=None):
93
  """
94
  Args:
 
111
  """
112
  N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
113
  # normalize
114
+ feat_c0, feat_c1 = map(
115
+ lambda feat: feat / feat.shape[-1] ** 0.5, [feat_c0, feat_c1]
116
+ )
117
+
118
+ if self.match_type == "dual_softmax":
119
+ sim_matrix = (
120
+ torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) * self.temperature
121
+ )
122
  if mask_c0 is not None:
123
  sim_matrix.masked_fill_(
124
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
125
+ )
126
  conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
127
+
128
+ elif self.match_type == "sinkhorn":
129
  # sinkhorn, dustbin included
130
  sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
131
  if mask_c0 is not None:
132
  sim_matrix[:, :L, :S].masked_fill_(
133
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF
134
+ )
135
 
136
  # build uniform prior & use sinkhorn
137
  log_assign_matrix = self.log_optimal_transport(
138
+ sim_matrix, self.bin_score, self.skh_iters
139
+ )
140
  assign_matrix = log_assign_matrix.exp()
141
  conf_matrix = assign_matrix[:, :-1, :-1]
142
 
 
147
  conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
148
  conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
149
 
150
+ if self.config["sparse_spvs"]:
151
+ data.update({"conf_matrix_with_bin": assign_matrix.clone()})
152
 
153
+ data.update({"conf_matrix": conf_matrix})
154
  # predict coarse matches from conf_matrix
155
  data.update(**self.get_coarse_match(conf_matrix, data))
156
 
157
+ # update predicted offset
158
+ if (
159
+ flow_list[0].shape[2] == flow_list[1].shape[2]
160
+ and flow_list[0].shape[3] == flow_list[1].shape[3]
161
+ ):
162
+ flow_list = torch.stack(flow_list, dim=0)
163
+ data.update({"predict_flow": flow_list}) # [2*L*B*H*W*4]
164
+ self.get_offset_match(flow_list, data, mask_c0, mask_c1)
165
 
166
  @torch.no_grad()
167
  def get_coarse_match(self, conf_matrix, data):
 
181
  'mconf' (torch.Tensor): [M]}
182
  """
183
  axes_lengths = {
184
+ "h0c": data["hw0_c"][0],
185
+ "w0c": data["hw0_c"][1],
186
+ "h1c": data["hw1_c"][0],
187
+ "w1c": data["hw1_c"][1],
188
  }
189
  _device = conf_matrix.device
190
  # 1. confidence thresholding
191
  mask = conf_matrix > self.thr
192
+ mask = rearrange(
193
+ mask, "b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c", **axes_lengths
194
+ )
195
+ if "mask0" not in data:
196
  mask_border(mask, self.border_rm, False)
197
  else:
198
+ mask_border_with_padding(
199
+ mask, self.border_rm, False, data["mask0"], data["mask1"]
200
+ )
201
+ mask = rearrange(
202
+ mask, "b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)", **axes_lengths
203
+ )
204
 
205
  # 2. mutual nearest
206
+ mask = (
207
+ mask
208
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0])
209
  * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
210
+ )
211
 
212
  # 3. find all valid coarse matches
213
  # this only works when at most one `True` in each row
 
222
  # NOTE:
223
  # The sampling is performed across all pairs in a batch without manually balancing
224
  # #samples for fine-level increases w.r.t. batch_size
225
+ if "mask0" not in data:
226
+ num_candidates_max = mask.size(0) * max(mask.size(1), mask.size(2))
 
227
  else:
228
  num_candidates_max = compute_max_candidates(
229
+ data["mask0"], data["mask1"]
230
+ )
231
+ num_matches_train = int(num_candidates_max * self.train_coarse_percent)
232
  num_matches_pred = len(b_ids)
233
+ assert (
234
+ self.train_pad_num_gt_min < num_matches_train
235
+ ), "min-num-gt-pad should be less than num-train-matches"
236
+
237
  # pred_indices is to select from prediction
238
  if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
239
  pred_indices = torch.arange(num_matches_pred, device=_device)
240
  else:
241
  pred_indices = torch.randint(
242
  num_matches_pred,
243
+ (num_matches_train - self.train_pad_num_gt_min,),
244
+ device=_device,
245
+ )
246
 
247
  # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
248
  gt_pad_indices = torch.randint(
249
+ len(data["spv_b_ids"]),
250
+ (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min),),
251
+ device=_device,
252
+ )
253
+ mconf_gt = torch.zeros(
254
+ len(data["spv_b_ids"]), device=_device
255
+ ) # set conf of gt paddings to all zero
256
 
257
  b_ids, i_ids, j_ids, mconf = map(
258
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0),
259
+ *zip(
260
+ [b_ids, data["spv_b_ids"]],
261
+ [i_ids, data["spv_i_ids"]],
262
+ [j_ids, data["spv_j_ids"]],
263
+ [mconf, mconf_gt],
264
+ )
265
+ )
266
 
267
  # These matches select patches that feed into fine-level network
268
+ coarse_matches = {"b_ids": b_ids, "i_ids": i_ids, "j_ids": j_ids}
269
 
270
  # 4. Update with matches in original image resolution
271
+ scale = data["hw0_i"][0] / data["hw0_c"][0]
272
+ scale0 = scale * data["scale0"][b_ids] if "scale0" in data else scale
273
+ scale1 = scale * data["scale1"][b_ids] if "scale1" in data else scale
274
+ mkpts0_c = (
275
+ torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
276
+ * scale0
277
+ )
278
+ mkpts1_c = (
279
+ torch.stack([j_ids % data["hw1_c"][1], j_ids // data["hw1_c"][1]], dim=1)
280
+ * scale1
281
+ )
282
 
283
  # These matches is the current prediction (for visualization)
284
+ coarse_matches.update(
285
+ {
286
+ "gt_mask": mconf == 0,
287
+ "m_bids": b_ids[mconf != 0], # mconf == 0 => gt matches
288
+ "mkpts0_c": mkpts0_c[mconf != 0],
289
+ "mkpts1_c": mkpts1_c[mconf != 0],
290
+ "mconf": mconf[mconf != 0],
291
+ }
292
+ )
293
 
294
  return coarse_matches
295
 
296
  @torch.no_grad()
297
+ def get_offset_match(self, flow_list, data, mask1, mask2):
298
  """
299
  Args:
300
  offset (torch.Tensor): [L, B, H, W, 2]
 
306
  'mkpts1_c' (torch.Tensor): [M, 2],
307
  'mconf' (torch.Tensor): [M]}
308
  """
309
+ offset1 = flow_list[0]
310
+ bs, layer_num = offset1.shape[1], offset1.shape[0]
311
+
312
+ # left side
313
+ offset1 = offset1.view(layer_num, bs, -1, 4)
314
+ conf1 = offset1[:, :, :, 2:].mean(dim=-1)
315
  if mask1 is not None:
316
+ conf1.masked_fill_(~mask1.bool()[None].expand(layer_num, -1, -1), 100)
317
+ offset1 = offset1[:, :, :, :2]
318
+ self.get_offset_match_work(offset1, conf1, data, "left")
319
+
320
+ # rihgt side
321
+ if len(flow_list) == 2:
322
+ offset2 = flow_list[1].view(layer_num, bs, -1, 4)
323
+ conf2 = offset2[:, :, :, 2:].mean(dim=-1)
324
  if mask2 is not None:
325
+ conf2.masked_fill_(~mask2.bool()[None].expand(layer_num, -1, -1), 100)
326
+ offset2 = offset2[:, :, :, :2]
327
+ self.get_offset_match_work(offset2, conf2, data, "right")
 
328
 
329
  @torch.no_grad()
330
+ def get_offset_match_work(self, offset, conf, data, side):
331
+ bs, layer_num = offset.shape[1], offset.shape[0]
332
  # 1. confidence thresholding
333
+ mask_conf = conf < 2
334
  for index in range(bs):
335
+ mask_conf[:, index, 0] = True # safe guard in case that no match survives
336
  # 3. find offset matches
337
+ scale = data["hw0_i"][0] / data["hw0_c"][0]
338
+ l_ids, b_ids, i_ids = torch.where(mask_conf)
339
+ j_coor = offset[l_ids, b_ids, i_ids, :2] * scale # [N,2]
340
+ i_coor = (
341
+ torch.stack([i_ids % data["hw0_c"][1], i_ids // data["hw0_c"][1]], dim=1)
342
+ * scale
343
+ )
344
+ # i_coor=torch.as_tensor([[index%data['hw0_c'][1],index//data['hw0_c'][1]] for index in i_ids]).cuda().float()*scale #[N,2]
345
  # These matches is the current prediction (for visualization)
346
+ data.update(
347
+ {
348
+ "offset_bids_" + side: b_ids, # mconf == 0 => gt matches
349
+ "offset_lids_" + side: l_ids,
350
+ "conf" + side: conf[mask_conf],
351
+ }
352
+ )
353
+
354
+ if side == "right":
355
+ data.update(
356
+ {
357
+ "offset_kpts0_f_" + side: j_coor.detach(),
358
+ "offset_kpts1_f_" + side: i_coor,
359
+ }
360
+ )
361
  else:
362
+ data.update(
363
+ {
364
+ "offset_kpts0_f_" + side: i_coor,
365
+ "offset_kpts1_f_" + side: j_coor.detach(),
366
+ }
367
+ )
third_party/ASpanFormer/src/ASpanFormer/utils/cvpr_ds_config.py CHANGED
@@ -8,7 +8,7 @@ def lower_config(yacs_cfg):
8
 
9
 
10
  _CN = CN()
11
- _CN.BACKBONE_TYPE = 'ResNetFPN'
12
  _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
13
  _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
14
  _CN.FINE_CONCAT_COARSE_FEAT = True
@@ -23,15 +23,15 @@ _CN.COARSE = CN()
23
  _CN.COARSE.D_MODEL = 256
24
  _CN.COARSE.D_FFN = 256
25
  _CN.COARSE.NHEAD = 8
26
- _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
27
- _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
28
  _CN.COARSE.TEMP_BUG_FIX = False
29
 
30
  # 3. Coarse-Matching config
31
  _CN.MATCH_COARSE = CN()
32
  _CN.MATCH_COARSE.THR = 0.1
33
  _CN.MATCH_COARSE.BORDER_RM = 2
34
- _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
35
  _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
36
  _CN.MATCH_COARSE.SKH_ITERS = 3
37
  _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
@@ -44,7 +44,7 @@ _CN.FINE = CN()
44
  _CN.FINE.D_MODEL = 128
45
  _CN.FINE.D_FFN = 128
46
  _CN.FINE.NHEAD = 8
47
- _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
48
- _CN.FINE.ATTENTION = 'linear'
49
 
50
  default_cfg = lower_config(_CN)
 
8
 
9
 
10
  _CN = CN()
11
+ _CN.BACKBONE_TYPE = "ResNetFPN"
12
  _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
13
  _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
14
  _CN.FINE_CONCAT_COARSE_FEAT = True
 
23
  _CN.COARSE.D_MODEL = 256
24
  _CN.COARSE.D_FFN = 256
25
  _CN.COARSE.NHEAD = 8
26
+ _CN.COARSE.LAYER_NAMES = ["self", "cross"] * 4
27
+ _CN.COARSE.ATTENTION = "linear" # options: ['linear', 'full']
28
  _CN.COARSE.TEMP_BUG_FIX = False
29
 
30
  # 3. Coarse-Matching config
31
  _CN.MATCH_COARSE = CN()
32
  _CN.MATCH_COARSE.THR = 0.1
33
  _CN.MATCH_COARSE.BORDER_RM = 2
34
+ _CN.MATCH_COARSE.MATCH_TYPE = "dual_softmax" # options: ['dual_softmax, 'sinkhorn']
35
  _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
36
  _CN.MATCH_COARSE.SKH_ITERS = 3
37
  _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
 
44
  _CN.FINE.D_MODEL = 128
45
  _CN.FINE.D_FFN = 128
46
  _CN.FINE.NHEAD = 8
47
+ _CN.FINE.LAYER_NAMES = ["self", "cross"] * 1
48
+ _CN.FINE.ATTENTION = "linear"
49
 
50
  default_cfg = lower_config(_CN)
third_party/ASpanFormer/src/ASpanFormer/utils/fine_matching.py CHANGED
@@ -26,35 +26,46 @@ class FineMatching(nn.Module):
26
  """
27
  M, WW, C = feat_f0.shape
28
  W = int(math.sqrt(WW))
29
- scale = data['hw0_i'][0] / data['hw0_f'][0]
30
  self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
31
 
32
  # corner case: if no coarse matches found
33
  if M == 0:
34
- assert self.training == False, "M is always >0, when training, see coarse_matching.py"
 
 
35
  # logger.warning('No matches found in coarse-level.')
36
- data.update({
37
- 'expec_f': torch.empty(0, 3, device=feat_f0.device),
38
- 'mkpts0_f': data['mkpts0_c'],
39
- 'mkpts1_f': data['mkpts1_c'],
40
- })
 
 
41
  return
42
 
43
- feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
44
- sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
45
- softmax_temp = 1. / C**.5
46
  heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
47
 
48
  # compute coordinates from heatmap
49
  coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
50
- grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
 
 
51
 
52
  # compute std over <x, y>
53
- var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
54
- std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
55
-
 
 
 
 
 
56
  # for fine-level supervision
57
- data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
58
 
59
  # compute absolute kpt coords
60
  self.get_fine_match(coords_normalized, data)
@@ -64,11 +75,10 @@ class FineMatching(nn.Module):
64
  W, WW, C, scale = self.W, self.WW, self.C, self.scale
65
 
66
  # mkpts0_f and mkpts1_f
67
- mkpts0_f = data['mkpts0_c']
68
- scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
69
- mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
 
 
70
 
71
- data.update({
72
- "mkpts0_f": mkpts0_f,
73
- "mkpts1_f": mkpts1_f
74
- })
 
26
  """
27
  M, WW, C = feat_f0.shape
28
  W = int(math.sqrt(WW))
29
+ scale = data["hw0_i"][0] / data["hw0_f"][0]
30
  self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
31
 
32
  # corner case: if no coarse matches found
33
  if M == 0:
34
+ assert (
35
+ self.training == False
36
+ ), "M is always >0, when training, see coarse_matching.py"
37
  # logger.warning('No matches found in coarse-level.')
38
+ data.update(
39
+ {
40
+ "expec_f": torch.empty(0, 3, device=feat_f0.device),
41
+ "mkpts0_f": data["mkpts0_c"],
42
+ "mkpts1_f": data["mkpts1_c"],
43
+ }
44
+ )
45
  return
46
 
47
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
48
+ sim_matrix = torch.einsum("mc,mrc->mr", feat_f0_picked, feat_f1)
49
+ softmax_temp = 1.0 / C**0.5
50
  heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
51
 
52
  # compute coordinates from heatmap
53
  coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
54
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
55
+ 1, -1, 2
56
+ ) # [1, WW, 2]
57
 
58
  # compute std over <x, y>
59
+ var = (
60
+ torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1)
61
+ - coords_normalized**2
62
+ ) # [M, 2]
63
+ std = torch.sum(
64
+ torch.sqrt(torch.clamp(var, min=1e-10)), -1
65
+ ) # [M] clamp needed for numerical stability
66
+
67
  # for fine-level supervision
68
+ data.update({"expec_f": torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
69
 
70
  # compute absolute kpt coords
71
  self.get_fine_match(coords_normalized, data)
 
75
  W, WW, C, scale = self.W, self.WW, self.C, self.scale
76
 
77
  # mkpts0_f and mkpts1_f
78
+ mkpts0_f = data["mkpts0_c"]
79
+ scale1 = scale * data["scale1"][data["b_ids"]] if "scale0" in data else scale
80
+ mkpts1_f = (
81
+ data["mkpts1_c"] + (coords_normed * (W // 2) * scale1)[: len(data["mconf"])]
82
+ )
83
 
84
+ data.update({"mkpts0_f": mkpts0_f, "mkpts1_f": mkpts1_f})
 
 
 
third_party/ASpanFormer/src/ASpanFormer/utils/geometry.py CHANGED
@@ -3,10 +3,10 @@ import torch
3
 
4
  @torch.no_grad()
5
  def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
6
- """ Warp kpts0 from I0 to I1 with depth, K and Rt
7
  Also check covisibility and depth consistency.
8
  Depth is consistent if relative error < 0.2 (hard-coded).
9
-
10
  Args:
11
  kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
12
  depth0 (torch.Tensor): [N, H, W],
@@ -22,33 +22,52 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
22
 
23
  # Sample depth, get calculable_mask on depth != 0
24
  kpts0_depth = torch.stack(
25
- [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
 
 
 
 
26
  ) # (N, L)
27
  nonzero_mask = kpts0_depth != 0
28
 
29
  # Unproject
30
- kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
 
 
 
31
  kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
32
 
33
  # Rigid Transform
34
- w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
35
  w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
36
 
37
  # Project
38
  w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
39
- w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
 
 
40
 
41
  # Covisible Check
42
  h, w = depth1.shape[1:3]
43
- covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
44
- (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
 
 
 
 
45
  w_kpts0_long = w_kpts0.long()
46
  w_kpts0_long[~covisible_mask, :] = 0
47
 
48
  w_kpts0_depth = torch.stack(
49
- [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
 
 
 
 
50
  ) # (N, L)
51
- consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
 
 
52
  valid_mask = nonzero_mask * covisible_mask * consistent_mask
53
 
54
  return valid_mask, w_kpts0
 
3
 
4
  @torch.no_grad()
5
  def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
6
+ """Warp kpts0 from I0 to I1 with depth, K and Rt
7
  Also check covisibility and depth consistency.
8
  Depth is consistent if relative error < 0.2 (hard-coded).
9
+
10
  Args:
11
  kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
12
  depth0 (torch.Tensor): [N, H, W],
 
22
 
23
  # Sample depth, get calculable_mask on depth != 0
24
  kpts0_depth = torch.stack(
25
+ [
26
+ depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
27
+ for i in range(kpts0.shape[0])
28
+ ],
29
+ dim=0,
30
  ) # (N, L)
31
  nonzero_mask = kpts0_depth != 0
32
 
33
  # Unproject
34
+ kpts0_h = (
35
+ torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
36
+ * kpts0_depth[..., None]
37
+ ) # (N, L, 3)
38
  kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
39
 
40
  # Rigid Transform
41
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
42
  w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
43
 
44
  # Project
45
  w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
46
+ w_kpts0 = w_kpts0_h[:, :, :2] / (
47
+ w_kpts0_h[:, :, [2]] + 1e-4
48
+ ) # (N, L, 2), +1e-4 to avoid zero depth
49
 
50
  # Covisible Check
51
  h, w = depth1.shape[1:3]
52
+ covisible_mask = (
53
+ (w_kpts0[:, :, 0] > 0)
54
+ * (w_kpts0[:, :, 0] < w - 1)
55
+ * (w_kpts0[:, :, 1] > 0)
56
+ * (w_kpts0[:, :, 1] < h - 1)
57
+ )
58
  w_kpts0_long = w_kpts0.long()
59
  w_kpts0_long[~covisible_mask, :] = 0
60
 
61
  w_kpts0_depth = torch.stack(
62
+ [
63
+ depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]]
64
+ for i in range(w_kpts0_long.shape[0])
65
+ ],
66
+ dim=0,
67
  ) # (N, L)
68
+ consistent_mask = (
69
+ (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
70
+ ).abs() < 0.2
71
  valid_mask = nonzero_mask * covisible_mask * consistent_mask
72
 
73
  return valid_mask, w_kpts0
third_party/ASpanFormer/src/ASpanFormer/utils/position_encoding.py CHANGED
@@ -8,7 +8,7 @@ class PositionEncodingSine(nn.Module):
8
  This is a sinusoidal position encoding that generalized to 2-dimensional images
9
  """
10
 
11
- def __init__(self, d_model, max_shape=(256, 256),pre_scaling=None):
12
  """
13
  Args:
14
  max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
@@ -18,44 +18,63 @@ class PositionEncodingSine(nn.Module):
18
  We will remove the buggy impl after re-training all variants of our released models.
19
  """
20
  super().__init__()
21
- self.d_model=d_model
22
- self.max_shape=max_shape
23
- self.pre_scaling=pre_scaling
24
 
25
  pe = torch.zeros((d_model, *max_shape))
26
  y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
27
  x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
28
 
29
  if pre_scaling[0] is not None and pre_scaling[1] is not None:
30
- train_res,test_res=pre_scaling[0],pre_scaling[1]
31
- x_position,y_position=x_position*train_res[1]/test_res[1],y_position*train_res[0]/test_res[0]
 
 
 
32
 
33
- div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
 
 
 
34
  div_term = div_term[:, None, None] # [C//4, 1, 1]
35
  pe[0::4, :, :] = torch.sin(x_position * div_term)
36
  pe[1::4, :, :] = torch.cos(x_position * div_term)
37
  pe[2::4, :, :] = torch.sin(y_position * div_term)
38
  pe[3::4, :, :] = torch.cos(y_position * div_term)
39
 
40
- self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
41
 
42
- def forward(self, x,scaling=None):
43
  """
44
  Args:
45
  x: [N, C, H, W]
46
  """
47
- if scaling is None: #onliner scaling overwrites pre_scaling
48
- return x + self.pe[:, :, :x.size(2), :x.size(3)],self.pe[:, :, :x.size(2), :x.size(3)]
 
 
 
49
  else:
50
  pe = torch.zeros((self.d_model, *self.max_shape))
51
- y_position = torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0)*scaling[0]
52
- x_position = torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0)*scaling[1]
53
-
54
- div_term = torch.exp(torch.arange(0, self.d_model//2, 2).float() * (-math.log(10000.0) / (self.d_model//2)))
 
 
 
 
 
 
 
55
  div_term = div_term[:, None, None] # [C//4, 1, 1]
56
  pe[0::4, :, :] = torch.sin(x_position * div_term)
57
  pe[1::4, :, :] = torch.cos(x_position * div_term)
58
  pe[2::4, :, :] = torch.sin(y_position * div_term)
59
  pe[3::4, :, :] = torch.cos(y_position * div_term)
60
- pe=pe.unsqueeze(0).to(x.device)
61
- return x + pe[:, :, :x.size(2), :x.size(3)],pe[:, :, :x.size(2), :x.size(3)]
 
 
 
 
8
  This is a sinusoidal position encoding that generalized to 2-dimensional images
9
  """
10
 
11
+ def __init__(self, d_model, max_shape=(256, 256), pre_scaling=None):
12
  """
13
  Args:
14
  max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
 
18
  We will remove the buggy impl after re-training all variants of our released models.
19
  """
20
  super().__init__()
21
+ self.d_model = d_model
22
+ self.max_shape = max_shape
23
+ self.pre_scaling = pre_scaling
24
 
25
  pe = torch.zeros((d_model, *max_shape))
26
  y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
27
  x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
28
 
29
  if pre_scaling[0] is not None and pre_scaling[1] is not None:
30
+ train_res, test_res = pre_scaling[0], pre_scaling[1]
31
+ x_position, y_position = (
32
+ x_position * train_res[1] / test_res[1],
33
+ y_position * train_res[0] / test_res[0],
34
+ )
35
 
36
+ div_term = torch.exp(
37
+ torch.arange(0, d_model // 2, 2).float()
38
+ * (-math.log(10000.0) / (d_model // 2))
39
+ )
40
  div_term = div_term[:, None, None] # [C//4, 1, 1]
41
  pe[0::4, :, :] = torch.sin(x_position * div_term)
42
  pe[1::4, :, :] = torch.cos(x_position * div_term)
43
  pe[2::4, :, :] = torch.sin(y_position * div_term)
44
  pe[3::4, :, :] = torch.cos(y_position * div_term)
45
 
46
+ self.register_buffer("pe", pe.unsqueeze(0), persistent=False) # [1, C, H, W]
47
 
48
+ def forward(self, x, scaling=None):
49
  """
50
  Args:
51
  x: [N, C, H, W]
52
  """
53
+ if scaling is None: # onliner scaling overwrites pre_scaling
54
+ return (
55
+ x + self.pe[:, :, : x.size(2), : x.size(3)],
56
+ self.pe[:, :, : x.size(2), : x.size(3)],
57
+ )
58
  else:
59
  pe = torch.zeros((self.d_model, *self.max_shape))
60
+ y_position = (
61
+ torch.ones(self.max_shape).cumsum(0).float().unsqueeze(0) * scaling[0]
62
+ )
63
+ x_position = (
64
+ torch.ones(self.max_shape).cumsum(1).float().unsqueeze(0) * scaling[1]
65
+ )
66
+
67
+ div_term = torch.exp(
68
+ torch.arange(0, self.d_model // 2, 2).float()
69
+ * (-math.log(10000.0) / (self.d_model // 2))
70
+ )
71
  div_term = div_term[:, None, None] # [C//4, 1, 1]
72
  pe[0::4, :, :] = torch.sin(x_position * div_term)
73
  pe[1::4, :, :] = torch.cos(x_position * div_term)
74
  pe[2::4, :, :] = torch.sin(y_position * div_term)
75
  pe[3::4, :, :] = torch.cos(y_position * div_term)
76
+ pe = pe.unsqueeze(0).to(x.device)
77
+ return (
78
+ x + pe[:, :, : x.size(2), : x.size(3)],
79
+ pe[:, :, : x.size(2), : x.size(3)],
80
+ )
third_party/ASpanFormer/src/ASpanFormer/utils/supervision.py CHANGED
@@ -13,7 +13,7 @@ from .geometry import warp_kpts
13
  @torch.no_grad()
14
  def mask_pts_at_padded_regions(grid_pt, mask):
15
  """For megadepth dataset, zero-padding exists in images"""
16
- mask = repeat(mask, 'n h w -> n (h w) c', c=2)
17
  grid_pt[~mask.bool()] = 0
18
  return grid_pt
19
 
@@ -30,37 +30,55 @@ def spvs_coarse(data, config):
30
  'spv_w_pt0_i': [N, hw0, 2], in original image resolution
31
  'spv_pt1_i': [N, hw1, 2], in original image resolution
32
  }
33
-
34
  NOTE:
35
  - for scannet dataset, there're 3 kinds of resolution {i, c, f}
36
  - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
37
  """
38
  # 1. misc
39
- device = data['image0'].device
40
- N, _, H0, W0 = data['image0'].shape
41
- _, _, H1, W1 = data['image1'].shape
42
- scale = config['ASPAN']['RESOLUTION'][0]
43
- scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
44
- scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
45
  h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
46
 
47
  # 2. warp grids
48
  # create kpts in meshgrid and resize them to image resolution
49
- grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
 
 
50
  grid_pt0_i = scale0 * grid_pt0_c
51
- grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
 
 
52
  grid_pt1_i = scale1 * grid_pt1_c
53
 
54
  # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
55
- if 'mask0' in data:
56
- grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
57
- grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
58
 
59
  # warp kpts bi-directionally and resize them to coarse-level resolution
60
  # (no depth consistency check, since it leads to worse results experimentally)
61
  # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
62
- _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
63
- _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  w_pt0_c = w_pt0_i / scale1
65
  w_pt1_c = w_pt1_i / scale0
66
 
@@ -72,21 +90,26 @@ def spvs_coarse(data, config):
72
 
73
  # corner case: out of boundary
74
  def out_bound_mask(pt, w, h):
75
- return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
 
 
 
76
  nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
77
  nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
78
 
79
- loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
80
- correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
 
 
81
  correct_0to1[:, 0] = False # ignore the top-left corner
82
 
83
  # 4. construct a gt conf_matrix
84
- conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
85
  b_ids, i_ids = torch.where(correct_0to1 != 0)
86
  j_ids = nearest_index1[b_ids, i_ids]
87
 
88
  conf_matrix_gt[b_ids, i_ids, j_ids] = 1
89
- data.update({'conf_matrix_gt': conf_matrix_gt})
90
 
91
  # 5. save coarse matches(gt) for training fine level
92
  if len(b_ids) == 0:
@@ -96,30 +119,26 @@ def spvs_coarse(data, config):
96
  i_ids = torch.tensor([0], device=device)
97
  j_ids = torch.tensor([0], device=device)
98
 
99
- data.update({
100
- 'spv_b_ids': b_ids,
101
- 'spv_i_ids': i_ids,
102
- 'spv_j_ids': j_ids
103
- })
104
 
105
  # 6. save intermediate results (for fast fine-level computation)
106
- data.update({
107
- 'spv_w_pt0_i': w_pt0_i,
108
- 'spv_pt1_i': grid_pt1_i
109
- })
110
 
111
 
112
  def compute_supervision_coarse(data, config):
113
- assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
114
- data_source = data['dataset_name'][0]
115
- if data_source.lower() in ['scannet', 'megadepth']:
 
 
116
  spvs_coarse(data, config)
117
  else:
118
- raise ValueError(f'Unknown data source: {data_source}')
119
 
120
 
121
  ############## ↓ Fine-Level supervision ↓ ##############
122
 
 
123
  @torch.no_grad()
124
  def spvs_fine(data, config):
125
  """
@@ -129,23 +148,25 @@ def spvs_fine(data, config):
129
  """
130
  # 1. misc
131
  # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
132
- w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
133
- scale = config['ASPAN']['RESOLUTION'][1]
134
- radius = config['ASPAN']['FINE_WINDOW_SIZE'] // 2
135
 
136
  # 2. get coarse prediction
137
- b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
138
 
139
  # 3. compute gt
140
- scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
141
  # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
142
- expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
 
 
143
  data.update({"expec_f_gt": expec_f_gt})
144
 
145
 
146
  def compute_supervision_fine(data, config):
147
- data_source = data['dataset_name'][0]
148
- if data_source.lower() in ['scannet', 'megadepth']:
149
  spvs_fine(data, config)
150
  else:
151
  raise NotImplementedError
 
13
  @torch.no_grad()
14
  def mask_pts_at_padded_regions(grid_pt, mask):
15
  """For megadepth dataset, zero-padding exists in images"""
16
+ mask = repeat(mask, "n h w -> n (h w) c", c=2)
17
  grid_pt[~mask.bool()] = 0
18
  return grid_pt
19
 
 
30
  'spv_w_pt0_i': [N, hw0, 2], in original image resolution
31
  'spv_pt1_i': [N, hw1, 2], in original image resolution
32
  }
33
+
34
  NOTE:
35
  - for scannet dataset, there're 3 kinds of resolution {i, c, f}
36
  - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
37
  """
38
  # 1. misc
39
+ device = data["image0"].device
40
+ N, _, H0, W0 = data["image0"].shape
41
+ _, _, H1, W1 = data["image1"].shape
42
+ scale = config["ASPAN"]["RESOLUTION"][0]
43
+ scale0 = scale * data["scale0"][:, None] if "scale0" in data else scale
44
+ scale1 = scale * data["scale1"][:, None] if "scale0" in data else scale
45
  h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
46
 
47
  # 2. warp grids
48
  # create kpts in meshgrid and resize them to image resolution
49
+ grid_pt0_c = (
50
+ create_meshgrid(h0, w0, False, device).reshape(1, h0 * w0, 2).repeat(N, 1, 1)
51
+ ) # [N, hw, 2]
52
  grid_pt0_i = scale0 * grid_pt0_c
53
+ grid_pt1_c = (
54
+ create_meshgrid(h1, w1, False, device).reshape(1, h1 * w1, 2).repeat(N, 1, 1)
55
+ )
56
  grid_pt1_i = scale1 * grid_pt1_c
57
 
58
  # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
59
+ if "mask0" in data:
60
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data["mask0"])
61
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data["mask1"])
62
 
63
  # warp kpts bi-directionally and resize them to coarse-level resolution
64
  # (no depth consistency check, since it leads to worse results experimentally)
65
  # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
66
+ _, w_pt0_i = warp_kpts(
67
+ grid_pt0_i,
68
+ data["depth0"],
69
+ data["depth1"],
70
+ data["T_0to1"],
71
+ data["K0"],
72
+ data["K1"],
73
+ )
74
+ _, w_pt1_i = warp_kpts(
75
+ grid_pt1_i,
76
+ data["depth1"],
77
+ data["depth0"],
78
+ data["T_1to0"],
79
+ data["K1"],
80
+ data["K0"],
81
+ )
82
  w_pt0_c = w_pt0_i / scale1
83
  w_pt1_c = w_pt1_i / scale0
84
 
 
90
 
91
  # corner case: out of boundary
92
  def out_bound_mask(pt, w, h):
93
+ return (
94
+ (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
95
+ )
96
+
97
  nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
98
  nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
99
 
100
+ loop_back = torch.stack(
101
+ [nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0
102
+ )
103
+ correct_0to1 = loop_back == torch.arange(h0 * w0, device=device)[None].repeat(N, 1)
104
  correct_0to1[:, 0] = False # ignore the top-left corner
105
 
106
  # 4. construct a gt conf_matrix
107
+ conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
108
  b_ids, i_ids = torch.where(correct_0to1 != 0)
109
  j_ids = nearest_index1[b_ids, i_ids]
110
 
111
  conf_matrix_gt[b_ids, i_ids, j_ids] = 1
112
+ data.update({"conf_matrix_gt": conf_matrix_gt})
113
 
114
  # 5. save coarse matches(gt) for training fine level
115
  if len(b_ids) == 0:
 
119
  i_ids = torch.tensor([0], device=device)
120
  j_ids = torch.tensor([0], device=device)
121
 
122
+ data.update({"spv_b_ids": b_ids, "spv_i_ids": i_ids, "spv_j_ids": j_ids})
 
 
 
 
123
 
124
  # 6. save intermediate results (for fast fine-level computation)
125
+ data.update({"spv_w_pt0_i": w_pt0_i, "spv_pt1_i": grid_pt1_i})
 
 
 
126
 
127
 
128
  def compute_supervision_coarse(data, config):
129
+ assert (
130
+ len(set(data["dataset_name"])) == 1
131
+ ), "Do not support mixed datasets training!"
132
+ data_source = data["dataset_name"][0]
133
+ if data_source.lower() in ["scannet", "megadepth"]:
134
  spvs_coarse(data, config)
135
  else:
136
+ raise ValueError(f"Unknown data source: {data_source}")
137
 
138
 
139
  ############## ↓ Fine-Level supervision ↓ ##############
140
 
141
+
142
  @torch.no_grad()
143
  def spvs_fine(data, config):
144
  """
 
148
  """
149
  # 1. misc
150
  # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
151
+ w_pt0_i, pt1_i = data["spv_w_pt0_i"], data["spv_pt1_i"]
152
+ scale = config["ASPAN"]["RESOLUTION"][1]
153
+ radius = config["ASPAN"]["FINE_WINDOW_SIZE"] // 2
154
 
155
  # 2. get coarse prediction
156
+ b_ids, i_ids, j_ids = data["b_ids"], data["i_ids"], data["j_ids"]
157
 
158
  # 3. compute gt
159
+ scale = scale * data["scale1"][b_ids] if "scale0" in data else scale
160
  # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
161
+ expec_f_gt = (
162
+ (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius
163
+ ) # [M, 2]
164
  data.update({"expec_f_gt": expec_f_gt})
165
 
166
 
167
  def compute_supervision_fine(data, config):
168
+ data_source = data["dataset_name"][0]
169
+ if data_source.lower() in ["scannet", "megadepth"]:
170
  spvs_fine(data, config)
171
  else:
172
  raise NotImplementedError
third_party/ASpanFormer/src/config/default.py CHANGED
@@ -1,9 +1,10 @@
1
  from yacs.config import CfgNode as CN
 
2
  _CN = CN()
3
 
4
  ############## ↓ ASPAN Pipeline ↓ ##############
5
  _CN.ASPAN = CN()
6
- _CN.ASPAN.BACKBONE_TYPE = 'ResNetFPN'
7
  _CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
8
  _CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
9
  _CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
@@ -17,14 +18,14 @@ _CN.ASPAN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
17
  _CN.ASPAN.COARSE = CN()
18
  _CN.ASPAN.COARSE.D_MODEL = 256
19
  _CN.ASPAN.COARSE.D_FFN = 256
20
- _CN.ASPAN.COARSE.D_FLOW= 128
21
  _CN.ASPAN.COARSE.NHEAD = 8
22
- _CN.ASPAN.COARSE.NLEVEL= 3
23
- _CN.ASPAN.COARSE.INI_LAYER_NUM = 2
24
- _CN.ASPAN.COARSE.LAYER_NUM = 4
25
- _CN.ASPAN.COARSE.NSAMPLE = [2,8]
26
- _CN.ASPAN.COARSE.RADIUS_SCALE= 5
27
- _CN.ASPAN.COARSE.COARSEST_LEVEL= [26,26]
28
  _CN.ASPAN.COARSE.TRAIN_RES = None
29
  _CN.ASPAN.COARSE.TEST_RES = None
30
 
@@ -32,7 +33,9 @@ _CN.ASPAN.COARSE.TEST_RES = None
32
  _CN.ASPAN.MATCH_COARSE = CN()
33
  _CN.ASPAN.MATCH_COARSE.THR = 0.2
34
  _CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
35
- _CN.ASPAN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
 
 
36
  _CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
37
  _CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
38
  _CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
@@ -46,13 +49,13 @@ _CN.ASPAN.FINE = CN()
46
  _CN.ASPAN.FINE.D_MODEL = 128
47
  _CN.ASPAN.FINE.D_FFN = 128
48
  _CN.ASPAN.FINE.NHEAD = 8
49
- _CN.ASPAN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
50
- _CN.ASPAN.FINE.ATTENTION = 'linear'
51
 
52
  # 5. ASPAN Losses
53
  # -- # coarse-level
54
  _CN.ASPAN.LOSS = CN()
55
- _CN.ASPAN.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
56
  _CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
57
  # _CN.ASPAN.LOSS.SPARSE_SPVS = False
58
  # -- - -- # focal loss (coarse)
@@ -64,7 +67,7 @@ _CN.ASPAN.LOSS.NEG_WEIGHT = 1.0
64
  # use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
65
 
66
  # -- # fine-level
67
- _CN.ASPAN.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
68
  _CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
69
  _CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
70
 
@@ -85,24 +88,32 @@ _CN.DATASET.TRAIN_INTRINSIC_PATH = None
85
  _CN.DATASET.VAL_DATA_ROOT = None
86
  _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
87
  _CN.DATASET.VAL_NPZ_ROOT = None
88
- _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
 
 
89
  _CN.DATASET.VAL_INTRINSIC_PATH = None
90
  # testing
91
  _CN.DATASET.TEST_DATA_SOURCE = None
92
  _CN.DATASET.TEST_DATA_ROOT = None
93
  _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
94
  _CN.DATASET.TEST_NPZ_ROOT = None
95
- _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
 
 
96
  _CN.DATASET.TEST_INTRINSIC_PATH = None
97
 
98
  # 2. dataset config
99
  # general options
100
- _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
 
 
101
  _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
102
  _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
103
 
104
  # MegaDepth options
105
- _CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
 
 
106
  _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
107
  _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
108
  _CN.DATASET.MGDPT_DF = 8
@@ -118,17 +129,17 @@ _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
118
  # optimizer
119
  _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
120
  _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
121
- _CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
122
  _CN.TRAINER.ADAMW_DECAY = 0.1
123
 
124
  # step-based warm-up
125
- _CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
126
- _CN.TRAINER.WARMUP_RATIO = 0.
127
  _CN.TRAINER.WARMUP_STEP = 4800
128
 
129
  # learning rate scheduler
130
- _CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
131
- _CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
132
  _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
133
  _CN.TRAINER.MSLR_GAMMA = 0.5
134
  _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
@@ -136,25 +147,33 @@ _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' in
136
 
137
  # plotting related
138
  _CN.TRAINER.ENABLE_PLOTTING = True
139
- _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
140
- _CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
141
- _CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
142
 
143
  # geometric metrics and pose solver
144
- _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
145
- _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
146
- _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
 
 
147
  _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
148
  _CN.TRAINER.RANSAC_CONF = 0.99999
149
  _CN.TRAINER.RANSAC_MAX_ITERS = 10000
150
  _CN.TRAINER.USE_MAGSACPP = False
151
 
152
  # data sampler for train_dataloader
153
- _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
 
 
154
  # 'scene_balance' config
155
  _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
156
- _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
157
- _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
 
 
 
 
158
  _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
159
  # 'random' config
160
  _CN.TRAINER.RDM_REPLACEMENT = True
 
1
  from yacs.config import CfgNode as CN
2
+
3
  _CN = CN()
4
 
5
  ############## ↓ ASPAN Pipeline ↓ ##############
6
  _CN.ASPAN = CN()
7
+ _CN.ASPAN.BACKBONE_TYPE = "ResNetFPN"
8
  _CN.ASPAN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
9
  _CN.ASPAN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
10
  _CN.ASPAN.FINE_CONCAT_COARSE_FEAT = True
 
18
  _CN.ASPAN.COARSE = CN()
19
  _CN.ASPAN.COARSE.D_MODEL = 256
20
  _CN.ASPAN.COARSE.D_FFN = 256
21
+ _CN.ASPAN.COARSE.D_FLOW = 128
22
  _CN.ASPAN.COARSE.NHEAD = 8
23
+ _CN.ASPAN.COARSE.NLEVEL = 3
24
+ _CN.ASPAN.COARSE.INI_LAYER_NUM = 2
25
+ _CN.ASPAN.COARSE.LAYER_NUM = 4
26
+ _CN.ASPAN.COARSE.NSAMPLE = [2, 8]
27
+ _CN.ASPAN.COARSE.RADIUS_SCALE = 5
28
+ _CN.ASPAN.COARSE.COARSEST_LEVEL = [26, 26]
29
  _CN.ASPAN.COARSE.TRAIN_RES = None
30
  _CN.ASPAN.COARSE.TEST_RES = None
31
 
 
33
  _CN.ASPAN.MATCH_COARSE = CN()
34
  _CN.ASPAN.MATCH_COARSE.THR = 0.2
35
  _CN.ASPAN.MATCH_COARSE.BORDER_RM = 2
36
+ _CN.ASPAN.MATCH_COARSE.MATCH_TYPE = (
37
+ "dual_softmax" # options: ['dual_softmax, 'sinkhorn']
38
+ )
39
  _CN.ASPAN.MATCH_COARSE.SKH_ITERS = 3
40
  _CN.ASPAN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
41
  _CN.ASPAN.MATCH_COARSE.SKH_PREFILTER = False
 
49
  _CN.ASPAN.FINE.D_MODEL = 128
50
  _CN.ASPAN.FINE.D_FFN = 128
51
  _CN.ASPAN.FINE.NHEAD = 8
52
+ _CN.ASPAN.FINE.LAYER_NAMES = ["self", "cross"] * 1
53
+ _CN.ASPAN.FINE.ATTENTION = "linear"
54
 
55
  # 5. ASPAN Losses
56
  # -- # coarse-level
57
  _CN.ASPAN.LOSS = CN()
58
+ _CN.ASPAN.LOSS.COARSE_TYPE = "focal" # ['focal', 'cross_entropy']
59
  _CN.ASPAN.LOSS.COARSE_WEIGHT = 1.0
60
  # _CN.ASPAN.LOSS.SPARSE_SPVS = False
61
  # -- - -- # focal loss (coarse)
 
67
  # use `_CN.ASPAN.MATCH_COARSE.MATCH_TYPE`
68
 
69
  # -- # fine-level
70
+ _CN.ASPAN.LOSS.FINE_TYPE = "l2_with_std" # ['l2_with_std', 'l2']
71
  _CN.ASPAN.LOSS.FINE_WEIGHT = 1.0
72
  _CN.ASPAN.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window)
73
 
 
88
  _CN.DATASET.VAL_DATA_ROOT = None
89
  _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
90
  _CN.DATASET.VAL_NPZ_ROOT = None
91
+ _CN.DATASET.VAL_LIST_PATH = (
92
+ None # None if val data from all scenes are bundled into a single npz file
93
+ )
94
  _CN.DATASET.VAL_INTRINSIC_PATH = None
95
  # testing
96
  _CN.DATASET.TEST_DATA_SOURCE = None
97
  _CN.DATASET.TEST_DATA_ROOT = None
98
  _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
99
  _CN.DATASET.TEST_NPZ_ROOT = None
100
+ _CN.DATASET.TEST_LIST_PATH = (
101
+ None # None if test data from all scenes are bundled into a single npz file
102
+ )
103
  _CN.DATASET.TEST_INTRINSIC_PATH = None
104
 
105
  # 2. dataset config
106
  # general options
107
+ _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = (
108
+ 0.4 # discard data with overlap_score < min_overlap_score
109
+ )
110
  _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
111
  _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
112
 
113
  # MegaDepth options
114
+ _CN.DATASET.MGDPT_IMG_RESIZE = (
115
+ 640 # resize the longer side, zero-pad bottom-right to square.
116
+ )
117
  _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
118
  _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
119
  _CN.DATASET.MGDPT_DF = 8
 
129
  # optimizer
130
  _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw]
131
  _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
132
+ _CN.TRAINER.ADAM_DECAY = 0.0 # ADAM: for adam
133
  _CN.TRAINER.ADAMW_DECAY = 0.1
134
 
135
  # step-based warm-up
136
+ _CN.TRAINER.WARMUP_TYPE = "linear" # [linear, constant]
137
+ _CN.TRAINER.WARMUP_RATIO = 0.0
138
  _CN.TRAINER.WARMUP_STEP = 4800
139
 
140
  # learning rate scheduler
141
+ _CN.TRAINER.SCHEDULER = "MultiStepLR" # [MultiStepLR, CosineAnnealing, ExponentialLR]
142
+ _CN.TRAINER.SCHEDULER_INTERVAL = "epoch" # [epoch, step]
143
  _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR
144
  _CN.TRAINER.MSLR_GAMMA = 0.5
145
  _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
 
147
 
148
  # plotting related
149
  _CN.TRAINER.ENABLE_PLOTTING = True
150
+ _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
151
+ _CN.TRAINER.PLOT_MODE = "evaluation" # ['evaluation', 'confidence']
152
+ _CN.TRAINER.PLOT_MATCHES_ALPHA = "dynamic"
153
 
154
  # geometric metrics and pose solver
155
+ _CN.TRAINER.EPI_ERR_THR = (
156
+ 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
157
+ )
158
+ _CN.TRAINER.POSE_GEO_MODEL = "E" # ['E', 'F', 'H']
159
+ _CN.TRAINER.POSE_ESTIMATION_METHOD = "RANSAC" # [RANSAC, DEGENSAC, MAGSAC]
160
  _CN.TRAINER.RANSAC_PIXEL_THR = 0.5
161
  _CN.TRAINER.RANSAC_CONF = 0.99999
162
  _CN.TRAINER.RANSAC_MAX_ITERS = 10000
163
  _CN.TRAINER.USE_MAGSACPP = False
164
 
165
  # data sampler for train_dataloader
166
+ _CN.TRAINER.DATA_SAMPLER = (
167
+ "scene_balance" # options: ['scene_balance', 'random', 'normal']
168
+ )
169
  # 'scene_balance' config
170
  _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
171
+ _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = (
172
+ True # whether sample each scene with replacement or not
173
+ )
174
+ _CN.TRAINER.SB_SUBSET_SHUFFLE = (
175
+ True # after sampling from scenes, whether shuffle within the epoch or not
176
+ )
177
  _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
178
  # 'random' config
179
  _CN.TRAINER.RDM_REPLACEMENT = True
third_party/ASpanFormer/src/datasets/__init__.py CHANGED
@@ -1,3 +1,2 @@
1
  from .scannet import ScanNetDataset
2
  from .megadepth import MegaDepthDataset
3
-
 
1
  from .scannet import ScanNetDataset
2
  from .megadepth import MegaDepthDataset
 
third_party/ASpanFormer/src/datasets/megadepth.py CHANGED
@@ -9,20 +9,22 @@ from src.utils.dataset import read_megadepth_gray, read_megadepth_depth
9
 
10
 
11
  class MegaDepthDataset(Dataset):
12
- def __init__(self,
13
- root_dir,
14
- npz_path,
15
- mode='train',
16
- min_overlap_score=0.4,
17
- img_resize=None,
18
- df=None,
19
- img_padding=False,
20
- depth_padding=False,
21
- augment_fn=None,
22
- **kwargs):
 
 
23
  """
24
  Manage one scene(npz_path) of MegaDepth dataset.
25
-
26
  Args:
27
  root_dir (str): megadepth root directory that has `phoenix`.
28
  npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
@@ -38,28 +40,36 @@ class MegaDepthDataset(Dataset):
38
  super().__init__()
39
  self.root_dir = root_dir
40
  self.mode = mode
41
- self.scene_id = npz_path.split('.')[0]
42
 
43
  # prepare scene_info and pair_info
44
- if mode == 'test' and min_overlap_score != 0:
45
- logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.")
 
 
46
  min_overlap_score = 0
47
  self.scene_info = np.load(npz_path, allow_pickle=True)
48
- self.pair_infos = self.scene_info['pair_infos'].copy()
49
- del self.scene_info['pair_infos']
50
- self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
 
 
 
 
51
 
52
  # parameters for image resizing, padding and depthmap padding
53
- if mode == 'train':
54
  assert img_resize is not None and img_padding and depth_padding
55
  self.img_resize = img_resize
56
  self.df = df
57
  self.img_padding = img_padding
58
- self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
 
 
59
 
60
  # for training LoFTR
61
- self.augment_fn = augment_fn if mode == 'train' else None
62
- self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
63
 
64
  def __len__(self):
65
  return len(self.pair_infos)
@@ -68,60 +78,77 @@ class MegaDepthDataset(Dataset):
68
  (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
69
 
70
  # read grayscale image and mask. (1, h, w) and (h, w)
71
- img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
72
- img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
73
-
74
  # TODO: Support augmentation & handle seeds for each worker correctly.
75
  image0, mask0, scale0 = read_megadepth_gray(
76
- img_name0, self.img_resize, self.df, self.img_padding, None)
77
- # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
78
  image1, mask1, scale1 = read_megadepth_gray(
79
- img_name1, self.img_resize, self.df, self.img_padding, None)
80
- # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
 
81
 
82
  # read depth. shape: (h, w)
83
- if self.mode in ['train', 'val']:
84
  depth0 = read_megadepth_depth(
85
- osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
 
 
86
  depth1 = read_megadepth_depth(
87
- osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
 
 
88
  else:
89
  depth0 = depth1 = torch.tensor([])
90
 
91
  # read intrinsics of original size
92
- K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
93
- K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
 
 
 
 
94
 
95
  # read and compute relative poses
96
- T0 = self.scene_info['poses'][idx0]
97
- T1 = self.scene_info['poses'][idx1]
98
- T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
 
 
99
  T_1to0 = T_0to1.inverse()
100
 
101
  data = {
102
- 'image0': image0, # (1, h, w)
103
- 'depth0': depth0, # (h, w)
104
- 'image1': image1,
105
- 'depth1': depth1,
106
- 'T_0to1': T_0to1, # (4, 4)
107
- 'T_1to0': T_1to0,
108
- 'K0': K_0, # (3, 3)
109
- 'K1': K_1,
110
- 'scale0': scale0, # [scale_w, scale_h]
111
- 'scale1': scale1,
112
- 'dataset_name': 'MegaDepth',
113
- 'scene_id': self.scene_id,
114
- 'pair_id': idx,
115
- 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
 
 
 
116
  }
117
 
118
  # for LoFTR training
119
  if mask0 is not None: # img_padding is True
120
  if self.coarse_scale:
121
- [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
122
- scale_factor=self.coarse_scale,
123
- mode='nearest',
124
- recompute_scale_factor=False)[0].bool()
125
- data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
 
 
126
 
127
  return data
 
9
 
10
 
11
  class MegaDepthDataset(Dataset):
12
+ def __init__(
13
+ self,
14
+ root_dir,
15
+ npz_path,
16
+ mode="train",
17
+ min_overlap_score=0.4,
18
+ img_resize=None,
19
+ df=None,
20
+ img_padding=False,
21
+ depth_padding=False,
22
+ augment_fn=None,
23
+ **kwargs
24
+ ):
25
  """
26
  Manage one scene(npz_path) of MegaDepth dataset.
27
+
28
  Args:
29
  root_dir (str): megadepth root directory that has `phoenix`.
30
  npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
 
40
  super().__init__()
41
  self.root_dir = root_dir
42
  self.mode = mode
43
+ self.scene_id = npz_path.split(".")[0]
44
 
45
  # prepare scene_info and pair_info
46
+ if mode == "test" and min_overlap_score != 0:
47
+ logger.warning(
48
+ "You are using `min_overlap_score`!=0 in test mode. Set to 0."
49
+ )
50
  min_overlap_score = 0
51
  self.scene_info = np.load(npz_path, allow_pickle=True)
52
+ self.pair_infos = self.scene_info["pair_infos"].copy()
53
+ del self.scene_info["pair_infos"]
54
+ self.pair_infos = [
55
+ pair_info
56
+ for pair_info in self.pair_infos
57
+ if pair_info[1] > min_overlap_score
58
+ ]
59
 
60
  # parameters for image resizing, padding and depthmap padding
61
+ if mode == "train":
62
  assert img_resize is not None and img_padding and depth_padding
63
  self.img_resize = img_resize
64
  self.df = df
65
  self.img_padding = img_padding
66
+ self.depth_max_size = (
67
+ 2000 if depth_padding else None
68
+ ) # the upperbound of depthmaps size in megadepth.
69
 
70
  # for training LoFTR
71
+ self.augment_fn = augment_fn if mode == "train" else None
72
+ self.coarse_scale = getattr(kwargs, "coarse_scale", 0.125)
73
 
74
  def __len__(self):
75
  return len(self.pair_infos)
 
78
  (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
79
 
80
  # read grayscale image and mask. (1, h, w) and (h, w)
81
+ img_name0 = osp.join(self.root_dir, self.scene_info["image_paths"][idx0])
82
+ img_name1 = osp.join(self.root_dir, self.scene_info["image_paths"][idx1])
83
+
84
  # TODO: Support augmentation & handle seeds for each worker correctly.
85
  image0, mask0, scale0 = read_megadepth_gray(
86
+ img_name0, self.img_resize, self.df, self.img_padding, None
87
+ )
88
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
89
  image1, mask1, scale1 = read_megadepth_gray(
90
+ img_name1, self.img_resize, self.df, self.img_padding, None
91
+ )
92
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
93
 
94
  # read depth. shape: (h, w)
95
+ if self.mode in ["train", "val"]:
96
  depth0 = read_megadepth_depth(
97
+ osp.join(self.root_dir, self.scene_info["depth_paths"][idx0]),
98
+ pad_to=self.depth_max_size,
99
+ )
100
  depth1 = read_megadepth_depth(
101
+ osp.join(self.root_dir, self.scene_info["depth_paths"][idx1]),
102
+ pad_to=self.depth_max_size,
103
+ )
104
  else:
105
  depth0 = depth1 = torch.tensor([])
106
 
107
  # read intrinsics of original size
108
+ K_0 = torch.tensor(
109
+ self.scene_info["intrinsics"][idx0].copy(), dtype=torch.float
110
+ ).reshape(3, 3)
111
+ K_1 = torch.tensor(
112
+ self.scene_info["intrinsics"][idx1].copy(), dtype=torch.float
113
+ ).reshape(3, 3)
114
 
115
  # read and compute relative poses
116
+ T0 = self.scene_info["poses"][idx0]
117
+ T1 = self.scene_info["poses"][idx1]
118
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[
119
+ :4, :4
120
+ ] # (4, 4)
121
  T_1to0 = T_0to1.inverse()
122
 
123
  data = {
124
+ "image0": image0, # (1, h, w)
125
+ "depth0": depth0, # (h, w)
126
+ "image1": image1,
127
+ "depth1": depth1,
128
+ "T_0to1": T_0to1, # (4, 4)
129
+ "T_1to0": T_1to0,
130
+ "K0": K_0, # (3, 3)
131
+ "K1": K_1,
132
+ "scale0": scale0, # [scale_w, scale_h]
133
+ "scale1": scale1,
134
+ "dataset_name": "MegaDepth",
135
+ "scene_id": self.scene_id,
136
+ "pair_id": idx,
137
+ "pair_names": (
138
+ self.scene_info["image_paths"][idx0],
139
+ self.scene_info["image_paths"][idx1],
140
+ ),
141
  }
142
 
143
  # for LoFTR training
144
  if mask0 is not None: # img_padding is True
145
  if self.coarse_scale:
146
+ [ts_mask_0, ts_mask_1] = F.interpolate(
147
+ torch.stack([mask0, mask1], dim=0)[None].float(),
148
+ scale_factor=self.coarse_scale,
149
+ mode="nearest",
150
+ recompute_scale_factor=False,
151
+ )[0].bool()
152
+ data.update({"mask0": ts_mask_0, "mask1": ts_mask_1})
153
 
154
  return data
third_party/ASpanFormer/src/datasets/sampler.py CHANGED
@@ -3,10 +3,10 @@ from torch.utils.data import Sampler, ConcatDataset
3
 
4
 
5
  class RandomConcatSampler(Sampler):
6
- """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
7
  in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
8
  However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
9
-
10
  For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
11
  Args:
12
  shuffle (bool): shuffle the random sampled indices across all sub-datsets.
@@ -18,16 +18,19 @@ class RandomConcatSampler(Sampler):
18
  TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
19
  ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
20
  """
21
- def __init__(self,
22
- data_source: ConcatDataset,
23
- n_samples_per_subset: int,
24
- subset_replacement: bool=True,
25
- shuffle: bool=True,
26
- repeat: int=1,
27
- seed: int=None):
 
 
 
28
  if not isinstance(data_source, ConcatDataset):
29
  raise TypeError("data_source should be torch.utils.data.ConcatDataset")
30
-
31
  self.data_source = data_source
32
  self.n_subset = len(self.data_source.datasets)
33
  self.n_samples_per_subset = n_samples_per_subset
@@ -37,27 +40,37 @@ class RandomConcatSampler(Sampler):
37
  self.shuffle = shuffle
38
  self.generator = torch.manual_seed(seed)
39
  assert self.repeat >= 1
40
-
41
  def __len__(self):
42
  return self.n_samples
43
-
44
  def __iter__(self):
45
  indices = []
46
  # sample from each sub-dataset
47
  for d_idx in range(self.n_subset):
48
- low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
49
  high = self.data_source.cumulative_sizes[d_idx]
50
  if self.subset_replacement:
51
- rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
52
- generator=self.generator, dtype=torch.int64)
 
 
 
 
 
53
  else: # sample without replacement
54
  len_subset = len(self.data_source.datasets[d_idx])
55
  rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
56
  if len_subset >= self.n_samples_per_subset:
57
- rand_tensor = rand_tensor[:self.n_samples_per_subset]
58
- else: # padding with replacement
59
- rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
60
- generator=self.generator, dtype=torch.int64)
 
 
 
 
 
61
  rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
62
  indices.append(rand_tensor)
63
  indices = torch.cat(indices)
@@ -72,6 +85,6 @@ class RandomConcatSampler(Sampler):
72
  _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
73
  repeat_indices = map(_choice, repeat_indices)
74
  indices = torch.cat([indices, *repeat_indices], 0)
75
-
76
  assert indices.shape[0] == self.n_samples
77
  return iter(indices.tolist())
 
3
 
4
 
5
  class RandomConcatSampler(Sampler):
6
+ """Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
7
  in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
8
  However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
9
+
10
  For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
11
  Args:
12
  shuffle (bool): shuffle the random sampled indices across all sub-datsets.
 
18
  TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
19
  ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
20
  """
21
+
22
+ def __init__(
23
+ self,
24
+ data_source: ConcatDataset,
25
+ n_samples_per_subset: int,
26
+ subset_replacement: bool = True,
27
+ shuffle: bool = True,
28
+ repeat: int = 1,
29
+ seed: int = None,
30
+ ):
31
  if not isinstance(data_source, ConcatDataset):
32
  raise TypeError("data_source should be torch.utils.data.ConcatDataset")
33
+
34
  self.data_source = data_source
35
  self.n_subset = len(self.data_source.datasets)
36
  self.n_samples_per_subset = n_samples_per_subset
 
40
  self.shuffle = shuffle
41
  self.generator = torch.manual_seed(seed)
42
  assert self.repeat >= 1
43
+
44
  def __len__(self):
45
  return self.n_samples
46
+
47
  def __iter__(self):
48
  indices = []
49
  # sample from each sub-dataset
50
  for d_idx in range(self.n_subset):
51
+ low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
52
  high = self.data_source.cumulative_sizes[d_idx]
53
  if self.subset_replacement:
54
+ rand_tensor = torch.randint(
55
+ low,
56
+ high,
57
+ (self.n_samples_per_subset,),
58
+ generator=self.generator,
59
+ dtype=torch.int64,
60
+ )
61
  else: # sample without replacement
62
  len_subset = len(self.data_source.datasets[d_idx])
63
  rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
64
  if len_subset >= self.n_samples_per_subset:
65
+ rand_tensor = rand_tensor[: self.n_samples_per_subset]
66
+ else: # padding with replacement
67
+ rand_tensor_replacement = torch.randint(
68
+ low,
69
+ high,
70
+ (self.n_samples_per_subset - len_subset,),
71
+ generator=self.generator,
72
+ dtype=torch.int64,
73
+ )
74
  rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
75
  indices.append(rand_tensor)
76
  indices = torch.cat(indices)
 
85
  _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
86
  repeat_indices = map(_choice, repeat_indices)
87
  indices = torch.cat([indices, *repeat_indices], 0)
88
+
89
  assert indices.shape[0] == self.n_samples
90
  return iter(indices.tolist())
third_party/ASpanFormer/src/datasets/scannet.py CHANGED
@@ -10,20 +10,22 @@ from src.utils.dataset import (
10
  read_scannet_gray,
11
  read_scannet_depth,
12
  read_scannet_pose,
13
- read_scannet_intrinsic
14
  )
15
 
16
 
17
  class ScanNetDataset(utils.data.Dataset):
18
- def __init__(self,
19
- root_dir,
20
- npz_path,
21
- intrinsic_path,
22
- mode='train',
23
- min_overlap_score=0.4,
24
- augment_fn=None,
25
- pose_dir=None,
26
- **kwargs):
 
 
27
  """Manage one scene of ScanNet Dataset.
28
  Args:
29
  root_dir (str): ScanNet root directory that contains scene folders.
@@ -41,73 +43,81 @@ class ScanNetDataset(utils.data.Dataset):
41
 
42
  # prepare data_names, intrinsics and extrinsics(T)
43
  with np.load(npz_path) as data:
44
- self.data_names = data['name']
45
- if 'score' in data.keys() and mode not in ['val' or 'test']:
46
- kept_mask = data['score'] > min_overlap_score
47
  self.data_names = self.data_names[kept_mask]
48
  self.intrinsics = dict(np.load(intrinsic_path))
49
 
50
  # for training LoFTR
51
- self.augment_fn = augment_fn if mode == 'train' else None
52
 
53
  def __len__(self):
54
  return len(self.data_names)
55
 
56
  def _read_abs_pose(self, scene_name, name):
57
- pth = osp.join(self.pose_dir,
58
- scene_name,
59
- 'pose', f'{name}.txt')
60
  return read_scannet_pose(pth)
61
 
62
  def _compute_rel_pose(self, scene_name, name0, name1):
63
  pose0 = self._read_abs_pose(scene_name, name0)
64
  pose1 = self._read_abs_pose(scene_name, name1)
65
-
66
  return np.matmul(pose1, inv(pose0)) # (4, 4)
67
 
68
  def __getitem__(self, idx):
69
  data_name = self.data_names[idx]
70
  scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
71
- scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
72
 
73
  # read the grayscale image which will be resized to (1, 480, 640)
74
- img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
75
- img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
76
  # TODO: Support augmentation & handle seeds for each worker correctly.
77
  image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
78
- # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
79
  image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
80
- # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
81
 
82
  # read the depthmap which is stored as (480, 640)
83
- if self.mode in ['train', 'val']:
84
- depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'))
85
- depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'))
 
 
 
 
86
  else:
87
  depth0 = depth1 = torch.tensor([])
88
 
89
  # read the intrinsic of depthmap
90
- K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
 
 
91
 
92
  # read and compute relative poses
93
- T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
94
- dtype=torch.float32)
 
 
95
  T_1to0 = T_0to1.inverse()
96
 
97
  data = {
98
- 'image0': image0, # (1, h, w)
99
- 'depth0': depth0, # (h, w)
100
- 'image1': image1,
101
- 'depth1': depth1,
102
- 'T_0to1': T_0to1, # (4, 4)
103
- 'T_1to0': T_1to0,
104
- 'K0': K_0, # (3, 3)
105
- 'K1': K_1,
106
- 'dataset_name': 'ScanNet',
107
- 'scene_id': scene_name,
108
- 'pair_id': idx,
109
- 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
110
- osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
 
 
111
  }
112
 
113
  return data
 
10
  read_scannet_gray,
11
  read_scannet_depth,
12
  read_scannet_pose,
13
+ read_scannet_intrinsic,
14
  )
15
 
16
 
17
  class ScanNetDataset(utils.data.Dataset):
18
+ def __init__(
19
+ self,
20
+ root_dir,
21
+ npz_path,
22
+ intrinsic_path,
23
+ mode="train",
24
+ min_overlap_score=0.4,
25
+ augment_fn=None,
26
+ pose_dir=None,
27
+ **kwargs,
28
+ ):
29
  """Manage one scene of ScanNet Dataset.
30
  Args:
31
  root_dir (str): ScanNet root directory that contains scene folders.
 
43
 
44
  # prepare data_names, intrinsics and extrinsics(T)
45
  with np.load(npz_path) as data:
46
+ self.data_names = data["name"]
47
+ if "score" in data.keys() and mode not in ["val" or "test"]:
48
+ kept_mask = data["score"] > min_overlap_score
49
  self.data_names = self.data_names[kept_mask]
50
  self.intrinsics = dict(np.load(intrinsic_path))
51
 
52
  # for training LoFTR
53
+ self.augment_fn = augment_fn if mode == "train" else None
54
 
55
  def __len__(self):
56
  return len(self.data_names)
57
 
58
  def _read_abs_pose(self, scene_name, name):
59
+ pth = osp.join(self.pose_dir, scene_name, "pose", f"{name}.txt")
 
 
60
  return read_scannet_pose(pth)
61
 
62
  def _compute_rel_pose(self, scene_name, name0, name1):
63
  pose0 = self._read_abs_pose(scene_name, name0)
64
  pose1 = self._read_abs_pose(scene_name, name1)
65
+
66
  return np.matmul(pose1, inv(pose0)) # (4, 4)
67
 
68
  def __getitem__(self, idx):
69
  data_name = self.data_names[idx]
70
  scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
71
+ scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
72
 
73
  # read the grayscale image which will be resized to (1, 480, 640)
74
+ img_name0 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_0}.jpg")
75
+ img_name1 = osp.join(self.root_dir, scene_name, "color", f"{stem_name_1}.jpg")
76
  # TODO: Support augmentation & handle seeds for each worker correctly.
77
  image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
78
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
79
  image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
80
+ # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
81
 
82
  # read the depthmap which is stored as (480, 640)
83
+ if self.mode in ["train", "val"]:
84
+ depth0 = read_scannet_depth(
85
+ osp.join(self.root_dir, scene_name, "depth", f"{stem_name_0}.png")
86
+ )
87
+ depth1 = read_scannet_depth(
88
+ osp.join(self.root_dir, scene_name, "depth", f"{stem_name_1}.png")
89
+ )
90
  else:
91
  depth0 = depth1 = torch.tensor([])
92
 
93
  # read the intrinsic of depthmap
94
+ K_0 = K_1 = torch.tensor(
95
+ self.intrinsics[scene_name].copy(), dtype=torch.float
96
+ ).reshape(3, 3)
97
 
98
  # read and compute relative poses
99
+ T_0to1 = torch.tensor(
100
+ self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
101
+ dtype=torch.float32,
102
+ )
103
  T_1to0 = T_0to1.inverse()
104
 
105
  data = {
106
+ "image0": image0, # (1, h, w)
107
+ "depth0": depth0, # (h, w)
108
+ "image1": image1,
109
+ "depth1": depth1,
110
+ "T_0to1": T_0to1, # (4, 4)
111
+ "T_1to0": T_1to0,
112
+ "K0": K_0, # (3, 3)
113
+ "K1": K_1,
114
+ "dataset_name": "ScanNet",
115
+ "scene_id": scene_name,
116
+ "pair_id": idx,
117
+ "pair_names": (
118
+ osp.join(scene_name, "color", f"{stem_name_0}.jpg"),
119
+ osp.join(scene_name, "color", f"{stem_name_1}.jpg"),
120
+ ),
121
  }
122
 
123
  return data
third_party/ASpanFormer/src/lightning/data.py CHANGED
@@ -16,7 +16,7 @@ from torch.utils.data import (
16
  ConcatDataset,
17
  DistributedSampler,
18
  RandomSampler,
19
- dataloader
20
  )
21
 
22
  from src.utils.augment import build_augmentor
@@ -29,10 +29,11 @@ from src.datasets.sampler import RandomConcatSampler
29
 
30
 
31
  class MultiSceneDataModule(pl.LightningDataModule):
32
- """
33
  For distributed training, each training process is assgined
34
  only a part of the training scenes to reduce memory overhead.
35
  """
 
36
  def __init__(self, args, config):
37
  super().__init__()
38
 
@@ -60,47 +61,51 @@ class MultiSceneDataModule(pl.LightningDataModule):
60
 
61
  # 2. dataset config
62
  # general options
63
- self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score
 
 
64
  self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
65
- self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile']
 
 
66
 
67
  # MegaDepth options
68
  self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
69
- self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
70
- self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
71
  self.mgdpt_df = config.DATASET.MGDPT_DF # 8
72
  self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr.
73
 
74
  # 3.loader parameters
75
  self.train_loader_params = {
76
- 'batch_size': args.batch_size,
77
- 'num_workers': args.num_workers,
78
- 'pin_memory': getattr(args, 'pin_memory', True)
79
  }
80
  self.val_loader_params = {
81
- 'batch_size': 1,
82
- 'shuffle': False,
83
- 'num_workers': args.num_workers,
84
- 'pin_memory': getattr(args, 'pin_memory', True)
85
  }
86
  self.test_loader_params = {
87
- 'batch_size': 1,
88
- 'shuffle': False,
89
- 'num_workers': args.num_workers,
90
- 'pin_memory': True
91
  }
92
-
93
  # 4. sampler
94
  self.data_sampler = config.TRAINER.DATA_SAMPLER
95
  self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
96
  self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
97
  self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
98
  self.repeat = config.TRAINER.SB_REPEAT
99
-
100
  # (optional) RandomSampler for debugging
101
 
102
  # misc configurations
103
- self.parallel_load_data = getattr(args, 'parallel_load_data', False)
104
  self.seed = config.TRAINER.SEED # 66
105
 
106
  def setup(self, stage=None):
@@ -110,7 +115,7 @@ class MultiSceneDataModule(pl.LightningDataModule):
110
  stage (str): 'fit' in training phase, and 'test' in testing phase.
111
  """
112
 
113
- assert stage in ['fit', 'test'], "stage must be either fit or test"
114
 
115
  try:
116
  self.world_size = dist.get_world_size()
@@ -121,73 +126,94 @@ class MultiSceneDataModule(pl.LightningDataModule):
121
  self.rank = 0
122
  logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
123
 
124
- if stage == 'fit':
125
  self.train_dataset = self._setup_dataset(
126
  self.train_data_root,
127
  self.train_npz_root,
128
  self.train_list_path,
129
  self.train_intrinsic_path,
130
- mode='train',
131
  min_overlap_score=self.min_overlap_score_train,
132
- pose_dir=self.train_pose_root)
 
133
  # setup multiple (optional) validation subsets
134
  if isinstance(self.val_list_path, (list, tuple)):
135
  self.val_dataset = []
136
  if not isinstance(self.val_npz_root, (list, tuple)):
137
- self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))]
 
 
138
  for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
139
- self.val_dataset.append(self._setup_dataset(
140
- self.val_data_root,
141
- npz_root,
142
- npz_list,
143
- self.val_intrinsic_path,
144
- mode='val',
145
- min_overlap_score=self.min_overlap_score_test,
146
- pose_dir=self.val_pose_root))
 
 
 
147
  else:
148
  self.val_dataset = self._setup_dataset(
149
  self.val_data_root,
150
  self.val_npz_root,
151
  self.val_list_path,
152
  self.val_intrinsic_path,
153
- mode='val',
154
  min_overlap_score=self.min_overlap_score_test,
155
- pose_dir=self.val_pose_root)
156
- logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
 
157
  else: # stage == 'test
158
  self.test_dataset = self._setup_dataset(
159
  self.test_data_root,
160
  self.test_npz_root,
161
  self.test_list_path,
162
  self.test_intrinsic_path,
163
- mode='test',
164
  min_overlap_score=self.min_overlap_score_test,
165
- pose_dir=self.test_pose_root)
166
- logger.info(f'[rank:{self.rank}]: Test Dataset loaded!')
 
167
 
168
- def _setup_dataset(self,
169
- data_root,
170
- split_npz_root,
171
- scene_list_path,
172
- intri_path,
173
- mode='train',
174
- min_overlap_score=0.,
175
- pose_dir=None):
176
- """ Setup train / val / test set"""
177
- with open(scene_list_path, 'r') as f:
 
 
178
  npz_names = [name.split()[0] for name in f.readlines()]
179
 
180
- if mode == 'train':
181
- local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed)
 
 
182
  else:
183
  local_npz_names = npz_names
184
- logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.')
185
-
186
- dataset_builder = self._build_concat_dataset_parallel \
187
- if self.parallel_load_data \
188
- else self._build_concat_dataset
189
- return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path,
190
- mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir)
 
 
 
 
 
 
 
 
 
191
 
192
  def _build_concat_dataset(
193
  self,
@@ -196,49 +222,61 @@ class MultiSceneDataModule(pl.LightningDataModule):
196
  npz_dir,
197
  intrinsic_path,
198
  mode,
199
- min_overlap_score=0.,
200
- pose_dir=None
201
  ):
202
  datasets = []
203
- augment_fn = self.augment_fn if mode == 'train' else None
204
- data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
205
- if data_source=='GL3D' and mode=='val':
206
- data_source='MegaDepth'
207
- if str(data_source).lower() == 'megadepth':
208
- npz_names = [f'{n}.npz' for n in npz_names]
209
- if str(data_source).lower() == 'gl3d':
210
- npz_names = [f'{n}.txt' for n in npz_names]
211
- #npz_names=npz_names[:8]
212
- for npz_name in tqdm(npz_names,
213
- desc=f'[rank:{self.rank}] loading {mode} datasets',
214
- disable=int(self.rank) != 0):
 
 
 
 
 
 
215
  # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
216
  npz_path = osp.join(npz_dir, npz_name)
217
- if data_source == 'ScanNet':
218
  datasets.append(
219
- ScanNetDataset(data_root,
220
- npz_path,
221
- intrinsic_path,
222
- mode=mode,
223
- min_overlap_score=min_overlap_score,
224
- augment_fn=augment_fn,
225
- pose_dir=pose_dir))
226
- elif data_source == 'MegaDepth':
 
 
 
227
  datasets.append(
228
- MegaDepthDataset(data_root,
229
- npz_path,
230
- mode=mode,
231
- min_overlap_score=min_overlap_score,
232
- img_resize=self.mgdpt_img_resize,
233
- df=self.mgdpt_df,
234
- img_padding=self.mgdpt_img_pad,
235
- depth_padding=self.mgdpt_depth_pad,
236
- augment_fn=augment_fn,
237
- coarse_scale=self.coarse_scale))
 
 
 
238
  else:
239
  raise NotImplementedError()
240
  return ConcatDataset(datasets)
241
-
242
  def _build_concat_dataset_parallel(
243
  self,
244
  data_root,
@@ -246,78 +284,119 @@ class MultiSceneDataModule(pl.LightningDataModule):
246
  npz_dir,
247
  intrinsic_path,
248
  mode,
249
- min_overlap_score=0.,
250
  pose_dir=None,
251
  ):
252
- augment_fn = self.augment_fn if mode == 'train' else None
253
- data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source
254
- if str(data_source).lower() == 'megadepth':
255
- npz_names = [f'{n}.npz' for n in npz_names]
256
- #npz_names=npz_names[:8]
257
- with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets',
258
- total=len(npz_names), disable=int(self.rank) != 0)):
259
- if data_source == 'ScanNet':
260
- datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
261
- delayed(lambda x: _build_dataset(
262
- ScanNetDataset,
263
- data_root,
264
- osp.join(npz_dir, x),
265
- intrinsic_path,
266
- mode=mode,
267
- min_overlap_score=min_overlap_score,
268
- augment_fn=augment_fn,
269
- pose_dir=pose_dir))(name)
270
- for name in npz_names)
271
- elif data_source == 'MegaDepth':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
273
  raise NotImplementedError()
274
- datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))(
275
- delayed(lambda x: _build_dataset(
276
- MegaDepthDataset,
277
- data_root,
278
- osp.join(npz_dir, x),
279
- mode=mode,
280
- min_overlap_score=min_overlap_score,
281
- img_resize=self.mgdpt_img_resize,
282
- df=self.mgdpt_df,
283
- img_padding=self.mgdpt_img_pad,
284
- depth_padding=self.mgdpt_depth_pad,
285
- augment_fn=augment_fn,
286
- coarse_scale=self.coarse_scale))(name)
287
- for name in npz_names)
 
 
 
 
 
 
 
 
288
  else:
289
- raise ValueError(f'Unknown dataset: {data_source}')
290
  return ConcatDataset(datasets)
291
 
292
  def train_dataloader(self):
293
- """ Build training dataloader for ScanNet / MegaDepth. """
294
- assert self.data_sampler in ['scene_balance']
295
- logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).')
296
- if self.data_sampler == 'scene_balance':
297
- sampler = RandomConcatSampler(self.train_dataset,
298
- self.n_samples_per_subset,
299
- self.subset_replacement,
300
- self.shuffle, self.repeat, self.seed)
 
 
 
 
 
 
301
  else:
302
  sampler = None
303
- dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
 
 
304
  return dataloader
305
-
306
  def val_dataloader(self):
307
- """ Build validation dataloader for ScanNet / MegaDepth. """
308
- logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.')
 
 
309
  if not isinstance(self.val_dataset, abc.Sequence):
310
  sampler = DistributedSampler(self.val_dataset, shuffle=False)
311
- return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
 
 
312
  else:
313
  dataloaders = []
314
  for dataset in self.val_dataset:
315
  sampler = DistributedSampler(dataset, shuffle=False)
316
- dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
 
 
317
  return dataloaders
318
 
319
  def test_dataloader(self, *args, **kwargs):
320
- logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.')
 
 
321
  sampler = DistributedSampler(self.test_dataset, shuffle=False)
322
  return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
323
 
 
16
  ConcatDataset,
17
  DistributedSampler,
18
  RandomSampler,
19
+ dataloader,
20
  )
21
 
22
  from src.utils.augment import build_augmentor
 
29
 
30
 
31
  class MultiSceneDataModule(pl.LightningDataModule):
32
+ """
33
  For distributed training, each training process is assgined
34
  only a part of the training scenes to reduce memory overhead.
35
  """
36
+
37
  def __init__(self, args, config):
38
  super().__init__()
39
 
 
61
 
62
  # 2. dataset config
63
  # general options
64
+ self.min_overlap_score_test = (
65
+ config.DATASET.MIN_OVERLAP_SCORE_TEST
66
+ ) # 0.4, omit data with overlap_score < min_overlap_score
67
  self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN
68
+ self.augment_fn = build_augmentor(
69
+ config.DATASET.AUGMENTATION_TYPE
70
+ ) # None, options: [None, 'dark', 'mobile']
71
 
72
  # MegaDepth options
73
  self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840
74
+ self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True
75
+ self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True
76
  self.mgdpt_df = config.DATASET.MGDPT_DF # 8
77
  self.coarse_scale = 1 / config.ASPAN.RESOLUTION[0] # 0.125. for training loftr.
78
 
79
  # 3.loader parameters
80
  self.train_loader_params = {
81
+ "batch_size": args.batch_size,
82
+ "num_workers": args.num_workers,
83
+ "pin_memory": getattr(args, "pin_memory", True),
84
  }
85
  self.val_loader_params = {
86
+ "batch_size": 1,
87
+ "shuffle": False,
88
+ "num_workers": args.num_workers,
89
+ "pin_memory": getattr(args, "pin_memory", True),
90
  }
91
  self.test_loader_params = {
92
+ "batch_size": 1,
93
+ "shuffle": False,
94
+ "num_workers": args.num_workers,
95
+ "pin_memory": True,
96
  }
97
+
98
  # 4. sampler
99
  self.data_sampler = config.TRAINER.DATA_SAMPLER
100
  self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
101
  self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT
102
  self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE
103
  self.repeat = config.TRAINER.SB_REPEAT
104
+
105
  # (optional) RandomSampler for debugging
106
 
107
  # misc configurations
108
+ self.parallel_load_data = getattr(args, "parallel_load_data", False)
109
  self.seed = config.TRAINER.SEED # 66
110
 
111
  def setup(self, stage=None):
 
115
  stage (str): 'fit' in training phase, and 'test' in testing phase.
116
  """
117
 
118
+ assert stage in ["fit", "test"], "stage must be either fit or test"
119
 
120
  try:
121
  self.world_size = dist.get_world_size()
 
126
  self.rank = 0
127
  logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")
128
 
129
+ if stage == "fit":
130
  self.train_dataset = self._setup_dataset(
131
  self.train_data_root,
132
  self.train_npz_root,
133
  self.train_list_path,
134
  self.train_intrinsic_path,
135
+ mode="train",
136
  min_overlap_score=self.min_overlap_score_train,
137
+ pose_dir=self.train_pose_root,
138
+ )
139
  # setup multiple (optional) validation subsets
140
  if isinstance(self.val_list_path, (list, tuple)):
141
  self.val_dataset = []
142
  if not isinstance(self.val_npz_root, (list, tuple)):
143
+ self.val_npz_root = [
144
+ self.val_npz_root for _ in range(len(self.val_list_path))
145
+ ]
146
  for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root):
147
+ self.val_dataset.append(
148
+ self._setup_dataset(
149
+ self.val_data_root,
150
+ npz_root,
151
+ npz_list,
152
+ self.val_intrinsic_path,
153
+ mode="val",
154
+ min_overlap_score=self.min_overlap_score_test,
155
+ pose_dir=self.val_pose_root,
156
+ )
157
+ )
158
  else:
159
  self.val_dataset = self._setup_dataset(
160
  self.val_data_root,
161
  self.val_npz_root,
162
  self.val_list_path,
163
  self.val_intrinsic_path,
164
+ mode="val",
165
  min_overlap_score=self.min_overlap_score_test,
166
+ pose_dir=self.val_pose_root,
167
+ )
168
+ logger.info(f"[rank:{self.rank}] Train & Val Dataset loaded!")
169
  else: # stage == 'test
170
  self.test_dataset = self._setup_dataset(
171
  self.test_data_root,
172
  self.test_npz_root,
173
  self.test_list_path,
174
  self.test_intrinsic_path,
175
+ mode="test",
176
  min_overlap_score=self.min_overlap_score_test,
177
+ pose_dir=self.test_pose_root,
178
+ )
179
+ logger.info(f"[rank:{self.rank}]: Test Dataset loaded!")
180
 
181
+ def _setup_dataset(
182
+ self,
183
+ data_root,
184
+ split_npz_root,
185
+ scene_list_path,
186
+ intri_path,
187
+ mode="train",
188
+ min_overlap_score=0.0,
189
+ pose_dir=None,
190
+ ):
191
+ """Setup train / val / test set"""
192
+ with open(scene_list_path, "r") as f:
193
  npz_names = [name.split()[0] for name in f.readlines()]
194
 
195
+ if mode == "train":
196
+ local_npz_names = get_local_split(
197
+ npz_names, self.world_size, self.rank, self.seed
198
+ )
199
  else:
200
  local_npz_names = npz_names
201
+ logger.info(f"[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.")
202
+
203
+ dataset_builder = (
204
+ self._build_concat_dataset_parallel
205
+ if self.parallel_load_data
206
+ else self._build_concat_dataset
207
+ )
208
+ return dataset_builder(
209
+ data_root,
210
+ local_npz_names,
211
+ split_npz_root,
212
+ intri_path,
213
+ mode=mode,
214
+ min_overlap_score=min_overlap_score,
215
+ pose_dir=pose_dir,
216
+ )
217
 
218
  def _build_concat_dataset(
219
  self,
 
222
  npz_dir,
223
  intrinsic_path,
224
  mode,
225
+ min_overlap_score=0.0,
226
+ pose_dir=None,
227
  ):
228
  datasets = []
229
+ augment_fn = self.augment_fn if mode == "train" else None
230
+ data_source = (
231
+ self.trainval_data_source
232
+ if mode in ["train", "val"]
233
+ else self.test_data_source
234
+ )
235
+ if data_source == "GL3D" and mode == "val":
236
+ data_source = "MegaDepth"
237
+ if str(data_source).lower() == "megadepth":
238
+ npz_names = [f"{n}.npz" for n in npz_names]
239
+ if str(data_source).lower() == "gl3d":
240
+ npz_names = [f"{n}.txt" for n in npz_names]
241
+ # npz_names=npz_names[:8]
242
+ for npz_name in tqdm(
243
+ npz_names,
244
+ desc=f"[rank:{self.rank}] loading {mode} datasets",
245
+ disable=int(self.rank) != 0,
246
+ ):
247
  # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time.
248
  npz_path = osp.join(npz_dir, npz_name)
249
+ if data_source == "ScanNet":
250
  datasets.append(
251
+ ScanNetDataset(
252
+ data_root,
253
+ npz_path,
254
+ intrinsic_path,
255
+ mode=mode,
256
+ min_overlap_score=min_overlap_score,
257
+ augment_fn=augment_fn,
258
+ pose_dir=pose_dir,
259
+ )
260
+ )
261
+ elif data_source == "MegaDepth":
262
  datasets.append(
263
+ MegaDepthDataset(
264
+ data_root,
265
+ npz_path,
266
+ mode=mode,
267
+ min_overlap_score=min_overlap_score,
268
+ img_resize=self.mgdpt_img_resize,
269
+ df=self.mgdpt_df,
270
+ img_padding=self.mgdpt_img_pad,
271
+ depth_padding=self.mgdpt_depth_pad,
272
+ augment_fn=augment_fn,
273
+ coarse_scale=self.coarse_scale,
274
+ )
275
+ )
276
  else:
277
  raise NotImplementedError()
278
  return ConcatDataset(datasets)
279
+
280
  def _build_concat_dataset_parallel(
281
  self,
282
  data_root,
 
284
  npz_dir,
285
  intrinsic_path,
286
  mode,
287
+ min_overlap_score=0.0,
288
  pose_dir=None,
289
  ):
290
+ augment_fn = self.augment_fn if mode == "train" else None
291
+ data_source = (
292
+ self.trainval_data_source
293
+ if mode in ["train", "val"]
294
+ else self.test_data_source
295
+ )
296
+ if str(data_source).lower() == "megadepth":
297
+ npz_names = [f"{n}.npz" for n in npz_names]
298
+ # npz_names=npz_names[:8]
299
+ with tqdm_joblib(
300
+ tqdm(
301
+ desc=f"[rank:{self.rank}] loading {mode} datasets",
302
+ total=len(npz_names),
303
+ disable=int(self.rank) != 0,
304
+ )
305
+ ):
306
+ if data_source == "ScanNet":
307
+ datasets = Parallel(
308
+ n_jobs=math.floor(
309
+ len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
310
+ )
311
+ )(
312
+ delayed(
313
+ lambda x: _build_dataset(
314
+ ScanNetDataset,
315
+ data_root,
316
+ osp.join(npz_dir, x),
317
+ intrinsic_path,
318
+ mode=mode,
319
+ min_overlap_score=min_overlap_score,
320
+ augment_fn=augment_fn,
321
+ pose_dir=pose_dir,
322
+ )
323
+ )(name)
324
+ for name in npz_names
325
+ )
326
+ elif data_source == "MegaDepth":
327
  # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers.
328
  raise NotImplementedError()
329
+ datasets = Parallel(
330
+ n_jobs=math.floor(
331
+ len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()
332
+ )
333
+ )(
334
+ delayed(
335
+ lambda x: _build_dataset(
336
+ MegaDepthDataset,
337
+ data_root,
338
+ osp.join(npz_dir, x),
339
+ mode=mode,
340
+ min_overlap_score=min_overlap_score,
341
+ img_resize=self.mgdpt_img_resize,
342
+ df=self.mgdpt_df,
343
+ img_padding=self.mgdpt_img_pad,
344
+ depth_padding=self.mgdpt_depth_pad,
345
+ augment_fn=augment_fn,
346
+ coarse_scale=self.coarse_scale,
347
+ )
348
+ )(name)
349
+ for name in npz_names
350
+ )
351
  else:
352
+ raise ValueError(f"Unknown dataset: {data_source}")
353
  return ConcatDataset(datasets)
354
 
355
  def train_dataloader(self):
356
+ """Build training dataloader for ScanNet / MegaDepth."""
357
+ assert self.data_sampler in ["scene_balance"]
358
+ logger.info(
359
+ f"[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!)."
360
+ )
361
+ if self.data_sampler == "scene_balance":
362
+ sampler = RandomConcatSampler(
363
+ self.train_dataset,
364
+ self.n_samples_per_subset,
365
+ self.subset_replacement,
366
+ self.shuffle,
367
+ self.repeat,
368
+ self.seed,
369
+ )
370
  else:
371
  sampler = None
372
+ dataloader = DataLoader(
373
+ self.train_dataset, sampler=sampler, **self.train_loader_params
374
+ )
375
  return dataloader
376
+
377
  def val_dataloader(self):
378
+ """Build validation dataloader for ScanNet / MegaDepth."""
379
+ logger.info(
380
+ f"[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init."
381
+ )
382
  if not isinstance(self.val_dataset, abc.Sequence):
383
  sampler = DistributedSampler(self.val_dataset, shuffle=False)
384
+ return DataLoader(
385
+ self.val_dataset, sampler=sampler, **self.val_loader_params
386
+ )
387
  else:
388
  dataloaders = []
389
  for dataset in self.val_dataset:
390
  sampler = DistributedSampler(dataset, shuffle=False)
391
+ dataloaders.append(
392
+ DataLoader(dataset, sampler=sampler, **self.val_loader_params)
393
+ )
394
  return dataloaders
395
 
396
  def test_dataloader(self, *args, **kwargs):
397
+ logger.info(
398
+ f"[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init."
399
+ )
400
  sampler = DistributedSampler(self.test_dataset, shuffle=False)
401
  return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params)
402
 
third_party/ASpanFormer/src/lightning/lightning_aspanformer.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from collections import defaultdict
3
  import pprint
4
  from loguru import logger
@@ -10,15 +9,19 @@ import pytorch_lightning as pl
10
  from matplotlib import pyplot as plt
11
 
12
  from src.ASpanFormer.aspanformer import ASpanFormer
13
- from src.ASpanFormer.utils.supervision import compute_supervision_coarse, compute_supervision_fine
 
 
 
14
  from src.losses.aspan_loss import ASpanLoss
15
  from src.optimizers import build_optimizer, build_scheduler
16
  from src.utils.metrics import (
17
- compute_symmetrical_epipolar_errors,compute_symmetrical_epipolar_errors_offset_bidirectional,
 
18
  compute_pose_errors,
19
- aggregate_metrics
20
  )
21
- from src.utils.plotting import make_matching_figures,make_matching_figures_offset
22
  from src.utils.comm import gather, all_gather
23
  from src.utils.misc import lower_config, flattenList
24
  from src.utils.profiler import PassThroughProfiler
@@ -34,200 +37,288 @@ class PL_ASpanFormer(pl.LightningModule):
34
  # Misc
35
  self.config = config # full config
36
  _config = lower_config(self.config)
37
- self.loftr_cfg = lower_config(_config['aspan'])
38
  self.profiler = profiler or PassThroughProfiler()
39
- self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
 
 
40
 
41
  # Matcher: LoFTR
42
- self.matcher = ASpanFormer(config=_config['aspan'])
43
  self.loss = ASpanLoss(_config)
44
 
45
  # Pretrained weights
46
  print(pretrained_ckpt)
47
  if pretrained_ckpt:
48
- print('load')
49
- state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
50
- msg=self.matcher.load_state_dict(state_dict, strict=False)
51
  print(msg)
52
- logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
53
-
54
  # Testing
55
  self.dump_dir = dump_dir
56
-
57
  def configure_optimizers(self):
58
  # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
59
  optimizer = build_optimizer(self, self.config)
60
  scheduler = build_scheduler(self.config, optimizer)
61
  return [optimizer], [scheduler]
62
-
63
  def optimizer_step(
64
- self, epoch, batch_idx, optimizer, optimizer_idx,
65
- optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
 
 
 
 
 
 
 
 
66
  # learning rate warm up
67
  warmup_step = self.config.TRAINER.WARMUP_STEP
68
  if self.trainer.global_step < warmup_step:
69
- if self.config.TRAINER.WARMUP_TYPE == 'linear':
70
  base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
71
- lr = base_lr + \
72
- (self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
73
- abs(self.config.TRAINER.TRUE_LR - base_lr)
74
  for pg in optimizer.param_groups:
75
- pg['lr'] = lr
76
- elif self.config.TRAINER.WARMUP_TYPE == 'constant':
77
  pass
78
  else:
79
- raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
 
 
80
 
81
  # update params
82
  optimizer.step(closure=optimizer_closure)
83
  optimizer.zero_grad()
84
-
85
  def _trainval_inference(self, batch):
86
  with self.profiler.profile("Compute coarse supervision"):
87
- compute_supervision_coarse(batch, self.config)
88
-
89
  with self.profiler.profile("LoFTR"):
90
- self.matcher(batch)
91
-
92
  with self.profiler.profile("Compute fine supervision"):
93
- compute_supervision_fine(batch, self.config)
94
-
95
  with self.profiler.profile("Compute losses"):
96
- self.loss(batch)
97
-
98
  def _compute_metrics(self, batch):
99
  with self.profiler.profile("Copmute metrics"):
100
- compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
101
- compute_symmetrical_epipolar_errors_offset_bidirectional(batch) # compute epi_errs for offset match
102
- compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
 
 
 
 
 
 
103
 
104
- rel_pair_names = list(zip(*batch['pair_names']))
105
- bs = batch['image0'].size(0)
106
  metrics = {
107
  # to filter duplicate pairs caused by DistributedSampler
108
- 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
109
- 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
110
- 'epi_errs_offset': [batch['epi_errs_offset_left'][batch['offset_bids_left'] == b].cpu().numpy() for b in range(bs)], #only consider left side
111
- 'R_errs': batch['R_errs'],
112
- 't_errs': batch['t_errs'],
113
- 'inliers': batch['inliers']}
114
- ret_dict = {'metrics': metrics}
 
 
 
 
 
 
 
 
 
115
  return ret_dict, rel_pair_names
116
-
117
-
118
  def training_step(self, batch, batch_idx):
119
  self._trainval_inference(batch)
120
-
121
  # logging
122
- if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
 
 
 
123
  # scalars
124
- for k, v in batch['loss_scalars'].items():
125
- if not k.startswith('loss_flow') and not k.startswith('conf_'):
126
- self.logger.experiment.add_scalar(f'train/{k}', v, self.global_step)
127
-
128
- #log offset_loss and conf for each layer and level
129
- layer_num=self.loftr_cfg['coarse']['layer_num']
130
  for layer_index in range(layer_num):
131
- log_title='layer_'+str(layer_index)
132
- self.logger.experiment.add_scalar(log_title+'/offset_loss', batch['loss_scalars']['loss_flow_'+str(layer_index)], self.global_step)
133
- self.logger.experiment.add_scalar(log_title+'/conf_', batch['loss_scalars']['conf_'+str(layer_index)],self.global_step)
134
-
 
 
 
 
 
 
 
 
135
  # net-params
136
- if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == 'sinkhorn':
137
  self.logger.experiment.add_scalar(
138
- f'skh_bin_score', self.matcher.coarse_matching.bin_score.clone().detach().cpu().data, self.global_step)
 
 
 
139
 
140
  # figures
141
  if self.config.TRAINER.ENABLE_PLOTTING:
142
- compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
143
- figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
 
 
 
 
144
  for k, v in figures.items():
145
- self.logger.experiment.add_figure(f'train_match/{k}', v, self.global_step)
 
 
146
 
147
- #plot offset
148
- if self.global_step%200==0:
149
  compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
150
- figures_left = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_left')
151
- figures_right = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
 
 
 
 
152
  for k, v in figures_left.items():
153
- self.logger.experiment.add_figure(f'train_offset/{k}'+'_left', v, self.global_step)
154
- figures = make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,side='_right')
 
 
 
 
155
  for k, v in figures_right.items():
156
- self.logger.experiment.add_figure(f'train_offset/{k}'+'_right', v, self.global_step)
157
-
158
- return {'loss': batch['loss']}
 
 
159
 
160
  def training_epoch_end(self, outputs):
161
- avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
162
  if self.trainer.global_rank == 0:
163
  self.logger.experiment.add_scalar(
164
- 'train/avg_loss_on_epoch', avg_loss,
165
- global_step=self.current_epoch)
166
-
167
  def validation_step(self, batch, batch_idx):
168
  self._trainval_inference(batch)
169
-
170
- ret_dict, _ = self._compute_metrics(batch) #this func also compute the epi_errors
171
-
 
 
172
  val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
173
  figures = {self.config.TRAINER.PLOT_MODE: []}
174
  figures_offset = {self.config.TRAINER.PLOT_MODE: []}
175
  if batch_idx % val_plot_interval == 0:
176
- figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE)
177
- figures_offset=make_matching_figures_offset(batch, self.config, self.config.TRAINER.PLOT_MODE,'_left')
 
 
 
 
178
  return {
179
  **ret_dict,
180
- 'loss_scalars': batch['loss_scalars'],
181
- 'figures': figures,
182
- 'figures_offset_left':figures_offset
183
  }
184
-
185
  def validation_epoch_end(self, outputs):
186
  # handle multiple validation sets
187
- multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
 
 
188
  multi_val_metrics = defaultdict(list)
189
-
190
  for valset_idx, outputs in enumerate(multi_outputs):
191
  # since pl performs sanity_check at the very begining of the training
192
  cur_epoch = self.trainer.current_epoch
193
- if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
 
 
 
194
  cur_epoch = -1
195
 
196
  # 1. loss_scalars: dict of list, on cpu
197
- _loss_scalars = [o['loss_scalars'] for o in outputs]
198
- loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
 
 
 
199
 
200
  # 2. val metrics: dict of list, numpy
201
- _metrics = [o['metrics'] for o in outputs]
202
- metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
203
- # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
204
- val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
 
 
 
 
 
205
  for thr in [5, 10, 20]:
206
- multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
207
-
208
  # 3. figures
209
- _figures = [o['figures'] for o in outputs]
210
- figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
 
 
 
211
 
212
  # tensorboard records only on rank 0
213
  if self.trainer.global_rank == 0:
214
  for k, v in loss_scalars.items():
215
  mean_v = torch.stack(v).mean()
216
- self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
 
 
217
 
218
  for k, v in val_metrics_4tb.items():
219
- self.logger.experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
220
-
 
 
221
  for k, v in figures.items():
222
  if self.trainer.global_rank == 0:
223
  for plot_idx, fig in enumerate(v):
224
  self.logger.experiment.add_figure(
225
- f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
226
- plt.close('all')
 
 
 
 
227
 
228
  for thr in [5, 10, 20]:
229
  # log on all ranks for ModelCheckpoint callback to work properly
230
- self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
 
 
231
 
232
  def test_step(self, batch, batch_idx):
233
  with self.profiler.profile("LoFTR"):
@@ -238,39 +329,46 @@ class PL_ASpanFormer(pl.LightningModule):
238
  with self.profiler.profile("dump_results"):
239
  if self.dump_dir is not None:
240
  # dump results for further analysis
241
- keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'}
242
- pair_names = list(zip(*batch['pair_names']))
243
- bs = batch['image0'].shape[0]
244
  dumps = []
245
  for b_id in range(bs):
246
  item = {}
247
- mask = batch['m_bids'] == b_id
248
- item['pair_names'] = pair_names[b_id]
249
- item['identifier'] = '#'.join(rel_pair_names[b_id])
250
  for key in keys_to_save:
251
  item[key] = batch[key][mask].cpu().numpy()
252
- for key in ['R_errs', 't_errs', 'inliers']:
253
  item[key] = batch[key][b_id]
254
  dumps.append(item)
255
- ret_dict['dumps'] = dumps
256
 
257
  return ret_dict
258
 
259
  def test_epoch_end(self, outputs):
260
  # metrics: dict of list, numpy
261
- _metrics = [o['metrics'] for o in outputs]
262
- metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
 
 
 
263
 
264
  # [{key: [{...}, *#bs]}, *#batch]
265
  if self.dump_dir is not None:
266
  Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
267
- _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
268
  dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
269
- logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
 
 
270
 
271
  if self.trainer.global_rank == 0:
272
  print(self.profiler.summary())
273
- val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
274
- logger.info('\n' + pprint.pformat(val_metrics_4tb))
 
 
275
  if self.dump_dir is not None:
276
- np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps)
 
 
1
  from collections import defaultdict
2
  import pprint
3
  from loguru import logger
 
9
  from matplotlib import pyplot as plt
10
 
11
  from src.ASpanFormer.aspanformer import ASpanFormer
12
+ from src.ASpanFormer.utils.supervision import (
13
+ compute_supervision_coarse,
14
+ compute_supervision_fine,
15
+ )
16
  from src.losses.aspan_loss import ASpanLoss
17
  from src.optimizers import build_optimizer, build_scheduler
18
  from src.utils.metrics import (
19
+ compute_symmetrical_epipolar_errors,
20
+ compute_symmetrical_epipolar_errors_offset_bidirectional,
21
  compute_pose_errors,
22
+ aggregate_metrics,
23
  )
24
+ from src.utils.plotting import make_matching_figures, make_matching_figures_offset
25
  from src.utils.comm import gather, all_gather
26
  from src.utils.misc import lower_config, flattenList
27
  from src.utils.profiler import PassThroughProfiler
 
37
  # Misc
38
  self.config = config # full config
39
  _config = lower_config(self.config)
40
+ self.loftr_cfg = lower_config(_config["aspan"])
41
  self.profiler = profiler or PassThroughProfiler()
42
+ self.n_vals_plot = max(
43
+ config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1
44
+ )
45
 
46
  # Matcher: LoFTR
47
+ self.matcher = ASpanFormer(config=_config["aspan"])
48
  self.loss = ASpanLoss(_config)
49
 
50
  # Pretrained weights
51
  print(pretrained_ckpt)
52
  if pretrained_ckpt:
53
+ print("load")
54
+ state_dict = torch.load(pretrained_ckpt, map_location="cpu")["state_dict"]
55
+ msg = self.matcher.load_state_dict(state_dict, strict=False)
56
  print(msg)
57
+ logger.info(f"Load '{pretrained_ckpt}' as pretrained checkpoint")
58
+
59
  # Testing
60
  self.dump_dir = dump_dir
61
+
62
  def configure_optimizers(self):
63
  # FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
64
  optimizer = build_optimizer(self, self.config)
65
  scheduler = build_scheduler(self.config, optimizer)
66
  return [optimizer], [scheduler]
67
+
68
  def optimizer_step(
69
+ self,
70
+ epoch,
71
+ batch_idx,
72
+ optimizer,
73
+ optimizer_idx,
74
+ optimizer_closure,
75
+ on_tpu,
76
+ using_native_amp,
77
+ using_lbfgs,
78
+ ):
79
  # learning rate warm up
80
  warmup_step = self.config.TRAINER.WARMUP_STEP
81
  if self.trainer.global_step < warmup_step:
82
+ if self.config.TRAINER.WARMUP_TYPE == "linear":
83
  base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
84
+ lr = base_lr + (
85
+ self.trainer.global_step / self.config.TRAINER.WARMUP_STEP
86
+ ) * abs(self.config.TRAINER.TRUE_LR - base_lr)
87
  for pg in optimizer.param_groups:
88
+ pg["lr"] = lr
89
+ elif self.config.TRAINER.WARMUP_TYPE == "constant":
90
  pass
91
  else:
92
+ raise ValueError(
93
+ f"Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}"
94
+ )
95
 
96
  # update params
97
  optimizer.step(closure=optimizer_closure)
98
  optimizer.zero_grad()
99
+
100
  def _trainval_inference(self, batch):
101
  with self.profiler.profile("Compute coarse supervision"):
102
+ compute_supervision_coarse(batch, self.config)
103
+
104
  with self.profiler.profile("LoFTR"):
105
+ self.matcher(batch)
106
+
107
  with self.profiler.profile("Compute fine supervision"):
108
+ compute_supervision_fine(batch, self.config)
109
+
110
  with self.profiler.profile("Compute losses"):
111
+ self.loss(batch)
112
+
113
  def _compute_metrics(self, batch):
114
  with self.profiler.profile("Copmute metrics"):
115
+ compute_symmetrical_epipolar_errors(
116
+ batch
117
+ ) # compute epi_errs for each match
118
+ compute_symmetrical_epipolar_errors_offset_bidirectional(
119
+ batch
120
+ ) # compute epi_errs for offset match
121
+ compute_pose_errors(
122
+ batch, self.config
123
+ ) # compute R_errs, t_errs, pose_errs for each pair
124
 
125
+ rel_pair_names = list(zip(*batch["pair_names"]))
126
+ bs = batch["image0"].size(0)
127
  metrics = {
128
  # to filter duplicate pairs caused by DistributedSampler
129
+ "identifiers": ["#".join(rel_pair_names[b]) for b in range(bs)],
130
+ "epi_errs": [
131
+ batch["epi_errs"][batch["m_bids"] == b].cpu().numpy()
132
+ for b in range(bs)
133
+ ],
134
+ "epi_errs_offset": [
135
+ batch["epi_errs_offset_left"][batch["offset_bids_left"] == b]
136
+ .cpu()
137
+ .numpy()
138
+ for b in range(bs)
139
+ ], # only consider left side
140
+ "R_errs": batch["R_errs"],
141
+ "t_errs": batch["t_errs"],
142
+ "inliers": batch["inliers"],
143
+ }
144
+ ret_dict = {"metrics": metrics}
145
  return ret_dict, rel_pair_names
146
+
 
147
  def training_step(self, batch, batch_idx):
148
  self._trainval_inference(batch)
149
+
150
  # logging
151
+ if (
152
+ self.trainer.global_rank == 0
153
+ and self.global_step % self.trainer.log_every_n_steps == 0
154
+ ):
155
  # scalars
156
+ for k, v in batch["loss_scalars"].items():
157
+ if not k.startswith("loss_flow") and not k.startswith("conf_"):
158
+ self.logger.experiment.add_scalar(f"train/{k}", v, self.global_step)
159
+
160
+ # log offset_loss and conf for each layer and level
161
+ layer_num = self.loftr_cfg["coarse"]["layer_num"]
162
  for layer_index in range(layer_num):
163
+ log_title = "layer_" + str(layer_index)
164
+ self.logger.experiment.add_scalar(
165
+ log_title + "/offset_loss",
166
+ batch["loss_scalars"]["loss_flow_" + str(layer_index)],
167
+ self.global_step,
168
+ )
169
+ self.logger.experiment.add_scalar(
170
+ log_title + "/conf_",
171
+ batch["loss_scalars"]["conf_" + str(layer_index)],
172
+ self.global_step,
173
+ )
174
+
175
  # net-params
176
+ if self.config.ASPAN.MATCH_COARSE.MATCH_TYPE == "sinkhorn":
177
  self.logger.experiment.add_scalar(
178
+ f"skh_bin_score",
179
+ self.matcher.coarse_matching.bin_score.clone().detach().cpu().data,
180
+ self.global_step,
181
+ )
182
 
183
  # figures
184
  if self.config.TRAINER.ENABLE_PLOTTING:
185
+ compute_symmetrical_epipolar_errors(
186
+ batch
187
+ ) # compute epi_errs for each match
188
+ figures = make_matching_figures(
189
+ batch, self.config, self.config.TRAINER.PLOT_MODE
190
+ )
191
  for k, v in figures.items():
192
+ self.logger.experiment.add_figure(
193
+ f"train_match/{k}", v, self.global_step
194
+ )
195
 
196
+ # plot offset
197
+ if self.global_step % 200 == 0:
198
  compute_symmetrical_epipolar_errors_offset_bidirectional(batch)
199
+ figures_left = make_matching_figures_offset(
200
+ batch, self.config, self.config.TRAINER.PLOT_MODE, side="_left"
201
+ )
202
+ figures_right = make_matching_figures_offset(
203
+ batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
204
+ )
205
  for k, v in figures_left.items():
206
+ self.logger.experiment.add_figure(
207
+ f"train_offset/{k}" + "_left", v, self.global_step
208
+ )
209
+ figures = make_matching_figures_offset(
210
+ batch, self.config, self.config.TRAINER.PLOT_MODE, side="_right"
211
+ )
212
  for k, v in figures_right.items():
213
+ self.logger.experiment.add_figure(
214
+ f"train_offset/{k}" + "_right", v, self.global_step
215
+ )
216
+
217
+ return {"loss": batch["loss"]}
218
 
219
  def training_epoch_end(self, outputs):
220
+ avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
221
  if self.trainer.global_rank == 0:
222
  self.logger.experiment.add_scalar(
223
+ "train/avg_loss_on_epoch", avg_loss, global_step=self.current_epoch
224
+ )
225
+
226
  def validation_step(self, batch, batch_idx):
227
  self._trainval_inference(batch)
228
+
229
+ ret_dict, _ = self._compute_metrics(
230
+ batch
231
+ ) # this func also compute the epi_errors
232
+
233
  val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
234
  figures = {self.config.TRAINER.PLOT_MODE: []}
235
  figures_offset = {self.config.TRAINER.PLOT_MODE: []}
236
  if batch_idx % val_plot_interval == 0:
237
+ figures = make_matching_figures(
238
+ batch, self.config, mode=self.config.TRAINER.PLOT_MODE
239
+ )
240
+ figures_offset = make_matching_figures_offset(
241
+ batch, self.config, self.config.TRAINER.PLOT_MODE, "_left"
242
+ )
243
  return {
244
  **ret_dict,
245
+ "loss_scalars": batch["loss_scalars"],
246
+ "figures": figures,
247
+ "figures_offset_left": figures_offset,
248
  }
249
+
250
  def validation_epoch_end(self, outputs):
251
  # handle multiple validation sets
252
+ multi_outputs = (
253
+ [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
254
+ )
255
  multi_val_metrics = defaultdict(list)
256
+
257
  for valset_idx, outputs in enumerate(multi_outputs):
258
  # since pl performs sanity_check at the very begining of the training
259
  cur_epoch = self.trainer.current_epoch
260
+ if (
261
+ not self.trainer.resume_from_checkpoint
262
+ and self.trainer.running_sanity_check
263
+ ):
264
  cur_epoch = -1
265
 
266
  # 1. loss_scalars: dict of list, on cpu
267
+ _loss_scalars = [o["loss_scalars"] for o in outputs]
268
+ loss_scalars = {
269
+ k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars]))
270
+ for k in _loss_scalars[0]
271
+ }
272
 
273
  # 2. val metrics: dict of list, numpy
274
+ _metrics = [o["metrics"] for o in outputs]
275
+ metrics = {
276
+ k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics])))
277
+ for k in _metrics[0]
278
+ }
279
+ # NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
280
+ val_metrics_4tb = aggregate_metrics(
281
+ metrics, self.config.TRAINER.EPI_ERR_THR
282
+ )
283
  for thr in [5, 10, 20]:
284
+ multi_val_metrics[f"auc@{thr}"].append(val_metrics_4tb[f"auc@{thr}"])
285
+
286
  # 3. figures
287
+ _figures = [o["figures"] for o in outputs]
288
+ figures = {
289
+ k: flattenList(gather(flattenList([_me[k] for _me in _figures])))
290
+ for k in _figures[0]
291
+ }
292
 
293
  # tensorboard records only on rank 0
294
  if self.trainer.global_rank == 0:
295
  for k, v in loss_scalars.items():
296
  mean_v = torch.stack(v).mean()
297
+ self.logger.experiment.add_scalar(
298
+ f"val_{valset_idx}/avg_{k}", mean_v, global_step=cur_epoch
299
+ )
300
 
301
  for k, v in val_metrics_4tb.items():
302
+ self.logger.experiment.add_scalar(
303
+ f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch
304
+ )
305
+
306
  for k, v in figures.items():
307
  if self.trainer.global_rank == 0:
308
  for plot_idx, fig in enumerate(v):
309
  self.logger.experiment.add_figure(
310
+ f"val_match_{valset_idx}/{k}/pair-{plot_idx}",
311
+ fig,
312
+ cur_epoch,
313
+ close=True,
314
+ )
315
+ plt.close("all")
316
 
317
  for thr in [5, 10, 20]:
318
  # log on all ranks for ModelCheckpoint callback to work properly
319
+ self.log(
320
+ f"auc@{thr}", torch.tensor(np.mean(multi_val_metrics[f"auc@{thr}"]))
321
+ ) # ckpt monitors on this
322
 
323
  def test_step(self, batch, batch_idx):
324
  with self.profiler.profile("LoFTR"):
 
329
  with self.profiler.profile("dump_results"):
330
  if self.dump_dir is not None:
331
  # dump results for further analysis
332
+ keys_to_save = {"mkpts0_f", "mkpts1_f", "mconf", "epi_errs"}
333
+ pair_names = list(zip(*batch["pair_names"]))
334
+ bs = batch["image0"].shape[0]
335
  dumps = []
336
  for b_id in range(bs):
337
  item = {}
338
+ mask = batch["m_bids"] == b_id
339
+ item["pair_names"] = pair_names[b_id]
340
+ item["identifier"] = "#".join(rel_pair_names[b_id])
341
  for key in keys_to_save:
342
  item[key] = batch[key][mask].cpu().numpy()
343
+ for key in ["R_errs", "t_errs", "inliers"]:
344
  item[key] = batch[key][b_id]
345
  dumps.append(item)
346
+ ret_dict["dumps"] = dumps
347
 
348
  return ret_dict
349
 
350
  def test_epoch_end(self, outputs):
351
  # metrics: dict of list, numpy
352
+ _metrics = [o["metrics"] for o in outputs]
353
+ metrics = {
354
+ k: flattenList(gather(flattenList([_me[k] for _me in _metrics])))
355
+ for k in _metrics[0]
356
+ }
357
 
358
  # [{key: [{...}, *#bs]}, *#batch]
359
  if self.dump_dir is not None:
360
  Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
361
+ _dumps = flattenList([o["dumps"] for o in outputs]) # [{...}, #bs*#batch]
362
  dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
363
+ logger.info(
364
+ f"Prediction and evaluation results will be saved to: {self.dump_dir}"
365
+ )
366
 
367
  if self.trainer.global_rank == 0:
368
  print(self.profiler.summary())
369
+ val_metrics_4tb = aggregate_metrics(
370
+ metrics, self.config.TRAINER.EPI_ERR_THR
371
+ )
372
+ logger.info("\n" + pprint.pformat(val_metrics_4tb))
373
  if self.dump_dir is not None:
374
+ np.save(Path(self.dump_dir) / "LoFTR_pred_eval", dumps)
third_party/ASpanFormer/src/losses/aspan_loss.py CHANGED
@@ -3,48 +3,55 @@ from loguru import logger
3
  import torch
4
  import torch.nn as nn
5
 
 
6
  class ASpanLoss(nn.Module):
7
  def __init__(self, config):
8
  super().__init__()
9
  self.config = config # config under the global namespace
10
- self.loss_config = config['aspan']['loss']
11
- self.match_type = self.config['aspan']['match_coarse']['match_type']
12
- self.sparse_spvs = self.config['aspan']['match_coarse']['sparse_spvs']
13
- self.flow_weight=self.config['aspan']['loss']['flow_weight']
14
 
15
  # coarse-level
16
- self.correct_thr = self.loss_config['fine_correct_thr']
17
- self.c_pos_w = self.loss_config['pos_weight']
18
- self.c_neg_w = self.loss_config['neg_weight']
19
  # fine-level
20
- self.fine_type = self.loss_config['fine_type']
21
-
22
- def compute_flow_loss(self,coarse_corr_gt,flow_list,h0,w0,h1,w1):
23
- #coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
24
- #flow_list: [L,B,H,W,4]
25
- loss1=self.flow_loss_worker(flow_list[0],coarse_corr_gt[0],coarse_corr_gt[1],coarse_corr_gt[2],w1)
26
- loss2=self.flow_loss_worker(flow_list[1],coarse_corr_gt[0],coarse_corr_gt[2],coarse_corr_gt[1],w0)
27
- total_loss=(loss1+loss2)/2
 
 
 
 
28
  return total_loss
29
 
30
- def flow_loss_worker(self,flow,batch_indicies,self_indicies,cross_indicies,w):
31
- bs,layer_num=flow.shape[1],flow.shape[0]
32
- flow=flow.view(layer_num,bs,-1,4)
33
- gt_flow=torch.stack([cross_indicies%w,cross_indicies//w],dim=1)
34
 
35
- total_loss_list=[]
36
  for layer_index in range(layer_num):
37
- cur_flow_list=flow[layer_index]
38
- spv_flow=cur_flow_list[batch_indicies,self_indicies][:,:2]
39
- spv_conf=cur_flow_list[batch_indicies,self_indicies][:,2:]#[#coarse,2]
40
- l2_flow_dis=((gt_flow-spv_flow)**2) #[#coarse,2]
41
- total_loss=(spv_conf+torch.exp(-spv_conf)*l2_flow_dis) #[#coarse,2]
 
 
42
  total_loss_list.append(total_loss.mean())
43
- total_loss=torch.stack(total_loss_list,dim=-1)*self.flow_weight
44
  return total_loss
45
-
46
  def compute_coarse_loss(self, conf, conf_gt, weight=None):
47
- """ Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
48
  Args:
49
  conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
50
  conf_gt (torch.Tensor): (N, HW0, HW1)
@@ -56,38 +63,44 @@ class ASpanLoss(nn.Module):
56
  if not pos_mask.any(): # assign a wrong gt
57
  pos_mask[0, 0, 0] = True
58
  if weight is not None:
59
- weight[0, 0, 0] = 0.
60
- c_pos_w = 0.
61
  if not neg_mask.any():
62
  neg_mask[0, 0, 0] = True
63
  if weight is not None:
64
- weight[0, 0, 0] = 0.
65
- c_neg_w = 0.
66
-
67
- if self.loss_config['coarse_type'] == 'cross_entropy':
68
- assert not self.sparse_spvs, 'Sparse Supervision for cross-entropy not implemented!'
69
- conf = torch.clamp(conf, 1e-6, 1-1e-6)
70
- loss_pos = - torch.log(conf[pos_mask])
71
- loss_neg = - torch.log(1 - conf[neg_mask])
 
 
72
  if weight is not None:
73
  loss_pos = loss_pos * weight[pos_mask]
74
  loss_neg = loss_neg * weight[neg_mask]
75
  return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
76
- elif self.loss_config['coarse_type'] == 'focal':
77
- conf = torch.clamp(conf, 1e-6, 1-1e-6)
78
- alpha = self.loss_config['focal_alpha']
79
- gamma = self.loss_config['focal_gamma']
80
-
81
  if self.sparse_spvs:
82
- pos_conf = conf[:, :-1, :-1][pos_mask] \
83
- if self.match_type == 'sinkhorn' \
84
- else conf[pos_mask]
85
- loss_pos = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
 
 
86
  # calculate losses for negative samples
87
- if self.match_type == 'sinkhorn':
88
  neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
89
- neg_conf = torch.cat([conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0)
90
- loss_neg = - alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
 
 
91
  else:
92
  # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
93
  # we could also add 'pseudo negtive-samples'
@@ -97,32 +110,46 @@ class ASpanLoss(nn.Module):
97
  # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
98
  # but only through manually setting corresponding regions in sim_matrix to '-inf'.
99
  loss_pos = loss_pos * weight[pos_mask]
100
- if self.match_type == 'sinkhorn':
101
  neg_w0 = (weight.sum(-1) != 0)[neg0]
102
  neg_w1 = (weight.sum(1) != 0)[neg1]
103
  neg_mask = torch.cat([neg_w0, neg_w1], 0)
104
  loss_neg = loss_neg[neg_mask]
105
-
106
- loss = c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean() \
107
- if self.match_type == 'sinkhorn' \
108
- else c_pos_w * loss_pos.mean()
 
 
109
  return loss
110
  # positive and negative elements occupy similar propotions. => more balanced loss weights needed
111
  else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
112
- loss_pos = - alpha * torch.pow(1 - conf[pos_mask], gamma) * (conf[pos_mask]).log()
113
- loss_neg = - alpha * torch.pow(conf[neg_mask], gamma) * (1 - conf[neg_mask]).log()
 
 
 
 
 
 
 
 
114
  if weight is not None:
115
  loss_pos = loss_pos * weight[pos_mask]
116
  loss_neg = loss_neg * weight[neg_mask]
117
  return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
118
  # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
119
  else:
120
- raise ValueError('Unknown coarse loss: {type}'.format(type=self.loss_config['coarse_type']))
121
-
 
 
 
 
122
  def compute_fine_loss(self, expec_f, expec_f_gt):
123
- if self.fine_type == 'l2_with_std':
124
  return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
125
- elif self.fine_type == 'l2':
126
  return self._compute_fine_loss_l2(expec_f, expec_f_gt)
127
  else:
128
  raise NotImplementedError()
@@ -133,9 +160,13 @@ class ASpanLoss(nn.Module):
133
  expec_f (torch.Tensor): [M, 2] <x, y>
134
  expec_f_gt (torch.Tensor): [M, 2] <x, y>
135
  """
136
- correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
 
 
137
  if correct_mask.sum() == 0:
138
- if self.training: # this seldomly happen when training, since we pad prediction with gt
 
 
139
  logger.warning("assign a false supervision to avoid ddp deadlock")
140
  correct_mask[0] = True
141
  else:
@@ -150,20 +181,26 @@ class ASpanLoss(nn.Module):
150
  expec_f_gt (torch.Tensor): [M, 2] <x, y>
151
  """
152
  # correct_mask tells you which pair to compute fine-loss
153
- correct_mask = torch.linalg.norm(expec_f_gt, ord=float('inf'), dim=1) < self.correct_thr
 
 
154
 
155
  # use std as weight that measures uncertainty
156
  std = expec_f[:, 2]
157
- inverse_std = 1. / torch.clamp(std, min=1e-10)
158
- weight = (inverse_std / torch.mean(inverse_std)).detach() # avoid minizing loss through increase std
 
 
159
 
160
  # corner case: no correct coarse match found
161
  if not correct_mask.any():
162
- if self.training: # this seldomly happen during training, since we pad prediction with gt
163
- # sometimes there is not coarse-level gt at all.
 
 
164
  logger.warning("assign a false supervision to avoid ddp deadlock")
165
  correct_mask[0] = True
166
- weight[0] = 0.
167
  else:
168
  return None
169
 
@@ -172,12 +209,15 @@ class ASpanLoss(nn.Module):
172
  loss = (flow_l2 * weight[correct_mask]).mean()
173
 
174
  return loss
175
-
176
  @torch.no_grad()
177
  def compute_c_weight(self, data):
178
- """ compute element-wise weights for computing coarse-level loss. """
179
- if 'mask0' in data:
180
- c_weight = (data['mask0'].flatten(-2)[..., None] * data['mask1'].flatten(-2)[:, None]).float()
 
 
 
181
  else:
182
  c_weight = None
183
  return c_weight
@@ -196,36 +236,54 @@ class ASpanLoss(nn.Module):
196
 
197
  # 1. coarse-level loss
198
  loss_c = self.compute_coarse_loss(
199
- data['conf_matrix_with_bin'] if self.sparse_spvs and self.match_type == 'sinkhorn' \
200
- else data['conf_matrix'],
201
- data['conf_matrix_gt'],
202
- weight=c_weight)
203
- loss = loss_c * self.loss_config['coarse_weight']
 
 
204
  loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
205
 
206
  # 2. fine-level loss
207
- loss_f = self.compute_fine_loss(data['expec_f'], data['expec_f_gt'])
208
  if loss_f is not None:
209
- loss += loss_f * self.loss_config['fine_weight']
210
- loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
211
  else:
212
  assert self.training is False
213
- loss_scalars.update({'loss_f': torch.tensor(1.)}) # 1 is the upper bound
214
-
215
  # 3. flow loss
216
- coarse_corr=[data['spv_b_ids'],data['spv_i_ids'],data['spv_j_ids']]
217
- loss_flow = self.compute_flow_loss(coarse_corr,data['predict_flow'],\
218
- data['hw0_c'][0],data['hw0_c'][1],data['hw1_c'][0],data['hw1_c'][1])
219
- loss_flow=loss_flow*self.flow_weight
220
- for index,loss_off in enumerate(loss_flow):
221
- loss_scalars.update({'loss_flow_'+str(index): loss_off.clone().detach().cpu()}) # 1 is the upper bound
222
- conf=data['predict_flow'][0][:,:,:,:,2:]
223
- layer_num=conf.shape[0]
 
 
 
 
 
 
 
 
224
  for layer_index in range(layer_num):
225
- loss_scalars.update({'conf_'+str(layer_index): conf[layer_index].mean().clone().detach().cpu()}) # 1 is the upper bound
226
-
227
-
228
- loss+=loss_flow.sum()
229
- #print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
230
- loss_scalars.update({'loss': loss.clone().detach().cpu()})
 
 
 
 
 
 
 
 
231
  data.update({"loss": loss, "loss_scalars": loss_scalars})
 
3
  import torch
4
  import torch.nn as nn
5
 
6
+
7
  class ASpanLoss(nn.Module):
8
  def __init__(self, config):
9
  super().__init__()
10
  self.config = config # config under the global namespace
11
+ self.loss_config = config["aspan"]["loss"]
12
+ self.match_type = self.config["aspan"]["match_coarse"]["match_type"]
13
+ self.sparse_spvs = self.config["aspan"]["match_coarse"]["sparse_spvs"]
14
+ self.flow_weight = self.config["aspan"]["loss"]["flow_weight"]
15
 
16
  # coarse-level
17
+ self.correct_thr = self.loss_config["fine_correct_thr"]
18
+ self.c_pos_w = self.loss_config["pos_weight"]
19
+ self.c_neg_w = self.loss_config["neg_weight"]
20
  # fine-level
21
+ self.fine_type = self.loss_config["fine_type"]
22
+
23
+ def compute_flow_loss(self, coarse_corr_gt, flow_list, h0, w0, h1, w1):
24
+ # coarse_corr_gt:[[batch_indices],[left_indices],[right_indices]]
25
+ # flow_list: [L,B,H,W,4]
26
+ loss1 = self.flow_loss_worker(
27
+ flow_list[0], coarse_corr_gt[0], coarse_corr_gt[1], coarse_corr_gt[2], w1
28
+ )
29
+ loss2 = self.flow_loss_worker(
30
+ flow_list[1], coarse_corr_gt[0], coarse_corr_gt[2], coarse_corr_gt[1], w0
31
+ )
32
+ total_loss = (loss1 + loss2) / 2
33
  return total_loss
34
 
35
+ def flow_loss_worker(self, flow, batch_indicies, self_indicies, cross_indicies, w):
36
+ bs, layer_num = flow.shape[1], flow.shape[0]
37
+ flow = flow.view(layer_num, bs, -1, 4)
38
+ gt_flow = torch.stack([cross_indicies % w, cross_indicies // w], dim=1)
39
 
40
+ total_loss_list = []
41
  for layer_index in range(layer_num):
42
+ cur_flow_list = flow[layer_index]
43
+ spv_flow = cur_flow_list[batch_indicies, self_indicies][:, :2]
44
+ spv_conf = cur_flow_list[batch_indicies, self_indicies][
45
+ :, 2:
46
+ ] # [#coarse,2]
47
+ l2_flow_dis = (gt_flow - spv_flow) ** 2 # [#coarse,2]
48
+ total_loss = spv_conf + torch.exp(-spv_conf) * l2_flow_dis # [#coarse,2]
49
  total_loss_list.append(total_loss.mean())
50
+ total_loss = torch.stack(total_loss_list, dim=-1) * self.flow_weight
51
  return total_loss
52
+
53
  def compute_coarse_loss(self, conf, conf_gt, weight=None):
54
+ """Point-wise CE / Focal Loss with 0 / 1 confidence as gt.
55
  Args:
56
  conf (torch.Tensor): (N, HW0, HW1) / (N, HW0+1, HW1+1)
57
  conf_gt (torch.Tensor): (N, HW0, HW1)
 
63
  if not pos_mask.any(): # assign a wrong gt
64
  pos_mask[0, 0, 0] = True
65
  if weight is not None:
66
+ weight[0, 0, 0] = 0.0
67
+ c_pos_w = 0.0
68
  if not neg_mask.any():
69
  neg_mask[0, 0, 0] = True
70
  if weight is not None:
71
+ weight[0, 0, 0] = 0.0
72
+ c_neg_w = 0.0
73
+
74
+ if self.loss_config["coarse_type"] == "cross_entropy":
75
+ assert (
76
+ not self.sparse_spvs
77
+ ), "Sparse Supervision for cross-entropy not implemented!"
78
+ conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
79
+ loss_pos = -torch.log(conf[pos_mask])
80
+ loss_neg = -torch.log(1 - conf[neg_mask])
81
  if weight is not None:
82
  loss_pos = loss_pos * weight[pos_mask]
83
  loss_neg = loss_neg * weight[neg_mask]
84
  return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
85
+ elif self.loss_config["coarse_type"] == "focal":
86
+ conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
87
+ alpha = self.loss_config["focal_alpha"]
88
+ gamma = self.loss_config["focal_gamma"]
89
+
90
  if self.sparse_spvs:
91
+ pos_conf = (
92
+ conf[:, :-1, :-1][pos_mask]
93
+ if self.match_type == "sinkhorn"
94
+ else conf[pos_mask]
95
+ )
96
+ loss_pos = -alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log()
97
  # calculate losses for negative samples
98
+ if self.match_type == "sinkhorn":
99
  neg0, neg1 = conf_gt.sum(-1) == 0, conf_gt.sum(1) == 0
100
+ neg_conf = torch.cat(
101
+ [conf[:, :-1, -1][neg0], conf[:, -1, :-1][neg1]], 0
102
+ )
103
+ loss_neg = -alpha * torch.pow(1 - neg_conf, gamma) * neg_conf.log()
104
  else:
105
  # These is no dustbin for dual_softmax, so we left unmatchable patches without supervision.
106
  # we could also add 'pseudo negtive-samples'
 
110
  # Different from dense-spvs, the loss w.r.t. padded regions aren't directly zeroed out,
111
  # but only through manually setting corresponding regions in sim_matrix to '-inf'.
112
  loss_pos = loss_pos * weight[pos_mask]
113
+ if self.match_type == "sinkhorn":
114
  neg_w0 = (weight.sum(-1) != 0)[neg0]
115
  neg_w1 = (weight.sum(1) != 0)[neg1]
116
  neg_mask = torch.cat([neg_w0, neg_w1], 0)
117
  loss_neg = loss_neg[neg_mask]
118
+
119
+ loss = (
120
+ c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
121
+ if self.match_type == "sinkhorn"
122
+ else c_pos_w * loss_pos.mean()
123
+ )
124
  return loss
125
  # positive and negative elements occupy similar propotions. => more balanced loss weights needed
126
  else: # dense supervision (in the case of match_type=='sinkhorn', the dustbin is not supervised.)
127
+ loss_pos = (
128
+ -alpha
129
+ * torch.pow(1 - conf[pos_mask], gamma)
130
+ * (conf[pos_mask]).log()
131
+ )
132
+ loss_neg = (
133
+ -alpha
134
+ * torch.pow(conf[neg_mask], gamma)
135
+ * (1 - conf[neg_mask]).log()
136
+ )
137
  if weight is not None:
138
  loss_pos = loss_pos * weight[pos_mask]
139
  loss_neg = loss_neg * weight[neg_mask]
140
  return c_pos_w * loss_pos.mean() + c_neg_w * loss_neg.mean()
141
  # each negative element occupy a smaller propotion than positive elements. => higher negative loss weight needed
142
  else:
143
+ raise ValueError(
144
+ "Unknown coarse loss: {type}".format(
145
+ type=self.loss_config["coarse_type"]
146
+ )
147
+ )
148
+
149
  def compute_fine_loss(self, expec_f, expec_f_gt):
150
+ if self.fine_type == "l2_with_std":
151
  return self._compute_fine_loss_l2_std(expec_f, expec_f_gt)
152
+ elif self.fine_type == "l2":
153
  return self._compute_fine_loss_l2(expec_f, expec_f_gt)
154
  else:
155
  raise NotImplementedError()
 
160
  expec_f (torch.Tensor): [M, 2] <x, y>
161
  expec_f_gt (torch.Tensor): [M, 2] <x, y>
162
  """
163
+ correct_mask = (
164
+ torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
165
+ )
166
  if correct_mask.sum() == 0:
167
+ if (
168
+ self.training
169
+ ): # this seldomly happen when training, since we pad prediction with gt
170
  logger.warning("assign a false supervision to avoid ddp deadlock")
171
  correct_mask[0] = True
172
  else:
 
181
  expec_f_gt (torch.Tensor): [M, 2] <x, y>
182
  """
183
  # correct_mask tells you which pair to compute fine-loss
184
+ correct_mask = (
185
+ torch.linalg.norm(expec_f_gt, ord=float("inf"), dim=1) < self.correct_thr
186
+ )
187
 
188
  # use std as weight that measures uncertainty
189
  std = expec_f[:, 2]
190
+ inverse_std = 1.0 / torch.clamp(std, min=1e-10)
191
+ weight = (
192
+ inverse_std / torch.mean(inverse_std)
193
+ ).detach() # avoid minizing loss through increase std
194
 
195
  # corner case: no correct coarse match found
196
  if not correct_mask.any():
197
+ if (
198
+ self.training
199
+ ): # this seldomly happen during training, since we pad prediction with gt
200
+ # sometimes there is not coarse-level gt at all.
201
  logger.warning("assign a false supervision to avoid ddp deadlock")
202
  correct_mask[0] = True
203
+ weight[0] = 0.0
204
  else:
205
  return None
206
 
 
209
  loss = (flow_l2 * weight[correct_mask]).mean()
210
 
211
  return loss
212
+
213
  @torch.no_grad()
214
  def compute_c_weight(self, data):
215
+ """compute element-wise weights for computing coarse-level loss."""
216
+ if "mask0" in data:
217
+ c_weight = (
218
+ data["mask0"].flatten(-2)[..., None]
219
+ * data["mask1"].flatten(-2)[:, None]
220
+ ).float()
221
  else:
222
  c_weight = None
223
  return c_weight
 
236
 
237
  # 1. coarse-level loss
238
  loss_c = self.compute_coarse_loss(
239
+ data["conf_matrix_with_bin"]
240
+ if self.sparse_spvs and self.match_type == "sinkhorn"
241
+ else data["conf_matrix"],
242
+ data["conf_matrix_gt"],
243
+ weight=c_weight,
244
+ )
245
+ loss = loss_c * self.loss_config["coarse_weight"]
246
  loss_scalars.update({"loss_c": loss_c.clone().detach().cpu()})
247
 
248
  # 2. fine-level loss
249
+ loss_f = self.compute_fine_loss(data["expec_f"], data["expec_f_gt"])
250
  if loss_f is not None:
251
+ loss += loss_f * self.loss_config["fine_weight"]
252
+ loss_scalars.update({"loss_f": loss_f.clone().detach().cpu()})
253
  else:
254
  assert self.training is False
255
+ loss_scalars.update({"loss_f": torch.tensor(1.0)}) # 1 is the upper bound
256
+
257
  # 3. flow loss
258
+ coarse_corr = [data["spv_b_ids"], data["spv_i_ids"], data["spv_j_ids"]]
259
+ loss_flow = self.compute_flow_loss(
260
+ coarse_corr,
261
+ data["predict_flow"],
262
+ data["hw0_c"][0],
263
+ data["hw0_c"][1],
264
+ data["hw1_c"][0],
265
+ data["hw1_c"][1],
266
+ )
267
+ loss_flow = loss_flow * self.flow_weight
268
+ for index, loss_off in enumerate(loss_flow):
269
+ loss_scalars.update(
270
+ {"loss_flow_" + str(index): loss_off.clone().detach().cpu()}
271
+ ) # 1 is the upper bound
272
+ conf = data["predict_flow"][0][:, :, :, :, 2:]
273
+ layer_num = conf.shape[0]
274
  for layer_index in range(layer_num):
275
+ loss_scalars.update(
276
+ {
277
+ "conf_"
278
+ + str(layer_index): conf[layer_index]
279
+ .mean()
280
+ .clone()
281
+ .detach()
282
+ .cpu()
283
+ }
284
+ ) # 1 is the upper bound
285
+
286
+ loss += loss_flow.sum()
287
+ # print((loss_c * self.loss_config['coarse_weight']).data,loss_flow.data)
288
+ loss_scalars.update({"loss": loss.clone().detach().cpu()})
289
  data.update({"loss": loss, "loss_scalars": loss_scalars})
third_party/ASpanFormer/src/optimizers/__init__.py CHANGED
@@ -7,9 +7,13 @@ def build_optimizer(model, config):
7
  lr = config.TRAINER.TRUE_LR
8
 
9
  if name == "adam":
10
- return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
 
 
11
  elif name == "adamw":
12
- return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
 
 
13
  else:
14
  raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
15
 
@@ -24,18 +28,27 @@ def build_scheduler(config, optimizer):
24
  'frequency': x, (optional)
25
  }
26
  """
27
- scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
28
  name = config.TRAINER.SCHEDULER
29
 
30
- if name == 'MultiStepLR':
31
  scheduler.update(
32
- {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
33
- elif name == 'CosineAnnealing':
 
 
 
 
 
 
 
34
  scheduler.update(
35
- {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
36
- elif name == 'ExponentialLR':
 
37
  scheduler.update(
38
- {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
 
39
  else:
40
  raise NotImplementedError()
41
 
 
7
  lr = config.TRAINER.TRUE_LR
8
 
9
  if name == "adam":
10
+ return torch.optim.Adam(
11
+ model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
12
+ )
13
  elif name == "adamw":
14
+ return torch.optim.AdamW(
15
+ model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
16
+ )
17
  else:
18
  raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
19
 
 
28
  'frequency': x, (optional)
29
  }
30
  """
31
+ scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
32
  name = config.TRAINER.SCHEDULER
33
 
34
+ if name == "MultiStepLR":
35
  scheduler.update(
36
+ {
37
+ "scheduler": MultiStepLR(
38
+ optimizer,
39
+ config.TRAINER.MSLR_MILESTONES,
40
+ gamma=config.TRAINER.MSLR_GAMMA,
41
+ )
42
+ }
43
+ )
44
+ elif name == "CosineAnnealing":
45
  scheduler.update(
46
+ {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
47
+ )
48
+ elif name == "ExponentialLR":
49
  scheduler.update(
50
+ {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
51
+ )
52
  else:
53
  raise NotImplementedError()
54
 
third_party/ASpanFormer/src/utils/augment.py CHANGED
@@ -7,16 +7,21 @@ class DarkAug(object):
7
  """
8
 
9
  def __init__(self) -> None:
10
- self.augmentor = A.Compose([
11
- A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)),
12
- A.Blur(p=0.1, blur_limit=(3, 9)),
13
- A.MotionBlur(p=0.2, blur_limit=(3, 25)),
14
- A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
15
- A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40))
16
- ], p=0.75)
 
 
 
 
 
17
 
18
  def __call__(self, x):
19
- return self.augmentor(image=x)['image']
20
 
21
 
22
  class MobileAug(object):
@@ -25,31 +30,36 @@ class MobileAug(object):
25
  """
26
 
27
  def __init__(self):
28
- self.augmentor = A.Compose([
29
- A.MotionBlur(p=0.25),
30
- A.ColorJitter(p=0.5),
31
- A.RandomRain(p=0.1), # random occlusion
32
- A.RandomSunFlare(p=0.1),
33
- A.JpegCompression(p=0.25),
34
- A.ISONoise(p=0.25)
35
- ], p=1.0)
 
 
 
36
 
37
  def __call__(self, x):
38
- return self.augmentor(image=x)['image']
39
 
40
 
41
  def build_augmentor(method=None, **kwargs):
42
  if method is not None:
43
- raise NotImplementedError('Using of augmentation functions are not supported yet!')
44
- if method == 'dark':
 
 
45
  return DarkAug()
46
- elif method == 'mobile':
47
  return MobileAug()
48
  elif method is None:
49
  return None
50
  else:
51
- raise ValueError(f'Invalid augmentation method: {method}')
52
 
53
 
54
- if __name__ == '__main__':
55
- augmentor = build_augmentor('FDA')
 
7
  """
8
 
9
  def __init__(self) -> None:
10
+ self.augmentor = A.Compose(
11
+ [
12
+ A.RandomBrightnessContrast(
13
+ p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)
14
+ ),
15
+ A.Blur(p=0.1, blur_limit=(3, 9)),
16
+ A.MotionBlur(p=0.2, blur_limit=(3, 25)),
17
+ A.RandomGamma(p=0.1, gamma_limit=(15, 65)),
18
+ A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)),
19
+ ],
20
+ p=0.75,
21
+ )
22
 
23
  def __call__(self, x):
24
+ return self.augmentor(image=x)["image"]
25
 
26
 
27
  class MobileAug(object):
 
30
  """
31
 
32
  def __init__(self):
33
+ self.augmentor = A.Compose(
34
+ [
35
+ A.MotionBlur(p=0.25),
36
+ A.ColorJitter(p=0.5),
37
+ A.RandomRain(p=0.1), # random occlusion
38
+ A.RandomSunFlare(p=0.1),
39
+ A.JpegCompression(p=0.25),
40
+ A.ISONoise(p=0.25),
41
+ ],
42
+ p=1.0,
43
+ )
44
 
45
  def __call__(self, x):
46
+ return self.augmentor(image=x)["image"]
47
 
48
 
49
  def build_augmentor(method=None, **kwargs):
50
  if method is not None:
51
+ raise NotImplementedError(
52
+ "Using of augmentation functions are not supported yet!"
53
+ )
54
+ if method == "dark":
55
  return DarkAug()
56
+ elif method == "mobile":
57
  return MobileAug()
58
  elif method is None:
59
  return None
60
  else:
61
+ raise ValueError(f"Invalid augmentation method: {method}")
62
 
63
 
64
+ if __name__ == "__main__":
65
+ augmentor = build_augmentor("FDA")
third_party/ASpanFormer/src/utils/comm.py CHANGED
@@ -98,11 +98,11 @@ def _serialize_to_tensor(data, group):
98
  device = torch.device("cpu" if backend == "gloo" else "cuda")
99
 
100
  buffer = pickle.dumps(data)
101
- if len(buffer) > 1024 ** 3:
102
  logger = logging.getLogger(__name__)
103
  logger.warning(
104
  "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
105
- get_rank(), len(buffer) / (1024 ** 3), device
106
  )
107
  )
108
  storage = torch.ByteStorage.from_buffer(buffer)
@@ -122,7 +122,8 @@ def _pad_to_largest_tensor(tensor, group):
122
  ), "comm.gather/all_gather must be called from ranks within the given group!"
123
  local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
124
  size_list = [
125
- torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
 
126
  ]
127
  dist.all_gather(size_list, local_size, group=group)
128
 
@@ -133,7 +134,9 @@ def _pad_to_largest_tensor(tensor, group):
133
  # we pad the tensor because torch all_gather does not support
134
  # gathering tensors of different shapes
135
  if local_size != max_size:
136
- padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
 
 
137
  tensor = torch.cat((tensor, padding), dim=0)
138
  return size_list, tensor
139
 
@@ -164,7 +167,8 @@ def all_gather(data, group=None):
164
 
165
  # receiving Tensor from all ranks
166
  tensor_list = [
167
- torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
 
168
  ]
169
  dist.all_gather(tensor_list, tensor, group=group)
170
 
@@ -205,7 +209,8 @@ def gather(data, dst=0, group=None):
205
  if rank == dst:
206
  max_size = max(size_list)
207
  tensor_list = [
208
- torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
 
209
  ]
210
  dist.gather(tensor, tensor_list, dst=dst, group=group)
211
 
@@ -228,7 +233,7 @@ def shared_random_seed():
228
 
229
  All workers must call this function, otherwise it will deadlock.
230
  """
231
- ints = np.random.randint(2 ** 31)
232
  all_ints = all_gather(ints)
233
  return all_ints[0]
234
 
 
98
  device = torch.device("cpu" if backend == "gloo" else "cuda")
99
 
100
  buffer = pickle.dumps(data)
101
+ if len(buffer) > 1024**3:
102
  logger = logging.getLogger(__name__)
103
  logger.warning(
104
  "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
105
+ get_rank(), len(buffer) / (1024**3), device
106
  )
107
  )
108
  storage = torch.ByteStorage.from_buffer(buffer)
 
122
  ), "comm.gather/all_gather must be called from ranks within the given group!"
123
  local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
124
  size_list = [
125
+ torch.zeros([1], dtype=torch.int64, device=tensor.device)
126
+ for _ in range(world_size)
127
  ]
128
  dist.all_gather(size_list, local_size, group=group)
129
 
 
134
  # we pad the tensor because torch all_gather does not support
135
  # gathering tensors of different shapes
136
  if local_size != max_size:
137
+ padding = torch.zeros(
138
+ (max_size - local_size,), dtype=torch.uint8, device=tensor.device
139
+ )
140
  tensor = torch.cat((tensor, padding), dim=0)
141
  return size_list, tensor
142
 
 
167
 
168
  # receiving Tensor from all ranks
169
  tensor_list = [
170
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
171
+ for _ in size_list
172
  ]
173
  dist.all_gather(tensor_list, tensor, group=group)
174
 
 
209
  if rank == dst:
210
  max_size = max(size_list)
211
  tensor_list = [
212
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
213
+ for _ in size_list
214
  ]
215
  dist.gather(tensor, tensor_list, dst=dst, group=group)
216
 
 
233
 
234
  All workers must call this function, otherwise it will deadlock.
235
  """
236
+ ints = np.random.randint(2**31)
237
  all_ints = all_gather(ints)
238
  return all_ints[0]
239
 
third_party/ASpanFormer/src/utils/dataloader.py CHANGED
@@ -3,21 +3,22 @@ import numpy as np
3
 
4
  # --- PL-DATAMODULE ---
5
 
 
6
  def get_local_split(items: list, world_size: int, rank: int, seed: int):
7
- """ The local rank only loads a split of the dataset. """
8
  n_items = len(items)
9
  items_permute = np.random.RandomState(seed).permutation(items)
10
  if n_items % world_size == 0:
11
  padded_items = items_permute
12
  else:
13
  padding = np.random.RandomState(seed).choice(
14
- items,
15
- world_size - (n_items % world_size),
16
- replace=True)
17
  padded_items = np.concatenate([items_permute, padding])
18
- assert len(padded_items) % world_size == 0, \
19
- f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}'
 
20
  n_per_rank = len(padded_items) // world_size
21
- local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)]
22
 
23
  return local_items
 
3
 
4
  # --- PL-DATAMODULE ---
5
 
6
+
7
  def get_local_split(items: list, world_size: int, rank: int, seed: int):
8
+ """The local rank only loads a split of the dataset."""
9
  n_items = len(items)
10
  items_permute = np.random.RandomState(seed).permutation(items)
11
  if n_items % world_size == 0:
12
  padded_items = items_permute
13
  else:
14
  padding = np.random.RandomState(seed).choice(
15
+ items, world_size - (n_items % world_size), replace=True
16
+ )
 
17
  padded_items = np.concatenate([items_permute, padding])
18
+ assert (
19
+ len(padded_items) % world_size == 0
20
+ ), f"len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}"
21
  n_per_rank = len(padded_items) // world_size
22
+ local_items = padded_items[n_per_rank * rank : n_per_rank * (rank + 1)]
23
 
24
  return local_items
third_party/ASpanFormer/src/utils/dataset.py CHANGED
@@ -15,8 +15,11 @@ except Exception:
15
 
16
  # --- DATA IO ---
17
 
 
18
  def load_array_from_s3(
19
- path, client, cv_type,
 
 
20
  use_h5py=False,
21
  ):
22
  byte_str = client.Get(path)
@@ -26,7 +29,7 @@ def load_array_from_s3(
26
  data = cv2.imdecode(raw_array, cv_type)
27
  else:
28
  f = io.BytesIO(byte_str)
29
- data = np.array(h5py.File(f, 'r')['/depth'])
30
  except Exception as ex:
31
  print(f"==> Data loading failure: {path}")
32
  raise ex
@@ -36,9 +39,8 @@ def load_array_from_s3(
36
 
37
 
38
  def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
39
- cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
40
- else cv2.IMREAD_COLOR
41
- if str(path).startswith('s3://'):
42
  image = load_array_from_s3(str(path), client, cv_type)
43
  else:
44
  image = cv2.imread(str(path), cv_type)
@@ -54,7 +56,7 @@ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
54
  def get_resized_wh(w, h, resize=None):
55
  if resize is not None: # resize the longer edge
56
  scale = resize / max(h, w)
57
- w_new, h_new = int(round(w*scale)), int(round(h*scale))
58
  else:
59
  w_new, h_new = w, h
60
  return w_new, h_new
@@ -69,20 +71,22 @@ def get_divisible_wh(w, h, df=None):
69
 
70
 
71
  def pad_bottom_right(inp, pad_size, ret_mask=False):
72
- assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
 
 
73
  mask = None
74
  if inp.ndim == 2:
75
  padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
76
- padded[:inp.shape[0], :inp.shape[1]] = inp
77
  if ret_mask:
78
  mask = np.zeros((pad_size, pad_size), dtype=bool)
79
- mask[:inp.shape[0], :inp.shape[1]] = True
80
  elif inp.ndim == 3:
81
  padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
82
- padded[:, :inp.shape[1], :inp.shape[2]] = inp
83
  if ret_mask:
84
  mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
85
- mask[:, :inp.shape[1], :inp.shape[2]] = True
86
  else:
87
  raise NotImplementedError()
88
  return padded, mask
@@ -90,6 +94,7 @@ def pad_bottom_right(inp, pad_size, ret_mask=False):
90
 
91
  # --- MEGADEPTH ---
92
 
 
93
  def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
94
  """
95
  Args:
@@ -99,7 +104,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
99
  Returns:
100
  image (torch.tensor): (1, h, w)
101
  mask (torch.tensor): (h, w)
102
- scale (torch.tensor): [w/w_new, h/h_new]
103
  """
104
  # read image
105
  image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
@@ -110,7 +115,7 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
110
  w_new, h_new = get_divisible_wh(w_new, h_new, df)
111
 
112
  image = cv2.resize(image, (w_new, h_new))
113
- scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
114
 
115
  if padding: # padding
116
  pad_to = max(h_new, w_new)
@@ -118,7 +123,9 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
118
  else:
119
  mask = None
120
 
121
- image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
 
 
122
  if mask is not None:
123
  mask = torch.from_numpy(mask)
124
 
@@ -126,10 +133,10 @@ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=No
126
 
127
 
128
  def read_megadepth_depth(path, pad_to=None):
129
- if str(path).startswith('s3://'):
130
  depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
131
  else:
132
- depth = np.array(h5py.File(path, 'r')['depth'])
133
  if pad_to is not None:
134
  depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
135
  depth = torch.from_numpy(depth).float() # (h, w)
@@ -138,6 +145,7 @@ def read_megadepth_depth(path, pad_to=None):
138
 
139
  # --- ScanNet ---
140
 
 
141
  def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
142
  """
143
  Args:
@@ -146,7 +154,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
146
  Returns:
147
  image (torch.tensor): (1, h, w)
148
  mask (torch.tensor): (h, w)
149
- scale (torch.tensor): [w/w_new, h/h_new]
150
  """
151
  # read and resize image
152
  image = imread_gray(path, augment_fn)
@@ -158,7 +166,7 @@ def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
158
 
159
 
160
  def read_scannet_depth(path):
161
- if str(path).startswith('s3://'):
162
  depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
163
  else:
164
  depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
@@ -168,55 +176,57 @@ def read_scannet_depth(path):
168
 
169
 
170
  def read_scannet_pose(path):
171
- """ Read ScanNet's Camera2World pose and transform it to World2Camera.
172
-
173
  Returns:
174
  pose_w2c (np.ndarray): (4, 4)
175
  """
176
- cam2world = np.loadtxt(path, delimiter=' ')
177
  world2cam = inv(cam2world)
178
  return world2cam
179
 
180
 
181
  def read_scannet_intrinsic(path):
182
- """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
183
- """
184
- intrinsic = np.loadtxt(path, delimiter=' ')
185
  return intrinsic[:-1, :-1]
186
 
187
 
188
- def read_gl3d_gray(path,resize):
189
- img=cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),(int(resize),int(resize)))
190
- img = torch.from_numpy(img).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
 
 
191
  return img
192
 
 
193
  def read_gl3d_depth(file_path):
194
- with open(file_path, 'rb') as fin:
195
  color = None
196
  width = None
197
  height = None
198
  scale = None
199
  data_type = None
200
- header = str(fin.readline().decode('UTF-8')).rstrip()
201
- if header == 'PF':
202
  color = True
203
- elif header == 'Pf':
204
  color = False
205
  else:
206
- raise Exception('Not a PFM file.')
207
- dim_match = re.match(r'^(\d+)\s(\d+)\s$', fin.readline().decode('UTF-8'))
208
  if dim_match:
209
  width, height = map(int, dim_match.groups())
210
  else:
211
- raise Exception('Malformed PFM header.')
212
- scale = float((fin.readline().decode('UTF-8')).rstrip())
213
  if scale < 0: # little-endian
214
- data_type = '<f'
215
  else:
216
- data_type = '>f' # big-endian
217
  data_string = fin.read()
218
  data = np.fromstring(data_string, data_type)
219
  shape = (height, width, 3) if color else (height, width)
220
  data = np.reshape(data, shape)
221
  data = np.flip(data, 0)
222
- return torch.from_numpy(data.copy()).float()
 
15
 
16
  # --- DATA IO ---
17
 
18
+
19
  def load_array_from_s3(
20
+ path,
21
+ client,
22
+ cv_type,
23
  use_h5py=False,
24
  ):
25
  byte_str = client.Get(path)
 
29
  data = cv2.imdecode(raw_array, cv_type)
30
  else:
31
  f = io.BytesIO(byte_str)
32
+ data = np.array(h5py.File(f, "r")["/depth"])
33
  except Exception as ex:
34
  print(f"==> Data loading failure: {path}")
35
  raise ex
 
39
 
40
 
41
  def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
42
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
43
+ if str(path).startswith("s3://"):
 
44
  image = load_array_from_s3(str(path), client, cv_type)
45
  else:
46
  image = cv2.imread(str(path), cv_type)
 
56
  def get_resized_wh(w, h, resize=None):
57
  if resize is not None: # resize the longer edge
58
  scale = resize / max(h, w)
59
+ w_new, h_new = int(round(w * scale)), int(round(h * scale))
60
  else:
61
  w_new, h_new = w, h
62
  return w_new, h_new
 
71
 
72
 
73
  def pad_bottom_right(inp, pad_size, ret_mask=False):
74
+ assert isinstance(pad_size, int) and pad_size >= max(
75
+ inp.shape[-2:]
76
+ ), f"{pad_size} < {max(inp.shape[-2:])}"
77
  mask = None
78
  if inp.ndim == 2:
79
  padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
80
+ padded[: inp.shape[0], : inp.shape[1]] = inp
81
  if ret_mask:
82
  mask = np.zeros((pad_size, pad_size), dtype=bool)
83
+ mask[: inp.shape[0], : inp.shape[1]] = True
84
  elif inp.ndim == 3:
85
  padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
86
+ padded[:, : inp.shape[1], : inp.shape[2]] = inp
87
  if ret_mask:
88
  mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
89
+ mask[:, : inp.shape[1], : inp.shape[2]] = True
90
  else:
91
  raise NotImplementedError()
92
  return padded, mask
 
94
 
95
  # --- MEGADEPTH ---
96
 
97
+
98
  def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
99
  """
100
  Args:
 
104
  Returns:
105
  image (torch.tensor): (1, h, w)
106
  mask (torch.tensor): (h, w)
107
+ scale (torch.tensor): [w/w_new, h/h_new]
108
  """
109
  # read image
110
  image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
 
115
  w_new, h_new = get_divisible_wh(w_new, h_new, df)
116
 
117
  image = cv2.resize(image, (w_new, h_new))
118
+ scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float)
119
 
120
  if padding: # padding
121
  pad_to = max(h_new, w_new)
 
123
  else:
124
  mask = None
125
 
126
+ image = (
127
+ torch.from_numpy(image).float()[None] / 255
128
+ ) # (h, w) -> (1, h, w) and normalized
129
  if mask is not None:
130
  mask = torch.from_numpy(mask)
131
 
 
133
 
134
 
135
  def read_megadepth_depth(path, pad_to=None):
136
+ if str(path).startswith("s3://"):
137
  depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
138
  else:
139
+ depth = np.array(h5py.File(path, "r")["depth"])
140
  if pad_to is not None:
141
  depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
142
  depth = torch.from_numpy(depth).float() # (h, w)
 
145
 
146
  # --- ScanNet ---
147
 
148
+
149
  def read_scannet_gray(path, resize=(640, 480), augment_fn=None):
150
  """
151
  Args:
 
154
  Returns:
155
  image (torch.tensor): (1, h, w)
156
  mask (torch.tensor): (h, w)
157
+ scale (torch.tensor): [w/w_new, h/h_new]
158
  """
159
  # read and resize image
160
  image = imread_gray(path, augment_fn)
 
166
 
167
 
168
  def read_scannet_depth(path):
169
+ if str(path).startswith("s3://"):
170
  depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED)
171
  else:
172
  depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
 
176
 
177
 
178
  def read_scannet_pose(path):
179
+ """Read ScanNet's Camera2World pose and transform it to World2Camera.
180
+
181
  Returns:
182
  pose_w2c (np.ndarray): (4, 4)
183
  """
184
+ cam2world = np.loadtxt(path, delimiter=" ")
185
  world2cam = inv(cam2world)
186
  return world2cam
187
 
188
 
189
  def read_scannet_intrinsic(path):
190
+ """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
191
+ intrinsic = np.loadtxt(path, delimiter=" ")
 
192
  return intrinsic[:-1, :-1]
193
 
194
 
195
+ def read_gl3d_gray(path, resize):
196
+ img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize)))
197
+ img = (
198
+ torch.from_numpy(img).float()[None] / 255
199
+ ) # (h, w) -> (1, h, w) and normalized
200
  return img
201
 
202
+
203
  def read_gl3d_depth(file_path):
204
+ with open(file_path, "rb") as fin:
205
  color = None
206
  width = None
207
  height = None
208
  scale = None
209
  data_type = None
210
+ header = str(fin.readline().decode("UTF-8")).rstrip()
211
+ if header == "PF":
212
  color = True
213
+ elif header == "Pf":
214
  color = False
215
  else:
216
+ raise Exception("Not a PFM file.")
217
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8"))
218
  if dim_match:
219
  width, height = map(int, dim_match.groups())
220
  else:
221
+ raise Exception("Malformed PFM header.")
222
+ scale = float((fin.readline().decode("UTF-8")).rstrip())
223
  if scale < 0: # little-endian
224
+ data_type = "<f"
225
  else:
226
+ data_type = ">f" # big-endian
227
  data_string = fin.read()
228
  data = np.fromstring(data_string, data_type)
229
  shape = (height, width, 3) if color else (height, width)
230
  data = np.reshape(data, shape)
231
  data = np.flip(data, 0)
232
+ return torch.from_numpy(data.copy()).float()
third_party/ASpanFormer/src/utils/metrics.py CHANGED
@@ -9,6 +9,7 @@ from kornia.geometry.conversions import convert_points_to_homogeneous
9
 
10
  # --- METRICS ---
11
 
 
12
  def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
13
  # angle error between 2 vectors
14
  t_gt = T_0to1[:3, 3]
@@ -21,7 +22,7 @@ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
21
  # angle error between 2 rotation matrices
22
  R_gt = T_0to1[:3, :3]
23
  cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
24
- cos = np.clip(cos, -1., 1.) # handle numercial errors
25
  R_err = np.rad2deg(np.abs(np.arccos(cos)))
26
 
27
  return t_err, R_err
@@ -43,93 +44,108 @@ def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
43
  p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
44
  Etp1 = pts1 @ E # [N, 3]
45
 
46
- d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
 
 
 
47
  return d
48
 
49
 
50
  def compute_symmetrical_epipolar_errors(data):
51
- """
52
  Update:
53
  data (dict):{"epi_errs": [M]}
54
  """
55
- Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
56
- E_mat = Tx @ data['T_0to1'][:, :3, :3]
57
 
58
- m_bids = data['m_bids']
59
- pts0 = data['mkpts0_f']
60
- pts1 = data['mkpts1_f']
61
 
62
  epi_errs = []
63
  for bs in range(Tx.size(0)):
64
  mask = m_bids == bs
65
  epi_errs.append(
66
- symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
 
 
 
67
  epi_errs = torch.cat(epi_errs, dim=0)
68
 
69
- data.update({'epi_errs': epi_errs})
 
70
 
71
  def compute_symmetrical_epipolar_errors_offset(data):
72
- """
73
  Update:
74
  data (dict):{"epi_errs": [M]}
75
  """
76
- Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
77
- E_mat = Tx @ data['T_0to1'][:, :3, :3]
78
 
79
- m_bids = data['offset_bids']
80
- l_ids=data['offset_lids']
81
- pts0 = data['offset_kpts0_f']
82
- pts1 = data['offset_kpts1_f']
83
 
84
  epi_errs = []
85
- layer_num=data['predict_flow'][0].shape[0]
86
-
87
  for bs in range(Tx.size(0)):
88
  for ls in range(layer_num):
89
  mask_b = m_bids == bs
90
  mask_l = l_ids == ls
91
- mask=mask_b&mask_l
92
  epi_errs.append(
93
- symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
 
 
 
94
  epi_errs = torch.cat(epi_errs, dim=0)
95
 
96
- data.update({'epi_errs_offset': epi_errs}) #[b*l*n]
 
97
 
98
  def compute_symmetrical_epipolar_errors_offset_bidirectional(data):
99
- """
100
  Update
101
  data (dict):{"epi_errs": [M]}
102
  """
103
- _compute_symmetrical_epipolar_errors_offset(data,'left')
104
- _compute_symmetrical_epipolar_errors_offset(data,'right')
105
 
106
 
107
- def _compute_symmetrical_epipolar_errors_offset(data,side):
108
- """
109
  Update
110
  data (dict):{"epi_errs": [M]}
111
  """
112
- assert side=='left' or side=='right', 'invalid side'
113
 
114
- Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3])
115
- E_mat = Tx @ data['T_0to1'][:, :3, :3]
116
 
117
- m_bids = data['offset_bids_'+side]
118
- l_ids=data['offset_lids_'+side]
119
- pts0 = data['offset_kpts0_f_'+side]
120
- pts1 = data['offset_kpts1_f_'+side]
121
 
122
  epi_errs = []
123
- layer_num=data['predict_flow'][0].shape[0]
124
  for bs in range(Tx.size(0)):
125
  for ls in range(layer_num):
126
  mask_b = m_bids == bs
127
  mask_l = l_ids == ls
128
- mask=mask_b&mask_l
129
  epi_errs.append(
130
- symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs]))
 
 
 
131
  epi_errs = torch.cat(epi_errs, dim=0)
132
- data.update({'epi_errs_offset_'+side: epi_errs}) #[b*l*n]
 
133
 
134
  def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
135
  if len(kpts0) < 5:
@@ -143,7 +159,8 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
143
 
144
  # compute pose with cv2
145
  E, mask = cv2.findEssentialMat(
146
- kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC)
 
147
  if E is None:
148
  print("\nE is None while trying to recover pose.\n")
149
  return None
@@ -161,7 +178,7 @@ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
161
 
162
 
163
  def compute_pose_errors(data, config):
164
- """
165
  Update:
166
  data (dict):{
167
  "R_errs" List[float]: [N]
@@ -171,33 +188,36 @@ def compute_pose_errors(data, config):
171
  """
172
  pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
173
  conf = config.TRAINER.RANSAC_CONF # 0.99999
174
- data.update({'R_errs': [], 't_errs': [], 'inliers': []})
175
 
176
- m_bids = data['m_bids'].cpu().numpy()
177
- pts0 = data['mkpts0_f'].cpu().numpy()
178
- pts1 = data['mkpts1_f'].cpu().numpy()
179
- K0 = data['K0'].cpu().numpy()
180
- K1 = data['K1'].cpu().numpy()
181
- T_0to1 = data['T_0to1'].cpu().numpy()
182
 
183
  for bs in range(K0.shape[0]):
184
  mask = m_bids == bs
185
- ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf)
 
 
186
 
187
  if ret is None:
188
- data['R_errs'].append(np.inf)
189
- data['t_errs'].append(np.inf)
190
- data['inliers'].append(np.array([]).astype(np.bool))
191
  else:
192
  R, t, inliers = ret
193
  t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
194
- data['R_errs'].append(R_err)
195
- data['t_errs'].append(t_err)
196
- data['inliers'].append(inliers)
197
 
198
 
199
  # --- METRIC AGGREGATION ---
200
 
 
201
  def error_auc(errors, thresholds):
202
  """
203
  Args:
@@ -211,14 +231,14 @@ def error_auc(errors, thresholds):
211
  thresholds = [5, 10, 20]
212
  for thr in thresholds:
213
  last_index = np.searchsorted(errors, thr)
214
- y = recall[:last_index] + [recall[last_index-1]]
215
  x = errors[:last_index] + [thr]
216
  aucs.append(np.trapz(y, x) / thr)
217
 
218
- return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
219
 
220
 
221
- def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
222
  precs = []
223
  for thr in thresholds:
224
  prec_ = []
@@ -227,34 +247,47 @@ def epidist_prec(errors, thresholds, ret_dict=False,offset=False):
227
  prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
228
  precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
229
  if ret_dict:
230
- return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} if not offset else {f'prec_flow@{t:.0e}': prec for t, prec in zip(thresholds, precs)}
 
 
 
 
231
  else:
232
  return precs
233
 
234
 
235
  def aggregate_metrics(metrics, epi_err_thr=5e-4):
236
- """ Aggregate metrics for the whole dataset:
237
  (This method should be called once per dataset)
238
  1. AUC of the pose error (angular) at the threshold [5, 10, 20]
239
  2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
240
  """
241
  # filter duplicates
242
- unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers']))
243
  unq_ids = list(unq_ids.values())
244
- logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...')
245
 
246
  # pose auc
247
  angular_thresholds = [5, 10, 20]
248
- pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids]
 
 
249
  aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
250
 
251
  # matching precision
252
  dist_thresholds = [epi_err_thr]
253
- precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr)
254
-
255
- #offset precision
 
 
256
  try:
257
- precs_offset = epidist_prec(np.array(metrics['epi_errs_offset'], dtype=object)[unq_ids], [2e-3], True,offset=True)
258
- return {**aucs, **precs,**precs_offset}
 
 
 
 
 
259
  except:
260
  return {**aucs, **precs}
 
9
 
10
  # --- METRICS ---
11
 
12
+
13
  def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
14
  # angle error between 2 vectors
15
  t_gt = T_0to1[:3, 3]
 
22
  # angle error between 2 rotation matrices
23
  R_gt = T_0to1[:3, :3]
24
  cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
25
+ cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
26
  R_err = np.rad2deg(np.abs(np.arccos(cos)))
27
 
28
  return t_err, R_err
 
44
  p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
45
  Etp1 = pts1 @ E # [N, 3]
46
 
47
+ d = p1Ep0**2 * (
48
+ 1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
49
+ + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)
50
+ ) # N
51
  return d
52
 
53
 
54
  def compute_symmetrical_epipolar_errors(data):
55
+ """
56
  Update:
57
  data (dict):{"epi_errs": [M]}
58
  """
59
+ Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
60
+ E_mat = Tx @ data["T_0to1"][:, :3, :3]
61
 
62
+ m_bids = data["m_bids"]
63
+ pts0 = data["mkpts0_f"]
64
+ pts1 = data["mkpts1_f"]
65
 
66
  epi_errs = []
67
  for bs in range(Tx.size(0)):
68
  mask = m_bids == bs
69
  epi_errs.append(
70
+ symmetric_epipolar_distance(
71
+ pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
72
+ )
73
+ )
74
  epi_errs = torch.cat(epi_errs, dim=0)
75
 
76
+ data.update({"epi_errs": epi_errs})
77
+
78
 
79
  def compute_symmetrical_epipolar_errors_offset(data):
80
+ """
81
  Update:
82
  data (dict):{"epi_errs": [M]}
83
  """
84
+ Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
85
+ E_mat = Tx @ data["T_0to1"][:, :3, :3]
86
 
87
+ m_bids = data["offset_bids"]
88
+ l_ids = data["offset_lids"]
89
+ pts0 = data["offset_kpts0_f"]
90
+ pts1 = data["offset_kpts1_f"]
91
 
92
  epi_errs = []
93
+ layer_num = data["predict_flow"][0].shape[0]
94
+
95
  for bs in range(Tx.size(0)):
96
  for ls in range(layer_num):
97
  mask_b = m_bids == bs
98
  mask_l = l_ids == ls
99
+ mask = mask_b & mask_l
100
  epi_errs.append(
101
+ symmetric_epipolar_distance(
102
+ pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
103
+ )
104
+ )
105
  epi_errs = torch.cat(epi_errs, dim=0)
106
 
107
+ data.update({"epi_errs_offset": epi_errs}) # [b*l*n]
108
+
109
 
110
  def compute_symmetrical_epipolar_errors_offset_bidirectional(data):
111
+ """
112
  Update
113
  data (dict):{"epi_errs": [M]}
114
  """
115
+ _compute_symmetrical_epipolar_errors_offset(data, "left")
116
+ _compute_symmetrical_epipolar_errors_offset(data, "right")
117
 
118
 
119
+ def _compute_symmetrical_epipolar_errors_offset(data, side):
120
+ """
121
  Update
122
  data (dict):{"epi_errs": [M]}
123
  """
124
+ assert side == "left" or side == "right", "invalid side"
125
 
126
+ Tx = numeric.cross_product_matrix(data["T_0to1"][:, :3, 3])
127
+ E_mat = Tx @ data["T_0to1"][:, :3, :3]
128
 
129
+ m_bids = data["offset_bids_" + side]
130
+ l_ids = data["offset_lids_" + side]
131
+ pts0 = data["offset_kpts0_f_" + side]
132
+ pts1 = data["offset_kpts1_f_" + side]
133
 
134
  epi_errs = []
135
+ layer_num = data["predict_flow"][0].shape[0]
136
  for bs in range(Tx.size(0)):
137
  for ls in range(layer_num):
138
  mask_b = m_bids == bs
139
  mask_l = l_ids == ls
140
+ mask = mask_b & mask_l
141
  epi_errs.append(
142
+ symmetric_epipolar_distance(
143
+ pts0[mask], pts1[mask], E_mat[bs], data["K0"][bs], data["K1"][bs]
144
+ )
145
+ )
146
  epi_errs = torch.cat(epi_errs, dim=0)
147
+ data.update({"epi_errs_offset_" + side: epi_errs}) # [b*l*n]
148
+
149
 
150
  def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
151
  if len(kpts0) < 5:
 
159
 
160
  # compute pose with cv2
161
  E, mask = cv2.findEssentialMat(
162
+ kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC
163
+ )
164
  if E is None:
165
  print("\nE is None while trying to recover pose.\n")
166
  return None
 
178
 
179
 
180
  def compute_pose_errors(data, config):
181
+ """
182
  Update:
183
  data (dict):{
184
  "R_errs" List[float]: [N]
 
188
  """
189
  pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5
190
  conf = config.TRAINER.RANSAC_CONF # 0.99999
191
+ data.update({"R_errs": [], "t_errs": [], "inliers": []})
192
 
193
+ m_bids = data["m_bids"].cpu().numpy()
194
+ pts0 = data["mkpts0_f"].cpu().numpy()
195
+ pts1 = data["mkpts1_f"].cpu().numpy()
196
+ K0 = data["K0"].cpu().numpy()
197
+ K1 = data["K1"].cpu().numpy()
198
+ T_0to1 = data["T_0to1"].cpu().numpy()
199
 
200
  for bs in range(K0.shape[0]):
201
  mask = m_bids == bs
202
+ ret = estimate_pose(
203
+ pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf
204
+ )
205
 
206
  if ret is None:
207
+ data["R_errs"].append(np.inf)
208
+ data["t_errs"].append(np.inf)
209
+ data["inliers"].append(np.array([]).astype(np.bool))
210
  else:
211
  R, t, inliers = ret
212
  t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0)
213
+ data["R_errs"].append(R_err)
214
+ data["t_errs"].append(t_err)
215
+ data["inliers"].append(inliers)
216
 
217
 
218
  # --- METRIC AGGREGATION ---
219
 
220
+
221
  def error_auc(errors, thresholds):
222
  """
223
  Args:
 
231
  thresholds = [5, 10, 20]
232
  for thr in thresholds:
233
  last_index = np.searchsorted(errors, thr)
234
+ y = recall[:last_index] + [recall[last_index - 1]]
235
  x = errors[:last_index] + [thr]
236
  aucs.append(np.trapz(y, x) / thr)
237
 
238
+ return {f"auc@{t}": auc for t, auc in zip(thresholds, aucs)}
239
 
240
 
241
+ def epidist_prec(errors, thresholds, ret_dict=False, offset=False):
242
  precs = []
243
  for thr in thresholds:
244
  prec_ = []
 
247
  prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0)
248
  precs.append(np.mean(prec_) if len(prec_) > 0 else 0)
249
  if ret_dict:
250
+ return (
251
+ {f"prec@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
252
+ if not offset
253
+ else {f"prec_flow@{t:.0e}": prec for t, prec in zip(thresholds, precs)}
254
+ )
255
  else:
256
  return precs
257
 
258
 
259
  def aggregate_metrics(metrics, epi_err_thr=5e-4):
260
+ """Aggregate metrics for the whole dataset:
261
  (This method should be called once per dataset)
262
  1. AUC of the pose error (angular) at the threshold [5, 10, 20]
263
  2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth)
264
  """
265
  # filter duplicates
266
+ unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics["identifiers"]))
267
  unq_ids = list(unq_ids.values())
268
+ logger.info(f"Aggregating metrics over {len(unq_ids)} unique items...")
269
 
270
  # pose auc
271
  angular_thresholds = [5, 10, 20]
272
+ pose_errors = np.max(np.stack([metrics["R_errs"], metrics["t_errs"]]), axis=0)[
273
+ unq_ids
274
+ ]
275
  aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20)
276
 
277
  # matching precision
278
  dist_thresholds = [epi_err_thr]
279
+ precs = epidist_prec(
280
+ np.array(metrics["epi_errs"], dtype=object)[unq_ids], dist_thresholds, True
281
+ ) # (prec@err_thr)
282
+
283
+ # offset precision
284
  try:
285
+ precs_offset = epidist_prec(
286
+ np.array(metrics["epi_errs_offset"], dtype=object)[unq_ids],
287
+ [2e-3],
288
+ True,
289
+ offset=True,
290
+ )
291
+ return {**aucs, **precs, **precs_offset}
292
  except:
293
  return {**aucs, **precs}
third_party/ASpanFormer/src/utils/misc.py CHANGED
@@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_only
11
  import cv2
12
  import numpy as np
13
 
 
14
  def lower_config(yacs_cfg):
15
  if not isinstance(yacs_cfg, CN):
16
  return yacs_cfg
@@ -25,7 +26,7 @@ def upper_config(dict_cfg):
25
 
26
  def log_on(condition, message, level):
27
  if condition:
28
- assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL']
29
  logger.log(level, message)
30
 
31
 
@@ -35,32 +36,35 @@ def get_rank_zero_only_logger(logger: _Logger):
35
  else:
36
  for _level in logger._core.levels.keys():
37
  level = _level.lower()
38
- setattr(logger, level,
39
- lambda x: None)
40
  logger._log = lambda x: None
41
  return logger
42
 
43
 
44
  def setup_gpus(gpus: Union[str, int]) -> int:
45
- """ A temporary fix for pytorch-lighting 1.3.x """
46
  gpus = str(gpus)
47
  gpu_ids = []
48
-
49
- if ',' not in gpus:
50
  n_gpus = int(gpus)
51
  return n_gpus if n_gpus != -1 else torch.cuda.device_count()
52
  else:
53
- gpu_ids = [i.strip() for i in gpus.split(',') if i != '']
54
-
55
  # setup environment variables
56
- visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
57
  if visible_devices is None:
58
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
59
- os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids)
60
- visible_devices = os.getenv('CUDA_VISIBLE_DEVICES')
61
- logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}')
 
 
62
  else:
63
- logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.')
 
 
64
  return len(gpu_ids)
65
 
66
 
@@ -71,11 +75,11 @@ def flattenList(x):
71
  @contextlib.contextmanager
72
  def tqdm_joblib(tqdm_object):
73
  """Context manager to patch joblib to report into tqdm progress bar given as argument
74
-
75
  Usage:
76
  with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
77
  Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
78
-
79
  When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
80
  ret_vals = Parallel(n_jobs=args.world_size)(
81
  delayed(lambda x: _compute_cov_score(pid, *x))(param)
@@ -84,6 +88,7 @@ def tqdm_joblib(tqdm_object):
84
  total=len(image_ids)*(len(image_ids)-1)/2))
85
  Src: https://stackoverflow.com/a/58936697
86
  """
 
87
  class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
88
  def __init__(self, *args, **kwargs):
89
  super().__init__(*args, **kwargs)
@@ -101,39 +106,79 @@ def tqdm_joblib(tqdm_object):
101
  tqdm_object.close()
102
 
103
 
104
- def draw_points(img,points,color=(0,255,0),radius=3):
105
  dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
106
  for i in range(points.shape[0]):
107
- cv2.circle(img, dp[i],radius=radius,color=color)
108
  return img
109
-
110
 
111
- def draw_match(img1, img2, corr1, corr2,inlier=[True],color=None,radius1=1,radius2=1,resize=None):
 
 
 
 
 
 
 
 
 
 
 
112
  if resize is not None:
113
- scale1,scale2=[img1.shape[1]/resize[0],img1.shape[0]/resize[1]],[img2.shape[1]/resize[0],img2.shape[0]/resize[1]]
114
- img1,img2=cv2.resize(img1, resize, interpolation=cv2.INTER_AREA),cv2.resize(img2, resize, interpolation=cv2.INTER_AREA)
115
- corr1,corr2=corr1/np.asarray(scale1)[np.newaxis],corr2/np.asarray(scale2)[np.newaxis]
116
- corr1_key = [cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])]
117
- corr2_key = [cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])]
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  assert len(corr1) == len(corr2)
120
 
121
  draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
122
  if color is None:
123
- color = [(0, 255, 0) if cur_inlier else (0,0,255) for cur_inlier in inlier]
124
- if len(color)==1:
125
- display = cv2.drawMatches(img1, corr1_key, img2, corr2_key, draw_matches, None,
126
- matchColor=color[0],
127
- singlePointColor=color[0],
128
- flags=4
129
- )
 
 
 
 
 
 
130
  else:
131
- height,width=max(img1.shape[0],img2.shape[0]),img1.shape[1]+img2.shape[1]
132
- display=np.zeros([height,width,3],np.uint8)
133
- display[:img1.shape[0],:img1.shape[1]]=img1
134
- display[:img2.shape[0],img1.shape[1]:]=img2
135
  for i in range(len(corr1)):
136
- left_x,left_y,right_x,right_y=int(corr1[i][0]),int(corr1[i][1]),int(corr2[i][0]+img1.shape[1]),int(corr2[i][1])
137
- cur_color=(int(color[i][0]),int(color[i][1]),int(color[i][2]))
138
- cv2.line(display, (left_x,left_y), (right_x,right_y),cur_color,1,lineType=cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
139
  return display
 
11
  import cv2
12
  import numpy as np
13
 
14
+
15
  def lower_config(yacs_cfg):
16
  if not isinstance(yacs_cfg, CN):
17
  return yacs_cfg
 
26
 
27
  def log_on(condition, message, level):
28
  if condition:
29
+ assert level in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]
30
  logger.log(level, message)
31
 
32
 
 
36
  else:
37
  for _level in logger._core.levels.keys():
38
  level = _level.lower()
39
+ setattr(logger, level, lambda x: None)
 
40
  logger._log = lambda x: None
41
  return logger
42
 
43
 
44
  def setup_gpus(gpus: Union[str, int]) -> int:
45
+ """A temporary fix for pytorch-lighting 1.3.x"""
46
  gpus = str(gpus)
47
  gpu_ids = []
48
+
49
+ if "," not in gpus:
50
  n_gpus = int(gpus)
51
  return n_gpus if n_gpus != -1 else torch.cuda.device_count()
52
  else:
53
+ gpu_ids = [i.strip() for i in gpus.split(",") if i != ""]
54
+
55
  # setup environment variables
56
+ visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
57
  if visible_devices is None:
58
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
59
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_ids)
60
+ visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
61
+ logger.warning(
62
+ f"[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}"
63
+ )
64
  else:
65
+ logger.warning(
66
+ "[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process."
67
+ )
68
  return len(gpu_ids)
69
 
70
 
 
75
  @contextlib.contextmanager
76
  def tqdm_joblib(tqdm_object):
77
  """Context manager to patch joblib to report into tqdm progress bar given as argument
78
+
79
  Usage:
80
  with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar:
81
  Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10))
82
+
83
  When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing)
84
  ret_vals = Parallel(n_jobs=args.world_size)(
85
  delayed(lambda x: _compute_cov_score(pid, *x))(param)
 
88
  total=len(image_ids)*(len(image_ids)-1)/2))
89
  Src: https://stackoverflow.com/a/58936697
90
  """
91
+
92
  class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
93
  def __init__(self, *args, **kwargs):
94
  super().__init__(*args, **kwargs)
 
106
  tqdm_object.close()
107
 
108
 
109
+ def draw_points(img, points, color=(0, 255, 0), radius=3):
110
  dp = [(int(points[i, 0]), int(points[i, 1])) for i in range(points.shape[0])]
111
  for i in range(points.shape[0]):
112
+ cv2.circle(img, dp[i], radius=radius, color=color)
113
  return img
 
114
 
115
+
116
+ def draw_match(
117
+ img1,
118
+ img2,
119
+ corr1,
120
+ corr2,
121
+ inlier=[True],
122
+ color=None,
123
+ radius1=1,
124
+ radius2=1,
125
+ resize=None,
126
+ ):
127
  if resize is not None:
128
+ scale1, scale2 = [img1.shape[1] / resize[0], img1.shape[0] / resize[1]], [
129
+ img2.shape[1] / resize[0],
130
+ img2.shape[0] / resize[1],
131
+ ]
132
+ img1, img2 = cv2.resize(img1, resize, interpolation=cv2.INTER_AREA), cv2.resize(
133
+ img2, resize, interpolation=cv2.INTER_AREA
134
+ )
135
+ corr1, corr2 = (
136
+ corr1 / np.asarray(scale1)[np.newaxis],
137
+ corr2 / np.asarray(scale2)[np.newaxis],
138
+ )
139
+ corr1_key = [
140
+ cv2.KeyPoint(corr1[i, 0], corr1[i, 1], radius1) for i in range(corr1.shape[0])
141
+ ]
142
+ corr2_key = [
143
+ cv2.KeyPoint(corr2[i, 0], corr2[i, 1], radius2) for i in range(corr2.shape[0])
144
+ ]
145
 
146
  assert len(corr1) == len(corr2)
147
 
148
  draw_matches = [cv2.DMatch(i, i, 0) for i in range(len(corr1))]
149
  if color is None:
150
+ color = [(0, 255, 0) if cur_inlier else (0, 0, 255) for cur_inlier in inlier]
151
+ if len(color) == 1:
152
+ display = cv2.drawMatches(
153
+ img1,
154
+ corr1_key,
155
+ img2,
156
+ corr2_key,
157
+ draw_matches,
158
+ None,
159
+ matchColor=color[0],
160
+ singlePointColor=color[0],
161
+ flags=4,
162
+ )
163
  else:
164
+ height, width = max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]
165
+ display = np.zeros([height, width, 3], np.uint8)
166
+ display[: img1.shape[0], : img1.shape[1]] = img1
167
+ display[: img2.shape[0], img1.shape[1] :] = img2
168
  for i in range(len(corr1)):
169
+ left_x, left_y, right_x, right_y = (
170
+ int(corr1[i][0]),
171
+ int(corr1[i][1]),
172
+ int(corr2[i][0] + img1.shape[1]),
173
+ int(corr2[i][1]),
174
+ )
175
+ cur_color = (int(color[i][0]), int(color[i][1]), int(color[i][2]))
176
+ cv2.line(
177
+ display,
178
+ (left_x, left_y),
179
+ (right_x, right_y),
180
+ cur_color,
181
+ 1,
182
+ lineType=cv2.LINE_AA,
183
+ )
184
  return display
third_party/ASpanFormer/src/utils/plotting.py CHANGED
@@ -4,38 +4,51 @@ import matplotlib.pyplot as plt
4
  import matplotlib
5
  from copy import deepcopy
6
 
 
7
  def _compute_conf_thresh(data):
8
- dataset_name = data['dataset_name'][0].lower()
9
- if dataset_name == 'scannet':
10
  thr = 5e-4
11
- elif dataset_name == 'megadepth' or dataset_name=='gl3d':
12
  thr = 1e-4
13
  else:
14
- raise ValueError(f'Unknown dataset: {dataset_name}')
15
  return thr
16
 
17
 
18
  # --- VISUALIZATION --- #
19
 
 
20
  def make_matching_figure(
21
- img0, img1, mkpts0, mkpts1, color,
22
- kpts0=None, kpts1=None, text=[], dpi=75, path=None):
 
 
 
 
 
 
 
 
 
23
  # draw image pair
24
- assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
 
 
25
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
26
- axes[0].imshow(img0, cmap='gray')
27
- axes[1].imshow(img1, cmap='gray')
28
- for i in range(2): # clear all frames
29
  axes[i].get_yaxis().set_ticks([])
30
  axes[i].get_xaxis().set_ticks([])
31
  for spine in axes[i].spines.values():
32
  spine.set_visible(False)
33
  plt.tight_layout(pad=1)
34
-
35
  if kpts0 is not None:
36
  assert kpts1 is not None
37
- axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
38
- axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
39
 
40
  # draw matches
41
  if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
@@ -43,164 +56,181 @@ def make_matching_figure(
43
  transFigure = fig.transFigure.inverted()
44
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
45
  fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
46
- fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
47
- (fkpts0[i, 1], fkpts1[i, 1]),
48
- transform=fig.transFigure, c=color[i], linewidth=1)
49
- for i in range(len(mkpts0))]
50
-
 
 
 
 
 
 
51
  axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
52
  axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
53
 
54
  # put txts
55
- txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
56
  fig.text(
57
- 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
58
- fontsize=15, va='top', ha='left', color=txt_color)
 
 
 
 
 
 
 
59
 
60
  # save or return figure
61
  if path:
62
- plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
63
  plt.close()
64
  else:
65
  return fig
66
 
67
 
68
- def _make_evaluation_figure(data, b_id, alpha='dynamic'):
69
- b_mask = data['m_bids'] == b_id
70
  conf_thr = _compute_conf_thresh(data)
71
-
72
- img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
73
- img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
74
- kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
75
- kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
76
-
77
  # for megadepth, we visualize matches on the resized image
78
- if 'scale0' in data:
79
- kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
80
- kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
81
- epi_errs = data['epi_errs'][b_mask].cpu().numpy()
82
  correct_mask = epi_errs < conf_thr
83
  precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
84
  n_correct = np.sum(correct_mask)
85
- n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
86
  recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
87
  # recall might be larger than 1, since the calculation of conf_matrix_gt
88
  # uses groundtruth depths and camera poses, but epipolar distance is used here.
89
 
90
  # matching info
91
- if alpha == 'dynamic':
92
  alpha = dynamic_alpha(len(correct_mask))
93
  color = error_colormap(epi_errs, conf_thr, alpha=alpha)
94
-
95
  text = [
96
- f'#Matches {len(kpts0)}',
97
- f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
98
- f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
99
  ]
100
-
101
  # make the figure
102
- figure = make_matching_figure(img0, img1, kpts0, kpts1,
103
- color, text=text)
104
  return figure
105
 
106
- def _make_evaluation_figure_offset(data, b_id, alpha='dynamic',side=''):
107
- layer_num=data['predict_flow'][0].shape[0]
108
 
109
- b_mask = data['offset_bids'+side] == b_id
110
- conf_thr = 2e-3 #hardcode for scannet(coarse level)
111
- img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
112
- img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
113
-
114
- figure_list=[]
115
- #draw offset matches in different layers
 
 
 
116
  for layer_index in range(layer_num):
117
- l_mask=data['offset_lids'+side]==layer_index
118
- mask=l_mask&b_mask
119
- kpts0 = data['offset_kpts0_f'+side][mask].cpu().numpy()
120
- kpts1 = data['offset_kpts1_f'+side][mask].cpu().numpy()
121
-
122
- epi_errs = data['epi_errs_offset'+side][mask].cpu().numpy()
123
  correct_mask = epi_errs < conf_thr
124
-
125
  precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
126
  n_correct = np.sum(correct_mask)
127
- n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
128
  recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
129
  # recall might be larger than 1, since the calculation of conf_matrix_gt
130
  # uses groundtruth depths and camera poses, but epipolar distance is used here.
131
 
132
  # matching info
133
- if alpha == 'dynamic':
134
  alpha = dynamic_alpha(len(correct_mask))
135
  color = error_colormap(epi_errs, conf_thr, alpha=alpha)
136
-
137
  text = [
138
- f'#Matches {len(kpts0)}',
139
- f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
140
- f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
141
  ]
142
-
143
  # make the figure
144
- #import pdb;pdb.set_trace()
145
- figure = make_matching_figure(deepcopy(img0), deepcopy(img1) , kpts0, kpts1,
146
- color, text=text)
 
147
  figure_list.append(figure)
148
  return figure
149
 
 
150
  def _make_confidence_figure(data, b_id):
151
  # TODO: Implement confidence figure
152
  raise NotImplementedError()
153
 
154
 
155
- def make_matching_figures(data, config, mode='evaluation'):
156
- """ Make matching figures for a batch.
157
-
158
  Args:
159
  data (Dict): a batch updated by PL_LoFTR.
160
  config (Dict): matcher config
161
  Returns:
162
  figures (Dict[str, List[plt.figure]]
163
  """
164
- assert mode in ['evaluation', 'confidence'] # 'confidence'
165
  figures = {mode: []}
166
- for b_id in range(data['image0'].size(0)):
167
- if mode == 'evaluation':
168
  fig = _make_evaluation_figure(
169
- data, b_id,
170
- alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
171
- elif mode == 'confidence':
172
  fig = _make_confidence_figure(data, b_id)
173
  else:
174
- raise ValueError(f'Unknown plot mode: {mode}')
175
  figures[mode].append(fig)
176
  return figures
177
 
178
- def make_matching_figures_offset(data, config, mode='evaluation',side=''):
179
- """ Make matching figures for a batch.
180
-
 
181
  Args:
182
  data (Dict): a batch updated by PL_LoFTR.
183
  config (Dict): matcher config
184
  Returns:
185
  figures (Dict[str, List[plt.figure]]
186
  """
187
- assert mode in ['evaluation', 'confidence'] # 'confidence'
188
  figures = {mode: []}
189
- for b_id in range(data['image0'].size(0)):
190
- if mode == 'evaluation':
191
  fig = _make_evaluation_figure_offset(
192
- data, b_id,
193
- alpha=config.TRAINER.PLOT_MATCHES_ALPHA,side=side)
194
- elif mode == 'confidence':
195
  fig = _make_evaluation_figure_offset(data, b_id)
196
  else:
197
- raise ValueError(f'Unknown plot mode: {mode}')
198
  figures[mode].append(fig)
199
  return figures
200
 
201
- def dynamic_alpha(n_matches,
202
- milestones=[0, 300, 1000, 2000],
203
- alphas=[1.0, 0.8, 0.4, 0.2]):
 
204
  if n_matches == 0:
205
  return 1.0
206
  ranges = list(zip(alphas, alphas[1:] + [None]))
@@ -209,11 +239,15 @@ def dynamic_alpha(n_matches,
209
  if _range[1] is None:
210
  return _range[0]
211
  return _range[1] + (milestones[loc + 1] - n_matches) / (
212
- milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
 
213
 
214
 
215
  def error_colormap(err, thr, alpha=1.0):
216
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
217
  x = 1 - np.clip(err / (thr * 2), 0, 1)
218
  return np.clip(
219
- np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
 
 
 
 
4
  import matplotlib
5
  from copy import deepcopy
6
 
7
+
8
  def _compute_conf_thresh(data):
9
+ dataset_name = data["dataset_name"][0].lower()
10
+ if dataset_name == "scannet":
11
  thr = 5e-4
12
+ elif dataset_name == "megadepth" or dataset_name == "gl3d":
13
  thr = 1e-4
14
  else:
15
+ raise ValueError(f"Unknown dataset: {dataset_name}")
16
  return thr
17
 
18
 
19
  # --- VISUALIZATION --- #
20
 
21
+
22
  def make_matching_figure(
23
+ img0,
24
+ img1,
25
+ mkpts0,
26
+ mkpts1,
27
+ color,
28
+ kpts0=None,
29
+ kpts1=None,
30
+ text=[],
31
+ dpi=75,
32
+ path=None,
33
+ ):
34
  # draw image pair
35
+ assert (
36
+ mkpts0.shape[0] == mkpts1.shape[0]
37
+ ), f"mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}"
38
  fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
39
+ axes[0].imshow(img0, cmap="gray")
40
+ axes[1].imshow(img1, cmap="gray")
41
+ for i in range(2): # clear all frames
42
  axes[i].get_yaxis().set_ticks([])
43
  axes[i].get_xaxis().set_ticks([])
44
  for spine in axes[i].spines.values():
45
  spine.set_visible(False)
46
  plt.tight_layout(pad=1)
47
+
48
  if kpts0 is not None:
49
  assert kpts1 is not None
50
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c="w", s=2)
51
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=2)
52
 
53
  # draw matches
54
  if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
 
56
  transFigure = fig.transFigure.inverted()
57
  fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
58
  fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
59
+ fig.lines = [
60
+ matplotlib.lines.Line2D(
61
+ (fkpts0[i, 0], fkpts1[i, 0]),
62
+ (fkpts0[i, 1], fkpts1[i, 1]),
63
+ transform=fig.transFigure,
64
+ c=color[i],
65
+ linewidth=1,
66
+ )
67
+ for i in range(len(mkpts0))
68
+ ]
69
+
70
  axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
71
  axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
72
 
73
  # put txts
74
+ txt_color = "k" if img0[:100, :200].mean() > 200 else "w"
75
  fig.text(
76
+ 0.01,
77
+ 0.99,
78
+ "\n".join(text),
79
+ transform=fig.axes[0].transAxes,
80
+ fontsize=15,
81
+ va="top",
82
+ ha="left",
83
+ color=txt_color,
84
+ )
85
 
86
  # save or return figure
87
  if path:
88
+ plt.savefig(str(path), bbox_inches="tight", pad_inches=0)
89
  plt.close()
90
  else:
91
  return fig
92
 
93
 
94
+ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
95
+ b_mask = data["m_bids"] == b_id
96
  conf_thr = _compute_conf_thresh(data)
97
+
98
+ img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
99
+ img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
100
+ kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
101
+ kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
102
+
103
  # for megadepth, we visualize matches on the resized image
104
+ if "scale0" in data:
105
+ kpts0 = kpts0 / data["scale0"][b_id].cpu().numpy()[[1, 0]]
106
+ kpts1 = kpts1 / data["scale1"][b_id].cpu().numpy()[[1, 0]]
107
+ epi_errs = data["epi_errs"][b_mask].cpu().numpy()
108
  correct_mask = epi_errs < conf_thr
109
  precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
110
  n_correct = np.sum(correct_mask)
111
+ n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
112
  recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
113
  # recall might be larger than 1, since the calculation of conf_matrix_gt
114
  # uses groundtruth depths and camera poses, but epipolar distance is used here.
115
 
116
  # matching info
117
+ if alpha == "dynamic":
118
  alpha = dynamic_alpha(len(correct_mask))
119
  color = error_colormap(epi_errs, conf_thr, alpha=alpha)
120
+
121
  text = [
122
+ f"#Matches {len(kpts0)}",
123
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
124
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
125
  ]
126
+
127
  # make the figure
128
+ figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
 
129
  return figure
130
 
 
 
131
 
132
+ def _make_evaluation_figure_offset(data, b_id, alpha="dynamic", side=""):
133
+ layer_num = data["predict_flow"][0].shape[0]
134
+
135
+ b_mask = data["offset_bids" + side] == b_id
136
+ conf_thr = 2e-3 # hardcode for scannet(coarse level)
137
+ img0 = (data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
138
+ img1 = (data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
139
+
140
+ figure_list = []
141
+ # draw offset matches in different layers
142
  for layer_index in range(layer_num):
143
+ l_mask = data["offset_lids" + side] == layer_index
144
+ mask = l_mask & b_mask
145
+ kpts0 = data["offset_kpts0_f" + side][mask].cpu().numpy()
146
+ kpts1 = data["offset_kpts1_f" + side][mask].cpu().numpy()
147
+
148
+ epi_errs = data["epi_errs_offset" + side][mask].cpu().numpy()
149
  correct_mask = epi_errs < conf_thr
150
+
151
  precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
152
  n_correct = np.sum(correct_mask)
153
+ n_gt_matches = int(data["conf_matrix_gt"][b_id].sum().cpu())
154
  recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
155
  # recall might be larger than 1, since the calculation of conf_matrix_gt
156
  # uses groundtruth depths and camera poses, but epipolar distance is used here.
157
 
158
  # matching info
159
+ if alpha == "dynamic":
160
  alpha = dynamic_alpha(len(correct_mask))
161
  color = error_colormap(epi_errs, conf_thr, alpha=alpha)
162
+
163
  text = [
164
+ f"#Matches {len(kpts0)}",
165
+ f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}",
166
+ f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}",
167
  ]
168
+
169
  # make the figure
170
+ # import pdb;pdb.set_trace()
171
+ figure = make_matching_figure(
172
+ deepcopy(img0), deepcopy(img1), kpts0, kpts1, color, text=text
173
+ )
174
  figure_list.append(figure)
175
  return figure
176
 
177
+
178
  def _make_confidence_figure(data, b_id):
179
  # TODO: Implement confidence figure
180
  raise NotImplementedError()
181
 
182
 
183
+ def make_matching_figures(data, config, mode="evaluation"):
184
+ """Make matching figures for a batch.
185
+
186
  Args:
187
  data (Dict): a batch updated by PL_LoFTR.
188
  config (Dict): matcher config
189
  Returns:
190
  figures (Dict[str, List[plt.figure]]
191
  """
192
+ assert mode in ["evaluation", "confidence"] # 'confidence'
193
  figures = {mode: []}
194
+ for b_id in range(data["image0"].size(0)):
195
+ if mode == "evaluation":
196
  fig = _make_evaluation_figure(
197
+ data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA
198
+ )
199
+ elif mode == "confidence":
200
  fig = _make_confidence_figure(data, b_id)
201
  else:
202
+ raise ValueError(f"Unknown plot mode: {mode}")
203
  figures[mode].append(fig)
204
  return figures
205
 
206
+
207
+ def make_matching_figures_offset(data, config, mode="evaluation", side=""):
208
+ """Make matching figures for a batch.
209
+
210
  Args:
211
  data (Dict): a batch updated by PL_LoFTR.
212
  config (Dict): matcher config
213
  Returns:
214
  figures (Dict[str, List[plt.figure]]
215
  """
216
+ assert mode in ["evaluation", "confidence"] # 'confidence'
217
  figures = {mode: []}
218
+ for b_id in range(data["image0"].size(0)):
219
+ if mode == "evaluation":
220
  fig = _make_evaluation_figure_offset(
221
+ data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA, side=side
222
+ )
223
+ elif mode == "confidence":
224
  fig = _make_evaluation_figure_offset(data, b_id)
225
  else:
226
+ raise ValueError(f"Unknown plot mode: {mode}")
227
  figures[mode].append(fig)
228
  return figures
229
 
230
+
231
+ def dynamic_alpha(
232
+ n_matches, milestones=[0, 300, 1000, 2000], alphas=[1.0, 0.8, 0.4, 0.2]
233
+ ):
234
  if n_matches == 0:
235
  return 1.0
236
  ranges = list(zip(alphas, alphas[1:] + [None]))
 
239
  if _range[1] is None:
240
  return _range[0]
241
  return _range[1] + (milestones[loc + 1] - n_matches) / (
242
+ milestones[loc + 1] - milestones[loc]
243
+ ) * (_range[0] - _range[1])
244
 
245
 
246
  def error_colormap(err, thr, alpha=1.0):
247
  assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
248
  x = 1 - np.clip(err / (thr * 2), 0, 1)
249
  return np.clip(
250
+ np.stack([2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1),
251
+ 0,
252
+ 1,
253
+ )
third_party/ASpanFormer/src/utils/profiler.py CHANGED
@@ -7,7 +7,7 @@ from pytorch_lightning.utilities import rank_zero_only
7
  class InferenceProfiler(SimpleProfiler):
8
  """
9
  This profiler records duration of actions with cuda.synchronize()
10
- Use this in test time.
11
  """
12
 
13
  def __init__(self):
@@ -28,12 +28,13 @@ class InferenceProfiler(SimpleProfiler):
28
 
29
 
30
  def build_profiler(name):
31
- if name == 'inference':
32
  return InferenceProfiler()
33
- elif name == 'pytorch':
34
  from pytorch_lightning.profiler import PyTorchProfiler
 
35
  return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
36
  elif name is None:
37
  return PassThroughProfiler()
38
  else:
39
- raise ValueError(f'Invalid profiler: {name}')
 
7
  class InferenceProfiler(SimpleProfiler):
8
  """
9
  This profiler records duration of actions with cuda.synchronize()
10
+ Use this in test time.
11
  """
12
 
13
  def __init__(self):
 
28
 
29
 
30
  def build_profiler(name):
31
+ if name == "inference":
32
  return InferenceProfiler()
33
+ elif name == "pytorch":
34
  from pytorch_lightning.profiler import PyTorchProfiler
35
+
36
  return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
37
  elif name is None:
38
  return PassThroughProfiler()
39
  else:
40
+ raise ValueError(f"Invalid profiler: {name}")
third_party/ASpanFormer/test.py CHANGED
@@ -10,33 +10,52 @@ from src.lightning.data import MultiSceneDataModule
10
  from src.lightning.lightning_aspanformer import PL_ASpanFormer
11
  import torch
12
 
 
13
  def parse_args():
14
  # init a costum parser which will be added into pl.Trainer parser
15
  # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
16
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17
- parser.add_argument(
18
- 'data_cfg_path', type=str, help='data config path')
19
- parser.add_argument(
20
- 'main_cfg_path', type=str, help='main config path')
21
- parser.add_argument(
22
- '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint')
23
- parser.add_argument(
24
- '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir")
25
  parser.add_argument(
26
- '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset')
 
 
 
 
27
  parser.add_argument(
28
- '--batch_size', type=int, default=1, help='batch_size per gpu')
 
 
 
 
29
  parser.add_argument(
30
- '--num_workers', type=int, default=2)
 
 
 
 
 
 
31
  parser.add_argument(
32
- '--thr', type=float, default=None, help='modify the coarse-level matching threshold.')
 
 
 
 
33
  parser.add_argument(
34
- '--mode', type=str, default='vanilla', help='modify the coarse-level matching threshold.')
 
 
 
 
35
  parser = pl.Trainer.add_argparse_args(parser)
36
  return parser.parse_args()
37
 
38
 
39
- if __name__ == '__main__':
40
  # parse arguments
41
  args = parse_args()
42
  pprint.pprint(vars(args))
@@ -55,7 +74,12 @@ if __name__ == '__main__':
55
 
56
  # lightning module
57
  profiler = build_profiler(args.profiler_name)
58
- model = PL_ASpanFormer(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir)
 
 
 
 
 
59
  loguru_logger.info(f"ASpanFormer-lightning initialized!")
60
 
61
  # lightning data
@@ -63,7 +87,9 @@ if __name__ == '__main__':
63
  loguru_logger.info(f"DataModule initialized!")
64
 
65
  # lightning trainer
66
- trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False)
 
 
67
 
68
  loguru_logger.info(f"Start testing!")
69
  trainer.test(model, datamodule=data_module, verbose=False)
 
10
  from src.lightning.lightning_aspanformer import PL_ASpanFormer
11
  import torch
12
 
13
+
14
  def parse_args():
15
  # init a costum parser which will be added into pl.Trainer parser
16
  # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags
17
+ parser = argparse.ArgumentParser(
18
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
19
+ )
20
+ parser.add_argument("data_cfg_path", type=str, help="data config path")
21
+ parser.add_argument("main_cfg_path", type=str, help="main config path")
 
 
 
 
22
  parser.add_argument(
23
+ "--ckpt_path",
24
+ type=str,
25
+ default="weights/indoor_ds.ckpt",
26
+ help="path to the checkpoint",
27
+ )
28
  parser.add_argument(
29
+ "--dump_dir",
30
+ type=str,
31
+ default=None,
32
+ help="if set, the matching results will be dump to dump_dir",
33
+ )
34
  parser.add_argument(
35
+ "--profiler_name",
36
+ type=str,
37
+ default=None,
38
+ help="options: [inference, pytorch], or leave it unset",
39
+ )
40
+ parser.add_argument("--batch_size", type=int, default=1, help="batch_size per gpu")
41
+ parser.add_argument("--num_workers", type=int, default=2)
42
  parser.add_argument(
43
+ "--thr",
44
+ type=float,
45
+ default=None,
46
+ help="modify the coarse-level matching threshold.",
47
+ )
48
  parser.add_argument(
49
+ "--mode",
50
+ type=str,
51
+ default="vanilla",
52
+ help="modify the coarse-level matching threshold.",
53
+ )
54
  parser = pl.Trainer.add_argparse_args(parser)
55
  return parser.parse_args()
56
 
57
 
58
+ if __name__ == "__main__":
59
  # parse arguments
60
  args = parse_args()
61
  pprint.pprint(vars(args))
 
74
 
75
  # lightning module
76
  profiler = build_profiler(args.profiler_name)
77
+ model = PL_ASpanFormer(
78
+ config,
79
+ pretrained_ckpt=args.ckpt_path,
80
+ profiler=profiler,
81
+ dump_dir=args.dump_dir,
82
+ )
83
  loguru_logger.info(f"ASpanFormer-lightning initialized!")
84
 
85
  # lightning data
 
87
  loguru_logger.info(f"DataModule initialized!")
88
 
89
  # lightning trainer
90
+ trainer = pl.Trainer.from_argparse_args(
91
+ args, replace_sampler_ddp=False, logger=False
92
+ )
93
 
94
  loguru_logger.info(f"Start testing!")
95
  trainer.test(model, datamodule=data_module, verbose=False)
third_party/ASpanFormer/tools/extract.py CHANGED
@@ -5,43 +5,77 @@ from tqdm import tqdm
5
  from multiprocessing import Pool
6
  from functools import partial
7
 
8
- scannet_dir='/root/data/ScanNet-v2-1.0.0/data/raw'
9
- dump_dir='/root/data/scannet_dump'
10
- num_process=32
11
-
12
- def extract(seq,scannet_dir,split,dump_dir):
13
- assert split=='train' or split=='test'
14
- if not os.path.exists(os.path.join(dump_dir,split,seq)):
15
- os.mkdir(os.path.join(dump_dir,split,seq))
16
- cmd='python reader.py --filename '+os.path.join(scannet_dir,'scans' if split=='train' else 'scans_test',seq,seq+'.sens')+' --output_path '+os.path.join(dump_dir,split,seq)+\
17
- ' --export_depth_images --export_color_images --export_poses --export_intrinsics'
 
 
 
 
 
 
 
 
 
 
 
18
  os.system(cmd)
19
 
20
- if __name__=='__main__':
 
21
  if not os.path.exists(dump_dir):
22
  os.mkdir(dump_dir)
23
- os.mkdir(os.path.join(dump_dir,'train'))
24
- os.mkdir(os.path.join(dump_dir,'test'))
25
 
26
- train_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans','scene*'))]
27
- test_seq_list=[seq.split('/')[-1] for seq in glob.glob(os.path.join(scannet_dir,'scans_test','scene*'))]
 
 
 
 
 
 
28
 
29
- extract_train=partial(extract,scannet_dir=scannet_dir,split='train',dump_dir=dump_dir)
30
- extract_test=partial(extract,scannet_dir=scannet_dir,split='test',dump_dir=dump_dir)
 
 
 
 
31
 
32
- num_train_iter=len(train_seq_list)//num_process if len(train_seq_list)%num_process==0 else len(train_seq_list)//num_process+1
33
- num_test_iter=len(test_seq_list)//num_process if len(test_seq_list)%num_process==0 else len(test_seq_list)//num_process+1
 
 
 
 
 
 
 
 
34
 
35
  pool = Pool(num_process)
36
  for index in tqdm(range(num_train_iter)):
37
- seq_list=train_seq_list[index*num_process:min((index+1)*num_process,len(train_seq_list))]
38
- pool.map(extract_train,seq_list)
 
 
39
  pool.close()
40
  pool.join()
41
-
42
  pool = Pool(num_process)
43
  for index in tqdm(range(num_test_iter)):
44
- seq_list=test_seq_list[index*num_process:min((index+1)*num_process,len(test_seq_list))]
45
- pool.map(extract_test,seq_list)
 
 
46
  pool.close()
47
- pool.join()
 
5
  from multiprocessing import Pool
6
  from functools import partial
7
 
8
+ scannet_dir = "/root/data/ScanNet-v2-1.0.0/data/raw"
9
+ dump_dir = "/root/data/scannet_dump"
10
+ num_process = 32
11
+
12
+
13
+ def extract(seq, scannet_dir, split, dump_dir):
14
+ assert split == "train" or split == "test"
15
+ if not os.path.exists(os.path.join(dump_dir, split, seq)):
16
+ os.mkdir(os.path.join(dump_dir, split, seq))
17
+ cmd = (
18
+ "python reader.py --filename "
19
+ + os.path.join(
20
+ scannet_dir,
21
+ "scans" if split == "train" else "scans_test",
22
+ seq,
23
+ seq + ".sens",
24
+ )
25
+ + " --output_path "
26
+ + os.path.join(dump_dir, split, seq)
27
+ + " --export_depth_images --export_color_images --export_poses --export_intrinsics"
28
+ )
29
  os.system(cmd)
30
 
31
+
32
+ if __name__ == "__main__":
33
  if not os.path.exists(dump_dir):
34
  os.mkdir(dump_dir)
35
+ os.mkdir(os.path.join(dump_dir, "train"))
36
+ os.mkdir(os.path.join(dump_dir, "test"))
37
 
38
+ train_seq_list = [
39
+ seq.split("/")[-1]
40
+ for seq in glob.glob(os.path.join(scannet_dir, "scans", "scene*"))
41
+ ]
42
+ test_seq_list = [
43
+ seq.split("/")[-1]
44
+ for seq in glob.glob(os.path.join(scannet_dir, "scans_test", "scene*"))
45
+ ]
46
 
47
+ extract_train = partial(
48
+ extract, scannet_dir=scannet_dir, split="train", dump_dir=dump_dir
49
+ )
50
+ extract_test = partial(
51
+ extract, scannet_dir=scannet_dir, split="test", dump_dir=dump_dir
52
+ )
53
 
54
+ num_train_iter = (
55
+ len(train_seq_list) // num_process
56
+ if len(train_seq_list) % num_process == 0
57
+ else len(train_seq_list) // num_process + 1
58
+ )
59
+ num_test_iter = (
60
+ len(test_seq_list) // num_process
61
+ if len(test_seq_list) % num_process == 0
62
+ else len(test_seq_list) // num_process + 1
63
+ )
64
 
65
  pool = Pool(num_process)
66
  for index in tqdm(range(num_train_iter)):
67
+ seq_list = train_seq_list[
68
+ index * num_process : min((index + 1) * num_process, len(train_seq_list))
69
+ ]
70
+ pool.map(extract_train, seq_list)
71
  pool.close()
72
  pool.join()
73
+
74
  pool = Pool(num_process)
75
  for index in tqdm(range(num_test_iter)):
76
+ seq_list = test_seq_list[
77
+ index * num_process : min((index + 1) * num_process, len(test_seq_list))
78
+ ]
79
+ pool.map(extract_test, seq_list)
80
  pool.close()
81
+ pool.join()