haofeixu commited on
Commit
9ef8038
1 Parent(s): f5d2232

unimatch demo

Browse files
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ from unimatch.unimatch import UniMatch
8
+ from utils.flow_viz import flow_to_image
9
+ from dataloader.stereo import transforms
10
+ from utils.visualization import vis_disparity
11
+
12
+ IMAGENET_MEAN = [0.485, 0.456, 0.406]
13
+ IMAGENET_STD = [0.229, 0.224, 0.225]
14
+
15
+
16
+ @torch.no_grad()
17
+ def inference(image1, image2, task='flow'):
18
+ """Inference on an image pair for optical flow or stereo disparity prediction"""
19
+
20
+ model = UniMatch(feature_channels=128,
21
+ num_scales=2,
22
+ upsample_factor=4,
23
+ ffn_dim_expansion=4,
24
+ num_transformer_layers=6,
25
+ reg_refine=True,
26
+ task=task)
27
+
28
+ model.eval()
29
+
30
+ if task == 'flow':
31
+ checkpoint_path = 'pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth'
32
+ else:
33
+ checkpoint_path = 'pretrained/gmstereo-scale2-regrefine3-resumeflowthings-mixdata-train320x640-ft640x960-e4e291fd.pth'
34
+
35
+ checkpoint_flow = torch.load(checkpoint_path)
36
+ model.load_state_dict(checkpoint_flow['model'], strict=True)
37
+
38
+ padding_factor = 32
39
+ attn_type = 'swin' if task == 'flow' else 'self_swin2d_cross_swin1d'
40
+ attn_splits_list = [2, 8]
41
+ corr_radius_list = [-1, 4]
42
+ prop_radius_list = [-1, 1]
43
+ num_reg_refine = 6 if task == 'flow' else 3
44
+
45
+ # smaller inference size for faster speed
46
+ max_inference_size = [384, 768] if task == 'flow' else [640, 960]
47
+
48
+ transpose_img = False
49
+
50
+ image1 = np.array(image1).astype(np.float32)
51
+ image2 = np.array(image2).astype(np.float32)
52
+
53
+ if len(image1.shape) == 2: # gray image
54
+ image1 = np.tile(image1[..., None], (1, 1, 3))
55
+ image2 = np.tile(image2[..., None], (1, 1, 3))
56
+ else:
57
+ image1 = image1[..., :3]
58
+ image2 = image2[..., :3]
59
+
60
+ if task == 'flow':
61
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0)
62
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0)
63
+ else:
64
+ val_transform_list = [transforms.ToTensor(),
65
+ transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
66
+ ]
67
+
68
+ val_transform = transforms.Compose(val_transform_list)
69
+
70
+ sample = {'left': image1, 'right': image2}
71
+ sample = val_transform(sample)
72
+
73
+ image1 = sample['left'].unsqueeze(0) # [1, 3, H, W]
74
+ image2 = sample['right'].unsqueeze(0) # [1, 3, H, W]
75
+
76
+ # the model is trained with size: width > height
77
+ if task == 'flow' and image1.size(-2) > image1.size(-1):
78
+ image1 = torch.transpose(image1, -2, -1)
79
+ image2 = torch.transpose(image2, -2, -1)
80
+ transpose_img = True
81
+
82
+ nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor,
83
+ int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor]
84
+
85
+ inference_size = [min(max_inference_size[0], nearest_size[0]), min(max_inference_size[1], nearest_size[1])]
86
+
87
+ assert isinstance(inference_size, list) or isinstance(inference_size, tuple)
88
+ ori_size = image1.shape[-2:]
89
+
90
+ # resize before inference
91
+ if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
92
+ image1 = F.interpolate(image1, size=inference_size, mode='bilinear',
93
+ align_corners=True)
94
+ image2 = F.interpolate(image2, size=inference_size, mode='bilinear',
95
+ align_corners=True)
96
+
97
+ results_dict = model(image1, image2,
98
+ attn_type=attn_type,
99
+ attn_splits_list=attn_splits_list,
100
+ corr_radius_list=corr_radius_list,
101
+ prop_radius_list=prop_radius_list,
102
+ num_reg_refine=num_reg_refine,
103
+ task=task,
104
+ )
105
+
106
+ flow_pr = results_dict['flow_preds'][-1] # [1, 2, H, W] or [1, H, W]
107
+
108
+ # resize back
109
+ if task == 'flow':
110
+ if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
111
+ flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear',
112
+ align_corners=True)
113
+ flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1]
114
+ flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2]
115
+ else:
116
+ if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]:
117
+ pred_disp = F.interpolate(flow_pr.unsqueeze(1), size=ori_size,
118
+ mode='bilinear',
119
+ align_corners=True).squeeze(1) # [1, H, W]
120
+ pred_disp = pred_disp * ori_size[-1] / float(inference_size[-1])
121
+
122
+ if task == 'flow':
123
+ if transpose_img:
124
+ flow_pr = torch.transpose(flow_pr, -2, -1)
125
+
126
+ flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2]
127
+
128
+ output = flow_to_image(flow) # [H, W, 3]
129
+ else:
130
+ disp = pred_disp[0].cpu().numpy()
131
+
132
+ output = vis_disparity(disp, return_rgb=True)
133
+
134
+ return Image.fromarray(output)
135
+
136
+
137
+ title = "UniMatch"
138
+
139
+ description = "<p style='text-align: center'>Optical flow and stereo matching demo for <a href='https://haofeixu.github.io/unimatch/' target='_blank'>Unifying Flow, Stereo and Depth Estimation</a> | <a href='https://arxiv.org/abs/2211.05783' target='_blank'>Paper</a> | <a href='https://github.com/autonomousvision/unimatch' target='_blank'>Code</a> | <a href='https://colab.research.google.com/drive/1r5m-xVy3Kw60U-m5VB-aQ98oqqg_6cab?usp=sharing' target='_blank'>Colab</a><br>Simply upload your images or click one of the provided examples.<br>The <strong>first three</strong> examples are video frames for <strong>flow</strong> task, and the <strong>last three</strong> are stereo pairs for <strong>stereo</strong> task.<br><strong>Select the task type according to your input images</strong>.</p>"
140
+
141
+ examples = [
142
+ ['demo/flow_kitti_test_000197_10.png', 'demo/flow_kitti_test_000197_11.png'],
143
+ ['demo/flow_sintel_cave_3_frame_0049.png', 'demo/flow_sintel_cave_3_frame_0050.png'],
144
+ ['demo/flow_davis_skate-jump_00059.jpg', 'demo/flow_davis_skate-jump_00060.jpg'],
145
+ ['demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg',
146
+ 'demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg'],
147
+ ['demo/stereo_middlebury_plants_im0.png', 'demo/stereo_middlebury_plants_im1.png'],
148
+ ['demo/stereo_holopix_left.png', 'demo/stereo_holopix_right.png']
149
+ ]
150
+
151
+ gr.Interface(
152
+ inference,
153
+ [gr.Image(type="pil", label="Image1"), gr.Image(type="pil", label="Image2"), gr.Radio(choices=['flow', 'stereo'], value='flow', label='Task')],
154
+ gr.Image(type="pil", label="Flow/Disparity"),
155
+ title=title,
156
+ description=description,
157
+ examples=examples,
158
+ thumbnail="https://haofeixu.github.io/unimatch/resources/teaser.svg",
159
+ allow_flagging="auto",
160
+ ).launch(debug=True)
dataloader/__init__.py ADDED
File without changes
dataloader/stereo/transforms.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms.functional as F
6
+ import random
7
+ import cv2
8
+
9
+
10
+ class Compose(object):
11
+ def __init__(self, transforms):
12
+ self.transforms = transforms
13
+
14
+ def __call__(self, sample):
15
+ for t in self.transforms:
16
+ sample = t(sample)
17
+ return sample
18
+
19
+
20
+ class ToTensor(object):
21
+ """Convert numpy array to torch tensor"""
22
+
23
+ def __init__(self, no_normalize=False):
24
+ self.no_normalize = no_normalize
25
+
26
+ def __call__(self, sample):
27
+ left = np.transpose(sample['left'], (2, 0, 1)) # [3, H, W]
28
+ if self.no_normalize:
29
+ sample['left'] = torch.from_numpy(left)
30
+ else:
31
+ sample['left'] = torch.from_numpy(left) / 255.
32
+ right = np.transpose(sample['right'], (2, 0, 1))
33
+
34
+ if self.no_normalize:
35
+ sample['right'] = torch.from_numpy(right)
36
+ else:
37
+ sample['right'] = torch.from_numpy(right) / 255.
38
+
39
+ # disp = np.expand_dims(sample['disp'], axis=0) # [1, H, W]
40
+ if 'disp' in sample.keys():
41
+ disp = sample['disp'] # [H, W]
42
+ sample['disp'] = torch.from_numpy(disp)
43
+
44
+ return sample
45
+
46
+
47
+ class Normalize(object):
48
+ """Normalize image, with type tensor"""
49
+
50
+ def __init__(self, mean, std):
51
+ self.mean = mean
52
+ self.std = std
53
+
54
+ def __call__(self, sample):
55
+
56
+ norm_keys = ['left', 'right']
57
+
58
+ for key in norm_keys:
59
+ # Images have converted to tensor, with shape [C, H, W]
60
+ for t, m, s in zip(sample[key], self.mean, self.std):
61
+ t.sub_(m).div_(s)
62
+
63
+ return sample
64
+
65
+
66
+ class RandomCrop(object):
67
+ def __init__(self, img_height, img_width):
68
+ self.img_height = img_height
69
+ self.img_width = img_width
70
+
71
+ def __call__(self, sample):
72
+ ori_height, ori_width = sample['left'].shape[:2]
73
+
74
+ # pad zero when crop size is larger than original image size
75
+ if self.img_height > ori_height or self.img_width > ori_width:
76
+
77
+ # can be used for only pad one side
78
+ top_pad = max(self.img_height - ori_height, 0)
79
+ right_pad = max(self.img_width - ori_width, 0)
80
+
81
+ # try edge padding
82
+ sample['left'] = np.lib.pad(sample['left'],
83
+ ((top_pad, 0), (0, right_pad), (0, 0)),
84
+ mode='edge')
85
+ sample['right'] = np.lib.pad(sample['right'],
86
+ ((top_pad, 0), (0, right_pad), (0, 0)),
87
+ mode='edge')
88
+
89
+ if 'disp' in sample.keys():
90
+ sample['disp'] = np.lib.pad(sample['disp'],
91
+ ((top_pad, 0), (0, right_pad)),
92
+ mode='constant',
93
+ constant_values=0)
94
+
95
+ # update image resolution
96
+ ori_height, ori_width = sample['left'].shape[:2]
97
+
98
+ assert self.img_height <= ori_height and self.img_width <= ori_width
99
+
100
+ # Training: random crop
101
+ self.offset_x = np.random.randint(ori_width - self.img_width + 1)
102
+
103
+ start_height = 0
104
+ assert ori_height - start_height >= self.img_height
105
+
106
+ self.offset_y = np.random.randint(start_height, ori_height - self.img_height + 1)
107
+
108
+ sample['left'] = self.crop_img(sample['left'])
109
+ sample['right'] = self.crop_img(sample['right'])
110
+ if 'disp' in sample.keys():
111
+ sample['disp'] = self.crop_img(sample['disp'])
112
+
113
+ return sample
114
+
115
+ def crop_img(self, img):
116
+ return img[self.offset_y:self.offset_y + self.img_height,
117
+ self.offset_x:self.offset_x + self.img_width]
118
+
119
+
120
+ class RandomVerticalFlip(object):
121
+ """Randomly vertically filps"""
122
+
123
+ def __call__(self, sample):
124
+ if np.random.random() < 0.5:
125
+ sample['left'] = np.copy(np.flipud(sample['left']))
126
+ sample['right'] = np.copy(np.flipud(sample['right']))
127
+
128
+ sample['disp'] = np.copy(np.flipud(sample['disp']))
129
+
130
+ return sample
131
+
132
+
133
+ class ToPILImage(object):
134
+
135
+ def __call__(self, sample):
136
+ sample['left'] = Image.fromarray(sample['left'].astype('uint8'))
137
+ sample['right'] = Image.fromarray(sample['right'].astype('uint8'))
138
+
139
+ return sample
140
+
141
+
142
+ class ToNumpyArray(object):
143
+
144
+ def __call__(self, sample):
145
+ sample['left'] = np.array(sample['left']).astype(np.float32)
146
+ sample['right'] = np.array(sample['right']).astype(np.float32)
147
+
148
+ return sample
149
+
150
+
151
+ # Random coloring
152
+ class RandomContrast(object):
153
+ """Random contrast"""
154
+
155
+ def __init__(self,
156
+ asymmetric_color_aug=True,
157
+ ):
158
+
159
+ self.asymmetric_color_aug = asymmetric_color_aug
160
+
161
+ def __call__(self, sample):
162
+ if np.random.random() < 0.5:
163
+ contrast_factor = np.random.uniform(0.8, 1.2)
164
+
165
+ sample['left'] = F.adjust_contrast(sample['left'], contrast_factor)
166
+
167
+ if self.asymmetric_color_aug and np.random.random() < 0.5:
168
+ contrast_factor = np.random.uniform(0.8, 1.2)
169
+
170
+ sample['right'] = F.adjust_contrast(sample['right'], contrast_factor)
171
+
172
+ return sample
173
+
174
+
175
+ class RandomGamma(object):
176
+
177
+ def __init__(self,
178
+ asymmetric_color_aug=True,
179
+ ):
180
+
181
+ self.asymmetric_color_aug = asymmetric_color_aug
182
+
183
+ def __call__(self, sample):
184
+ if np.random.random() < 0.5:
185
+ gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
186
+
187
+ sample['left'] = F.adjust_gamma(sample['left'], gamma)
188
+
189
+ if self.asymmetric_color_aug and np.random.random() < 0.5:
190
+ gamma = np.random.uniform(0.7, 1.5) # adopted from FlowNet
191
+
192
+ sample['right'] = F.adjust_gamma(sample['right'], gamma)
193
+
194
+ return sample
195
+
196
+
197
+ class RandomBrightness(object):
198
+
199
+ def __init__(self,
200
+ asymmetric_color_aug=True,
201
+ ):
202
+
203
+ self.asymmetric_color_aug = asymmetric_color_aug
204
+
205
+ def __call__(self, sample):
206
+ if np.random.random() < 0.5:
207
+ brightness = np.random.uniform(0.5, 2.0)
208
+
209
+ sample['left'] = F.adjust_brightness(sample['left'], brightness)
210
+
211
+ if self.asymmetric_color_aug and np.random.random() < 0.5:
212
+ brightness = np.random.uniform(0.5, 2.0)
213
+
214
+ sample['right'] = F.adjust_brightness(sample['right'], brightness)
215
+
216
+ return sample
217
+
218
+
219
+ class RandomHue(object):
220
+
221
+ def __init__(self,
222
+ asymmetric_color_aug=True,
223
+ ):
224
+
225
+ self.asymmetric_color_aug = asymmetric_color_aug
226
+
227
+ def __call__(self, sample):
228
+ if np.random.random() < 0.5:
229
+ hue = np.random.uniform(-0.1, 0.1)
230
+
231
+ sample['left'] = F.adjust_hue(sample['left'], hue)
232
+
233
+ if self.asymmetric_color_aug and np.random.random() < 0.5:
234
+ hue = np.random.uniform(-0.1, 0.1)
235
+
236
+ sample['right'] = F.adjust_hue(sample['right'], hue)
237
+
238
+ return sample
239
+
240
+
241
+ class RandomSaturation(object):
242
+
243
+ def __init__(self,
244
+ asymmetric_color_aug=True,
245
+ ):
246
+
247
+ self.asymmetric_color_aug = asymmetric_color_aug
248
+
249
+ def __call__(self, sample):
250
+ if np.random.random() < 0.5:
251
+ saturation = np.random.uniform(0.8, 1.2)
252
+
253
+ sample['left'] = F.adjust_saturation(sample['left'], saturation)
254
+
255
+ if self.asymmetric_color_aug and np.random.random() < 0.5:
256
+ saturation = np.random.uniform(0.8, 1.2)
257
+
258
+ sample['right'] = F.adjust_saturation(sample['right'], saturation)
259
+
260
+ return sample
261
+
262
+
263
+ class RandomColor(object):
264
+
265
+ def __init__(self,
266
+ asymmetric_color_aug=True,
267
+ ):
268
+
269
+ self.asymmetric_color_aug = asymmetric_color_aug
270
+
271
+ def __call__(self, sample):
272
+ transforms = [RandomContrast(asymmetric_color_aug=self.asymmetric_color_aug),
273
+ RandomGamma(asymmetric_color_aug=self.asymmetric_color_aug),
274
+ RandomBrightness(asymmetric_color_aug=self.asymmetric_color_aug),
275
+ RandomHue(asymmetric_color_aug=self.asymmetric_color_aug),
276
+ RandomSaturation(asymmetric_color_aug=self.asymmetric_color_aug)]
277
+
278
+ sample = ToPILImage()(sample)
279
+
280
+ if np.random.random() < 0.5:
281
+ # A single transform
282
+ t = random.choice(transforms)
283
+ sample = t(sample)
284
+ else:
285
+ # Combination of transforms
286
+ # Random order
287
+ random.shuffle(transforms)
288
+ for t in transforms:
289
+ sample = t(sample)
290
+
291
+ sample = ToNumpyArray()(sample)
292
+
293
+ return sample
294
+
295
+
296
+ class RandomScale(object):
297
+ def __init__(self,
298
+ min_scale=-0.4,
299
+ max_scale=0.4,
300
+ crop_width=512,
301
+ nearest_interp=False, # for sparse gt
302
+ ):
303
+ self.min_scale = min_scale
304
+ self.max_scale = max_scale
305
+ self.crop_width = crop_width
306
+ self.nearest_interp = nearest_interp
307
+
308
+ def __call__(self, sample):
309
+ if np.random.rand() < 0.5:
310
+ h, w = sample['disp'].shape
311
+
312
+ scale_x = 2 ** np.random.uniform(self.min_scale, self.max_scale)
313
+
314
+ scale_x = np.clip(scale_x, self.crop_width / float(w), None)
315
+
316
+ # only random scale x axis
317
+ sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
318
+ sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=1., interpolation=cv2.INTER_LINEAR)
319
+
320
+ sample['disp'] = cv2.resize(
321
+ sample['disp'], None, fx=scale_x, fy=1.,
322
+ interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
323
+ ) * scale_x
324
+
325
+ if 'pseudo_disp' in sample and sample['pseudo_disp'] is not None:
326
+ sample['pseudo_disp'] = cv2.resize(sample['pseudo_disp'], None, fx=scale_x, fy=1.,
327
+ interpolation=cv2.INTER_LINEAR) * scale_x
328
+
329
+ return sample
330
+
331
+
332
+ class Resize(object):
333
+ def __init__(self,
334
+ scale_x=1,
335
+ scale_y=1,
336
+ nearest_interp=True, # for sparse gt
337
+ ):
338
+ """
339
+ Resize low-resolution data to high-res for mixed dataset training
340
+ """
341
+ self.scale_x = scale_x
342
+ self.scale_y = scale_y
343
+ self.nearest_interp = nearest_interp
344
+
345
+ def __call__(self, sample):
346
+ scale_x = self.scale_x
347
+ scale_y = self.scale_y
348
+
349
+ sample['left'] = cv2.resize(sample['left'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
350
+ sample['right'] = cv2.resize(sample['right'], None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
351
+
352
+ sample['disp'] = cv2.resize(
353
+ sample['disp'], None, fx=scale_x, fy=scale_y,
354
+ interpolation=cv2.INTER_LINEAR if not self.nearest_interp else cv2.INTER_NEAREST
355
+ ) * scale_x
356
+
357
+ return sample
358
+
359
+
360
+ class RandomGrayscale(object):
361
+ def __init__(self, p=0.2):
362
+ self.p = p
363
+
364
+ def __call__(self, sample):
365
+ if np.random.random() < self.p:
366
+ sample = ToPILImage()(sample)
367
+
368
+ # only supported in higher version pytorch
369
+ # default output channels is 1
370
+ sample['left'] = F.rgb_to_grayscale(sample['left'], num_output_channels=3)
371
+ sample['right'] = F.rgb_to_grayscale(sample['right'], num_output_channels=3)
372
+
373
+ sample = ToNumpyArray()(sample)
374
+
375
+ return sample
376
+
377
+
378
+ class RandomRotateShiftRight(object):
379
+ def __init__(self, p=0.5):
380
+ self.p = p
381
+
382
+ def __call__(self, sample):
383
+ if np.random.random() < self.p:
384
+ angle, pixel = 0.1, 2
385
+ px = np.random.uniform(-pixel, pixel)
386
+ ag = np.random.uniform(-angle, angle)
387
+
388
+ right_img = sample['right']
389
+
390
+ image_center = (
391
+ np.random.uniform(0, right_img.shape[0]),
392
+ np.random.uniform(0, right_img.shape[1])
393
+ )
394
+
395
+ rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
396
+ right_img = cv2.warpAffine(
397
+ right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
398
+ )
399
+ trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
400
+ right_img = cv2.warpAffine(
401
+ right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
402
+ )
403
+
404
+ sample['right'] = right_img
405
+
406
+ return sample
407
+
408
+
409
+ class RandomOcclusion(object):
410
+ def __init__(self, p=0.5,
411
+ occlusion_mask_zero=False):
412
+ self.p = p
413
+ self.occlusion_mask_zero = occlusion_mask_zero
414
+
415
+ def __call__(self, sample):
416
+ bounds = [50, 100]
417
+ if np.random.random() < self.p:
418
+ img2 = sample['right']
419
+ ht, wd = img2.shape[:2]
420
+
421
+ if self.occlusion_mask_zero:
422
+ mean_color = 0
423
+ else:
424
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
425
+
426
+ x0 = np.random.randint(0, wd)
427
+ y0 = np.random.randint(0, ht)
428
+ dx = np.random.randint(bounds[0], bounds[1])
429
+ dy = np.random.randint(bounds[0], bounds[1])
430
+ img2[y0:y0 + dy, x0:x0 + dx, :] = mean_color
431
+
432
+ sample['right'] = img2
433
+
434
+ return sample
demo/flow_davis_skate-jump_00059.jpg ADDED
demo/flow_davis_skate-jump_00060.jpg ADDED
demo/flow_kitti_test_000197_10.png ADDED
demo/flow_kitti_test_000197_11.png ADDED
demo/flow_sintel_cave_3_frame_0049.png ADDED
demo/flow_sintel_cave_3_frame_0050.png ADDED
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_left.jpg ADDED
demo/stereo_drivingstereo_test_2018-07-11-14-48-52_2018-07-11-14-58-34-673_right.jpg ADDED
pretrained/tmp.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ opencv-python
5
+ pillow
unimatch/__init__.py ADDED
File without changes
unimatch/attention.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d
6
+
7
+
8
+ def single_head_full_attention(q, k, v):
9
+ # q, k, v: [B, L, C]
10
+ assert q.dim() == k.dim() == v.dim() == 3
11
+
12
+ scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L]
13
+ attn = torch.softmax(scores, dim=2) # [B, L, L]
14
+ out = torch.matmul(attn, v) # [B, L, C]
15
+
16
+ return out
17
+
18
+
19
+ def single_head_full_attention_1d(q, k, v,
20
+ h=None,
21
+ w=None,
22
+ ):
23
+ # q, k, v: [B, L, C]
24
+
25
+ assert h is not None and w is not None
26
+ assert q.size(1) == h * w
27
+
28
+ b, _, c = q.size()
29
+
30
+ q = q.view(b, h, w, c) # [B, H, W, C]
31
+ k = k.view(b, h, w, c)
32
+ v = v.view(b, h, w, c)
33
+
34
+ scale_factor = c ** 0.5
35
+
36
+ scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W]
37
+
38
+ attn = torch.softmax(scores, dim=-1)
39
+
40
+ out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C]
41
+
42
+ return out
43
+
44
+
45
+ def single_head_split_window_attention(q, k, v,
46
+ num_splits=1,
47
+ with_shift=False,
48
+ h=None,
49
+ w=None,
50
+ attn_mask=None,
51
+ ):
52
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
53
+ # q, k, v: [B, L, C]
54
+ assert q.dim() == k.dim() == v.dim() == 3
55
+
56
+ assert h is not None and w is not None
57
+ assert q.size(1) == h * w
58
+
59
+ b, _, c = q.size()
60
+
61
+ b_new = b * num_splits * num_splits
62
+
63
+ window_size_h = h // num_splits
64
+ window_size_w = w // num_splits
65
+
66
+ q = q.view(b, h, w, c) # [B, H, W, C]
67
+ k = k.view(b, h, w, c)
68
+ v = v.view(b, h, w, c)
69
+
70
+ scale_factor = c ** 0.5
71
+
72
+ if with_shift:
73
+ assert attn_mask is not None # compute once
74
+ shift_size_h = window_size_h // 2
75
+ shift_size_w = window_size_w // 2
76
+
77
+ q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
78
+ k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
79
+ v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2))
80
+
81
+ q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C]
82
+ k = split_feature(k, num_splits=num_splits, channel_last=True)
83
+ v = split_feature(v, num_splits=num_splits, channel_last=True)
84
+
85
+ scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
86
+ ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K]
87
+
88
+ if with_shift:
89
+ scores += attn_mask.repeat(b, 1, 1)
90
+
91
+ attn = torch.softmax(scores, dim=-1)
92
+
93
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C]
94
+
95
+ out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c),
96
+ num_splits=num_splits, channel_last=True) # [B, H, W, C]
97
+
98
+ # shift back
99
+ if with_shift:
100
+ out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2))
101
+
102
+ out = out.view(b, -1, c)
103
+
104
+ return out
105
+
106
+
107
+ def single_head_split_window_attention_1d(q, k, v,
108
+ relative_position_bias=None,
109
+ num_splits=1,
110
+ with_shift=False,
111
+ h=None,
112
+ w=None,
113
+ attn_mask=None,
114
+ ):
115
+ # q, k, v: [B, L, C]
116
+
117
+ assert h is not None and w is not None
118
+ assert q.size(1) == h * w
119
+
120
+ b, _, c = q.size()
121
+
122
+ b_new = b * num_splits * h
123
+
124
+ window_size_w = w // num_splits
125
+
126
+ q = q.view(b * h, w, c) # [B*H, W, C]
127
+ k = k.view(b * h, w, c)
128
+ v = v.view(b * h, w, c)
129
+
130
+ scale_factor = c ** 0.5
131
+
132
+ if with_shift:
133
+ assert attn_mask is not None # compute once
134
+ shift_size_w = window_size_w // 2
135
+
136
+ q = torch.roll(q, shifts=-shift_size_w, dims=1)
137
+ k = torch.roll(k, shifts=-shift_size_w, dims=1)
138
+ v = torch.roll(v, shifts=-shift_size_w, dims=1)
139
+
140
+ q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C]
141
+ k = split_feature_1d(k, num_splits=num_splits)
142
+ v = split_feature_1d(v, num_splits=num_splits)
143
+
144
+ scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)
145
+ ) / scale_factor # [B*H*K, W/K, W/K]
146
+
147
+ if with_shift:
148
+ # attn_mask: [K, W/K, W/K]
149
+ scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K]
150
+
151
+ attn = torch.softmax(scores, dim=-1)
152
+
153
+ out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C]
154
+
155
+ out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C]
156
+
157
+ # shift back
158
+ if with_shift:
159
+ out = torch.roll(out, shifts=shift_size_w, dims=2)
160
+
161
+ out = out.view(b, -1, c)
162
+
163
+ return out
164
+
165
+
166
+ class SelfAttnPropagation(nn.Module):
167
+ """
168
+ flow propagation with self-attention on feature
169
+ query: feature0, key: feature0, value: flow
170
+ """
171
+
172
+ def __init__(self, in_channels,
173
+ **kwargs,
174
+ ):
175
+ super(SelfAttnPropagation, self).__init__()
176
+
177
+ self.q_proj = nn.Linear(in_channels, in_channels)
178
+ self.k_proj = nn.Linear(in_channels, in_channels)
179
+
180
+ for p in self.parameters():
181
+ if p.dim() > 1:
182
+ nn.init.xavier_uniform_(p)
183
+
184
+ def forward(self, feature0, flow,
185
+ local_window_attn=False,
186
+ local_window_radius=1,
187
+ **kwargs,
188
+ ):
189
+ # q, k: feature [B, C, H, W], v: flow [B, 2, H, W]
190
+ if local_window_attn:
191
+ return self.forward_local_window_attn(feature0, flow,
192
+ local_window_radius=local_window_radius)
193
+
194
+ b, c, h, w = feature0.size()
195
+
196
+ query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C]
197
+
198
+ # a note: the ``correct'' implementation should be:
199
+ # ``query = self.q_proj(query), key = self.k_proj(query)''
200
+ # this problem is observed while cleaning up the code
201
+ # however, this doesn't affect the performance since the projection is a linear operation,
202
+ # thus the two projection matrices for key can be merged
203
+ # so I just leave it as is in order to not re-train all models :)
204
+ query = self.q_proj(query) # [B, H*W, C]
205
+ key = self.k_proj(query) # [B, H*W, C]
206
+
207
+ value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2]
208
+
209
+ scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W]
210
+ prob = torch.softmax(scores, dim=-1)
211
+
212
+ out = torch.matmul(prob, value) # [B, H*W, 2]
213
+ out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W]
214
+
215
+ return out
216
+
217
+ def forward_local_window_attn(self, feature0, flow,
218
+ local_window_radius=1,
219
+ ):
220
+ assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth
221
+ assert local_window_radius > 0
222
+
223
+ b, c, h, w = feature0.size()
224
+
225
+ value_channel = flow.size(1)
226
+
227
+ feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)
228
+ ).reshape(b * h * w, 1, c) # [B*H*W, 1, C]
229
+
230
+ kernel_size = 2 * local_window_radius + 1
231
+
232
+ feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w)
233
+
234
+ feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size,
235
+ padding=local_window_radius) # [B, C*(2R+1)^2), H*W]
236
+
237
+ feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute(
238
+ 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2]
239
+
240
+ flow_window = F.unfold(flow, kernel_size=kernel_size,
241
+ padding=local_window_radius) # [B, 2*(2R+1)^2), H*W]
242
+
243
+ flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute(
244
+ 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2]
245
+
246
+ scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2]
247
+
248
+ prob = torch.softmax(scores, dim=-1)
249
+
250
+ out = torch.matmul(prob, flow_window).view(b, h, w, value_channel
251
+ ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
252
+
253
+ return out
unimatch/backbone.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .trident_conv import MultiScaleTridentConv
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1,
8
+ ):
9
+ super(ResidualBlock, self).__init__()
10
+
11
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
12
+ dilation=dilation, padding=dilation, stride=stride, bias=False)
13
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
14
+ dilation=dilation, padding=dilation, bias=False)
15
+ self.relu = nn.ReLU(inplace=True)
16
+
17
+ self.norm1 = norm_layer(planes)
18
+ self.norm2 = norm_layer(planes)
19
+ if not stride == 1 or in_planes != planes:
20
+ self.norm3 = norm_layer(planes)
21
+
22
+ if stride == 1 and in_planes == planes:
23
+ self.downsample = None
24
+ else:
25
+ self.downsample = nn.Sequential(
26
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
27
+
28
+ def forward(self, x):
29
+ y = x
30
+ y = self.relu(self.norm1(self.conv1(y)))
31
+ y = self.relu(self.norm2(self.conv2(y)))
32
+
33
+ if self.downsample is not None:
34
+ x = self.downsample(x)
35
+
36
+ return self.relu(x + y)
37
+
38
+
39
+ class CNNEncoder(nn.Module):
40
+ def __init__(self, output_dim=128,
41
+ norm_layer=nn.InstanceNorm2d,
42
+ num_output_scales=1,
43
+ **kwargs,
44
+ ):
45
+ super(CNNEncoder, self).__init__()
46
+ self.num_branch = num_output_scales
47
+
48
+ feature_dims = [64, 96, 128]
49
+
50
+ self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2
51
+ self.norm1 = norm_layer(feature_dims[0])
52
+ self.relu1 = nn.ReLU(inplace=True)
53
+
54
+ self.in_planes = feature_dims[0]
55
+ self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2
56
+ self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4
57
+
58
+ # highest resolution 1/4 or 1/8
59
+ stride = 2 if num_output_scales == 1 else 1
60
+ self.layer3 = self._make_layer(feature_dims[2], stride=stride,
61
+ norm_layer=norm_layer,
62
+ ) # 1/4 or 1/8
63
+
64
+ self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0)
65
+
66
+ if self.num_branch > 1:
67
+ if self.num_branch == 4:
68
+ strides = (1, 2, 4, 8)
69
+ elif self.num_branch == 3:
70
+ strides = (1, 2, 4)
71
+ elif self.num_branch == 2:
72
+ strides = (1, 2)
73
+ else:
74
+ raise ValueError
75
+
76
+ self.trident_conv = MultiScaleTridentConv(output_dim, output_dim,
77
+ kernel_size=3,
78
+ strides=strides,
79
+ paddings=1,
80
+ num_branch=self.num_branch,
81
+ )
82
+
83
+ for m in self.modules():
84
+ if isinstance(m, nn.Conv2d):
85
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
86
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
87
+ if m.weight is not None:
88
+ nn.init.constant_(m.weight, 1)
89
+ if m.bias is not None:
90
+ nn.init.constant_(m.bias, 0)
91
+
92
+ def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d):
93
+ layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation)
94
+ layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation)
95
+
96
+ layers = (layer1, layer2)
97
+
98
+ self.in_planes = dim
99
+ return nn.Sequential(*layers)
100
+
101
+ def forward(self, x):
102
+ x = self.conv1(x)
103
+ x = self.norm1(x)
104
+ x = self.relu1(x)
105
+
106
+ x = self.layer1(x) # 1/2
107
+ x = self.layer2(x) # 1/4
108
+ x = self.layer3(x) # 1/8 or 1/4
109
+
110
+ x = self.conv2(x)
111
+
112
+ if self.num_branch > 1:
113
+ out = self.trident_conv([x] * self.num_branch) # high to low res
114
+ else:
115
+ out = [x]
116
+
117
+ return out
unimatch/geometry.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def coords_grid(b, h, w, homogeneous=False, device=None):
6
+ y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
7
+
8
+ stacks = [x, y]
9
+
10
+ if homogeneous:
11
+ ones = torch.ones_like(x) # [H, W]
12
+ stacks.append(ones)
13
+
14
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
15
+
16
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
17
+
18
+ if device is not None:
19
+ grid = grid.to(device)
20
+
21
+ return grid
22
+
23
+
24
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
25
+ assert device is not None
26
+
27
+ x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
28
+ torch.linspace(h_min, h_max, len_h, device=device)],
29
+ )
30
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
31
+
32
+ return grid
33
+
34
+
35
+ def normalize_coords(coords, h, w):
36
+ # coords: [B, H, W, 2]
37
+ c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
38
+ return (coords - c) / c # [-1, 1]
39
+
40
+
41
+ def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False):
42
+ # img: [B, C, H, W]
43
+ # sample_coords: [B, 2, H, W] in image scale
44
+ if sample_coords.size(1) != 2: # [B, H, W, 2]
45
+ sample_coords = sample_coords.permute(0, 3, 1, 2)
46
+
47
+ b, _, h, w = sample_coords.shape
48
+
49
+ # Normalize to [-1, 1]
50
+ x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
51
+ y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1
52
+
53
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2]
54
+
55
+ img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True)
56
+
57
+ if return_mask:
58
+ mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W]
59
+
60
+ return img, mask
61
+
62
+ return img
63
+
64
+
65
+ def flow_warp(feature, flow, mask=False, padding_mode='zeros'):
66
+ b, c, h, w = feature.size()
67
+ assert flow.size(1) == 2
68
+
69
+ grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W]
70
+
71
+ return bilinear_sample(feature, grid, padding_mode=padding_mode,
72
+ return_mask=mask)
73
+
74
+
75
+ def forward_backward_consistency_check(fwd_flow, bwd_flow,
76
+ alpha=0.01,
77
+ beta=0.5
78
+ ):
79
+ # fwd_flow, bwd_flow: [B, 2, H, W]
80
+ # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837)
81
+ assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
82
+ assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
83
+ flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W]
84
+
85
+ warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W]
86
+ warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W]
87
+
88
+ diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W]
89
+ diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)
90
+
91
+ threshold = alpha * flow_mag + beta
92
+
93
+ fwd_occ = (diff_fwd > threshold).float() # [B, H, W]
94
+ bwd_occ = (diff_bwd > threshold).float()
95
+
96
+ return fwd_occ, bwd_occ
97
+
98
+
99
+ def back_project(depth, intrinsics):
100
+ # Back project 2D pixel coords to 3D points
101
+ # depth: [B, H, W]
102
+ # intrinsics: [B, 3, 3]
103
+ b, h, w = depth.shape
104
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
105
+
106
+ intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3]
107
+
108
+ points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W]
109
+
110
+ return points
111
+
112
+
113
+ def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None):
114
+ # Transform 3D points from reference camera to target camera
115
+ # points_ref: [B, 3, H, W]
116
+ # extrinsics_ref: [B, 4, 4]
117
+ # extrinsics_tgt: [B, 4, 4]
118
+ # extrinsics_rel: [B, 4, 4], relative pose transform
119
+ b, _, h, w = points_ref.shape
120
+
121
+ if extrinsics_rel is None:
122
+ extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4]
123
+
124
+ points_tgt = torch.bmm(extrinsics_rel[:, :3, :3],
125
+ points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W]
126
+
127
+ points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W]
128
+
129
+ return points_tgt
130
+
131
+
132
+ def reproject(points_tgt, intrinsics, return_mask=False):
133
+ # reproject to target view
134
+ # points_tgt: [B, 3, H, W]
135
+ # intrinsics: [B, 3, 3]
136
+
137
+ b, _, h, w = points_tgt.shape
138
+
139
+ proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W]
140
+
141
+ X = proj_points[:, 0]
142
+ Y = proj_points[:, 1]
143
+ Z = proj_points[:, 2].clamp(min=1e-3)
144
+
145
+ pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale
146
+
147
+ if return_mask:
148
+ # valid mask in pixel space
149
+ mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & (
150
+ pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W]
151
+
152
+ return pixel_coords, mask
153
+
154
+ return pixel_coords
155
+
156
+
157
+ def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
158
+ return_mask=False):
159
+ # Compute reprojection sample coords
160
+ points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W]
161
+ points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel)
162
+
163
+ if return_mask:
164
+ reproj_coords, mask = reproject(points_tgt, intrinsics,
165
+ return_mask=return_mask) # [B, 2, H, W] in image scale
166
+
167
+ return reproj_coords, mask
168
+
169
+ reproj_coords = reproject(points_tgt, intrinsics,
170
+ return_mask=return_mask) # [B, 2, H, W] in image scale
171
+
172
+ return reproj_coords
173
+
174
+
175
+ def compute_flow_with_depth_pose(depth_ref, intrinsics,
176
+ extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None,
177
+ return_mask=False):
178
+ b, h, w = depth_ref.shape
179
+ coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W]
180
+
181
+ if return_mask:
182
+ reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
183
+ extrinsics_rel=extrinsics_rel,
184
+ return_mask=return_mask) # [B, 2, H, W]
185
+ rigid_flow = reproj_coords - coords_init
186
+
187
+ return rigid_flow, mask
188
+
189
+ reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt,
190
+ extrinsics_rel=extrinsics_rel,
191
+ return_mask=return_mask) # [B, 2, H, W]
192
+
193
+ rigid_flow = reproj_coords - coords_init
194
+
195
+ return rigid_flow
unimatch/matching.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .geometry import coords_grid, generate_window_grid, normalize_coords
5
+
6
+
7
+ def global_correlation_softmax(feature0, feature1,
8
+ pred_bidir_flow=False,
9
+ ):
10
+ # global correlation
11
+ b, c, h, w = feature0.shape
12
+ feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C]
13
+ feature1 = feature1.view(b, c, -1) # [B, C, H*W]
14
+
15
+ correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W]
16
+
17
+ # flow from softmax
18
+ init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W]
19
+ grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
20
+
21
+ correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W]
22
+
23
+ if pred_bidir_flow:
24
+ correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W]
25
+ init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W]
26
+ grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2]
27
+ b = b * 2
28
+
29
+ prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W]
30
+
31
+ correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
32
+
33
+ # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow
34
+ flow = correspondence - init_grid
35
+
36
+ return flow, prob
37
+
38
+
39
+ def local_correlation_softmax(feature0, feature1, local_radius,
40
+ padding_mode='zeros',
41
+ ):
42
+ b, c, h, w = feature0.size()
43
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
44
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
45
+
46
+ local_h = 2 * local_radius + 1
47
+ local_w = 2 * local_radius + 1
48
+
49
+ window_grid = generate_window_grid(-local_radius, local_radius,
50
+ -local_radius, local_radius,
51
+ local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
52
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
53
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2]
54
+
55
+ sample_coords_softmax = sample_coords
56
+
57
+ # exclude coords that are out of image space
58
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
59
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
60
+
61
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
62
+
63
+ # normalize coordinates to [-1, 1]
64
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
65
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
66
+ padding_mode=padding_mode, align_corners=True
67
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
68
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
69
+
70
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
71
+
72
+ # mask invalid locations
73
+ corr[~valid] = -1e9
74
+
75
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2]
76
+
77
+ correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view(
78
+ b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W]
79
+
80
+ flow = correspondence - coords_init
81
+ match_prob = prob
82
+
83
+ return flow, match_prob
84
+
85
+
86
+ def local_correlation_with_flow(feature0, feature1,
87
+ flow,
88
+ local_radius,
89
+ padding_mode='zeros',
90
+ dilation=1,
91
+ ):
92
+ b, c, h, w = feature0.size()
93
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
94
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2]
95
+
96
+ local_h = 2 * local_radius + 1
97
+ local_w = 2 * local_radius + 1
98
+
99
+ window_grid = generate_window_grid(-local_radius, local_radius,
100
+ -local_radius, local_radius,
101
+ local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2]
102
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2]
103
+ sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2]
104
+
105
+ # flow can be zero when using features after transformer
106
+ if not isinstance(flow, float):
107
+ sample_coords = sample_coords + flow.view(
108
+ b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2]
109
+ else:
110
+ assert flow == 0.
111
+
112
+ # normalize coordinates to [-1, 1]
113
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
114
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
115
+ padding_mode=padding_mode, align_corners=True
116
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2]
117
+ feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C]
118
+
119
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2]
120
+
121
+ corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W]
122
+
123
+ return corr
124
+
125
+
126
+ def global_correlation_softmax_stereo(feature0, feature1,
127
+ ):
128
+ # global correlation on horizontal direction
129
+ b, c, h, w = feature0.shape
130
+
131
+ x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W]
132
+
133
+ feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C]
134
+ feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W]
135
+
136
+ correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W]
137
+
138
+ # mask subsequent positions to make disparity positive
139
+ mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W]
140
+ valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W]
141
+
142
+ correlation[~valid_mask] = -1e9
143
+
144
+ prob = F.softmax(correlation, dim=-1) # [B, H, W, W]
145
+
146
+ correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W]
147
+
148
+ # NOTE: unlike flow, disparity is typically positive
149
+ disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W]
150
+
151
+ return disparity.unsqueeze(1), prob # feature resolution
152
+
153
+
154
+ def local_correlation_softmax_stereo(feature0, feature1, local_radius,
155
+ ):
156
+ b, c, h, w = feature0.size()
157
+ coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W]
158
+ coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2]
159
+
160
+ local_h = 1
161
+ local_w = 2 * local_radius + 1
162
+
163
+ window_grid = generate_window_grid(0, 0,
164
+ -local_radius, local_radius,
165
+ local_h, local_w, device=feature0.device) # [1, 2R+1, 2]
166
+ window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2]
167
+ sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2]
168
+
169
+ sample_coords_softmax = sample_coords
170
+
171
+ # exclude coords that are out of image space
172
+ valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2]
173
+ valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2]
174
+
175
+ valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax
176
+
177
+ # normalize coordinates to [-1, 1]
178
+ sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1]
179
+ window_feature = F.grid_sample(feature1, sample_coords_norm,
180
+ padding_mode='zeros', align_corners=True
181
+ ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)]
182
+ feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C]
183
+
184
+ corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)]
185
+
186
+ # mask invalid locations
187
+ corr[~valid] = -1e9
188
+
189
+ prob = F.softmax(corr, -1) # [B, H*W, (2R+1)]
190
+
191
+ correspondence = torch.matmul(prob.unsqueeze(-2),
192
+ sample_coords_softmax).squeeze(-2).view(
193
+ b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W]
194
+
195
+ flow = correspondence - coords_init # flow at feature resolution
196
+ match_prob = prob
197
+
198
+ flow_x = -flow[:, :1] # [B, 1, H, W]
199
+
200
+ return flow_x, match_prob
201
+
202
+
203
+ def correlation_softmax_depth(feature0, feature1,
204
+ intrinsics,
205
+ pose,
206
+ depth_candidates,
207
+ depth_from_argmax=False,
208
+ pred_bidir_depth=False,
209
+ ):
210
+ b, c, h, w = feature0.size()
211
+ assert depth_candidates.dim() == 4 # [B, D, H, W]
212
+ scale_factor = c ** 0.5
213
+
214
+ if pred_bidir_depth:
215
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
216
+ intrinsics = intrinsics.repeat(2, 1, 1)
217
+ pose = torch.cat((pose, torch.inverse(pose)), dim=0)
218
+ depth_candidates = depth_candidates.repeat(2, 1, 1, 1)
219
+
220
+ # depth candidates are actually inverse depth
221
+ warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose,
222
+ 1. / depth_candidates,
223
+ ) # [B, C, D, H, W]
224
+
225
+ correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W]
226
+
227
+ match_prob = F.softmax(correlation, dim=1) # [B, D, H, W]
228
+
229
+ # for cross-task transfer (flow -> depth), extract depth with argmax at test time
230
+ if depth_from_argmax:
231
+ index = torch.argmax(match_prob, dim=1, keepdim=True)
232
+ depth = torch.gather(depth_candidates, dim=1, index=index)
233
+ else:
234
+ depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W]
235
+
236
+ return depth, match_prob
237
+
238
+
239
+ def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth,
240
+ clamp_min_depth=1e-3,
241
+ ):
242
+ """
243
+ feature1: [B, C, H, W]
244
+ intrinsics: [B, 3, 3]
245
+ pose: [B, 4, 4]
246
+ depth: [B, D, H, W]
247
+ """
248
+
249
+ assert intrinsics.size(1) == intrinsics.size(2) == 3
250
+ assert pose.size(1) == pose.size(2) == 4
251
+ assert depth.dim() == 4
252
+
253
+ b, d, h, w = depth.size()
254
+ c = feature1.size(1)
255
+
256
+ with torch.no_grad():
257
+ # pixel coordinates
258
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
259
+ # back project to 3D and transform viewpoint
260
+ points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W]
261
+ points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat(
262
+ 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W]
263
+ points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W]
264
+ # reproject to 2D image plane
265
+ points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W]
266
+ pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W]
267
+
268
+ # normalize to [-1, 1]
269
+ x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1
270
+ y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1
271
+
272
+ grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2]
273
+
274
+ # sample features
275
+ warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear',
276
+ padding_mode='zeros',
277
+ align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W]
278
+
279
+ return warped_feature
unimatch/position.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+
9
+ class PositionEmbeddingSine(nn.Module):
10
+ """
11
+ This is a more standard version of the position embedding, very similar to the one
12
+ used by the Attention is all you need paper, generalized to work on images.
13
+ """
14
+
15
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
16
+ super().__init__()
17
+ self.num_pos_feats = num_pos_feats
18
+ self.temperature = temperature
19
+ self.normalize = normalize
20
+ if scale is not None and normalize is False:
21
+ raise ValueError("normalize should be True if scale is passed")
22
+ if scale is None:
23
+ scale = 2 * math.pi
24
+ self.scale = scale
25
+
26
+ def forward(self, x):
27
+ # x = tensor_list.tensors # [B, C, H, W]
28
+ # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0
29
+ b, c, h, w = x.size()
30
+ mask = torch.ones((b, h, w), device=x.device) # [B, H, W]
31
+ y_embed = mask.cumsum(1, dtype=torch.float32)
32
+ x_embed = mask.cumsum(2, dtype=torch.float32)
33
+ if self.normalize:
34
+ eps = 1e-6
35
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
36
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
37
+
38
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40
+
41
+ pos_x = x_embed[:, :, :, None] / dim_t
42
+ pos_y = y_embed[:, :, :, None] / dim_t
43
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
44
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
45
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
46
+ return pos
unimatch/reg_refine.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class FlowHead(nn.Module):
7
+ def __init__(self, input_dim=128, hidden_dim=256,
8
+ out_dim=2,
9
+ ):
10
+ super(FlowHead, self).__init__()
11
+
12
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
13
+ self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1)
14
+ self.relu = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+ out = self.conv2(self.relu(self.conv1(x)))
18
+
19
+ return out
20
+
21
+
22
+ class SepConvGRU(nn.Module):
23
+ def __init__(self, hidden_dim=128, input_dim=192 + 128,
24
+ kernel_size=5,
25
+ ):
26
+ padding = (kernel_size - 1) // 2
27
+
28
+ super(SepConvGRU, self).__init__()
29
+ self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
30
+ self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
31
+ self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding))
32
+
33
+ self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
34
+ self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
35
+ self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0))
36
+
37
+ def forward(self, h, x):
38
+ # horizontal
39
+ hx = torch.cat([h, x], dim=1)
40
+ z = torch.sigmoid(self.convz1(hx))
41
+ r = torch.sigmoid(self.convr1(hx))
42
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
43
+ h = (1 - z) * h + z * q
44
+
45
+ # vertical
46
+ hx = torch.cat([h, x], dim=1)
47
+ z = torch.sigmoid(self.convz2(hx))
48
+ r = torch.sigmoid(self.convr2(hx))
49
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
50
+ h = (1 - z) * h + z * q
51
+
52
+ return h
53
+
54
+
55
+ class BasicMotionEncoder(nn.Module):
56
+ def __init__(self, corr_channels=324,
57
+ flow_channels=2,
58
+ ):
59
+ super(BasicMotionEncoder, self).__init__()
60
+
61
+ self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0)
62
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
63
+ self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3)
64
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
65
+ self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1)
66
+
67
+ def forward(self, flow, corr):
68
+ cor = F.relu(self.convc1(corr))
69
+ cor = F.relu(self.convc2(cor))
70
+ flo = F.relu(self.convf1(flow))
71
+ flo = F.relu(self.convf2(flo))
72
+
73
+ cor_flo = torch.cat([cor, flo], dim=1)
74
+ out = F.relu(self.conv(cor_flo))
75
+ return torch.cat([out, flow], dim=1)
76
+
77
+
78
+ class BasicUpdateBlock(nn.Module):
79
+ def __init__(self, corr_channels=324,
80
+ hidden_dim=128,
81
+ context_dim=128,
82
+ downsample_factor=8,
83
+ flow_dim=2,
84
+ bilinear_up=False,
85
+ ):
86
+ super(BasicUpdateBlock, self).__init__()
87
+
88
+ self.encoder = BasicMotionEncoder(corr_channels=corr_channels,
89
+ flow_channels=flow_dim,
90
+ )
91
+
92
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim)
93
+
94
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256,
95
+ out_dim=flow_dim,
96
+ )
97
+
98
+ if bilinear_up:
99
+ self.mask = None
100
+ else:
101
+ self.mask = nn.Sequential(
102
+ nn.Conv2d(hidden_dim, 256, 3, padding=1),
103
+ nn.ReLU(inplace=True),
104
+ nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0))
105
+
106
+ def forward(self, net, inp, corr, flow):
107
+ motion_features = self.encoder(flow, corr)
108
+
109
+ inp = torch.cat([inp, motion_features], dim=1)
110
+
111
+ net = self.gru(net, inp)
112
+ delta_flow = self.flow_head(net)
113
+
114
+ if self.mask is not None:
115
+ mask = self.mask(net)
116
+ else:
117
+ mask = None
118
+
119
+ return net, mask, delta_flow
unimatch/transformer.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .attention import (single_head_full_attention, single_head_split_window_attention,
5
+ single_head_full_attention_1d, single_head_split_window_attention_1d)
6
+ from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d
7
+
8
+
9
+ class TransformerLayer(nn.Module):
10
+ def __init__(self,
11
+ d_model=128,
12
+ nhead=1,
13
+ no_ffn=False,
14
+ ffn_dim_expansion=4,
15
+ ):
16
+ super(TransformerLayer, self).__init__()
17
+
18
+ self.dim = d_model
19
+ self.nhead = nhead
20
+ self.no_ffn = no_ffn
21
+
22
+ # multi-head attention
23
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
24
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
25
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
26
+
27
+ self.merge = nn.Linear(d_model, d_model, bias=False)
28
+
29
+ self.norm1 = nn.LayerNorm(d_model)
30
+
31
+ # no ffn after self-attn, with ffn after cross-attn
32
+ if not self.no_ffn:
33
+ in_channels = d_model * 2
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False),
36
+ nn.GELU(),
37
+ nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False),
38
+ )
39
+
40
+ self.norm2 = nn.LayerNorm(d_model)
41
+
42
+ def forward(self, source, target,
43
+ height=None,
44
+ width=None,
45
+ shifted_window_attn_mask=None,
46
+ shifted_window_attn_mask_1d=None,
47
+ attn_type='swin',
48
+ with_shift=False,
49
+ attn_num_splits=None,
50
+ ):
51
+ # source, target: [B, L, C]
52
+ query, key, value = source, target, target
53
+
54
+ # for stereo: 2d attn in self-attn, 1d attn in cross-attn
55
+ is_self_attn = (query - key).abs().max() < 1e-6
56
+
57
+ # single-head attention
58
+ query = self.q_proj(query) # [B, L, C]
59
+ key = self.k_proj(key) # [B, L, C]
60
+ value = self.v_proj(value) # [B, L, C]
61
+
62
+ if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d
63
+ if self.nhead > 1:
64
+ # we observe that multihead attention slows down the speed and increases the memory consumption
65
+ # without bringing obvious performance gains and thus the implementation is removed
66
+ raise NotImplementedError
67
+ else:
68
+ message = single_head_split_window_attention(query, key, value,
69
+ num_splits=attn_num_splits,
70
+ with_shift=with_shift,
71
+ h=height,
72
+ w=width,
73
+ attn_mask=shifted_window_attn_mask,
74
+ )
75
+
76
+ elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d
77
+ if self.nhead > 1:
78
+ raise NotImplementedError
79
+ else:
80
+ if is_self_attn:
81
+ if attn_num_splits > 1:
82
+ message = single_head_split_window_attention(query, key, value,
83
+ num_splits=attn_num_splits,
84
+ with_shift=with_shift,
85
+ h=height,
86
+ w=width,
87
+ attn_mask=shifted_window_attn_mask,
88
+ )
89
+ else:
90
+ # full 2d attn
91
+ message = single_head_full_attention(query, key, value) # [N, L, C]
92
+
93
+ else:
94
+ # cross attn 1d
95
+ message = single_head_full_attention_1d(query, key, value,
96
+ h=height,
97
+ w=width,
98
+ )
99
+
100
+ elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d
101
+ if self.nhead > 1:
102
+ raise NotImplementedError
103
+ else:
104
+ if is_self_attn:
105
+ if attn_num_splits > 1:
106
+ # self attn shift window
107
+ message = single_head_split_window_attention(query, key, value,
108
+ num_splits=attn_num_splits,
109
+ with_shift=with_shift,
110
+ h=height,
111
+ w=width,
112
+ attn_mask=shifted_window_attn_mask,
113
+ )
114
+ else:
115
+ # full 2d attn
116
+ message = single_head_full_attention(query, key, value) # [N, L, C]
117
+ else:
118
+ if attn_num_splits > 1:
119
+ assert shifted_window_attn_mask_1d is not None
120
+ # cross attn 1d shift
121
+ message = single_head_split_window_attention_1d(query, key, value,
122
+ num_splits=attn_num_splits,
123
+ with_shift=with_shift,
124
+ h=height,
125
+ w=width,
126
+ attn_mask=shifted_window_attn_mask_1d,
127
+ )
128
+ else:
129
+ message = single_head_full_attention_1d(query, key, value,
130
+ h=height,
131
+ w=width,
132
+ )
133
+
134
+ else:
135
+ message = single_head_full_attention(query, key, value) # [B, L, C]
136
+
137
+ message = self.merge(message) # [B, L, C]
138
+ message = self.norm1(message)
139
+
140
+ if not self.no_ffn:
141
+ message = self.mlp(torch.cat([source, message], dim=-1))
142
+ message = self.norm2(message)
143
+
144
+ return source + message
145
+
146
+
147
+ class TransformerBlock(nn.Module):
148
+ """self attention + cross attention + FFN"""
149
+
150
+ def __init__(self,
151
+ d_model=128,
152
+ nhead=1,
153
+ ffn_dim_expansion=4,
154
+ ):
155
+ super(TransformerBlock, self).__init__()
156
+
157
+ self.self_attn = TransformerLayer(d_model=d_model,
158
+ nhead=nhead,
159
+ no_ffn=True,
160
+ ffn_dim_expansion=ffn_dim_expansion,
161
+ )
162
+
163
+ self.cross_attn_ffn = TransformerLayer(d_model=d_model,
164
+ nhead=nhead,
165
+ ffn_dim_expansion=ffn_dim_expansion,
166
+ )
167
+
168
+ def forward(self, source, target,
169
+ height=None,
170
+ width=None,
171
+ shifted_window_attn_mask=None,
172
+ shifted_window_attn_mask_1d=None,
173
+ attn_type='swin',
174
+ with_shift=False,
175
+ attn_num_splits=None,
176
+ ):
177
+ # source, target: [B, L, C]
178
+
179
+ # self attention
180
+ source = self.self_attn(source, source,
181
+ height=height,
182
+ width=width,
183
+ shifted_window_attn_mask=shifted_window_attn_mask,
184
+ attn_type=attn_type,
185
+ with_shift=with_shift,
186
+ attn_num_splits=attn_num_splits,
187
+ )
188
+
189
+ # cross attention and ffn
190
+ source = self.cross_attn_ffn(source, target,
191
+ height=height,
192
+ width=width,
193
+ shifted_window_attn_mask=shifted_window_attn_mask,
194
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
195
+ attn_type=attn_type,
196
+ with_shift=with_shift,
197
+ attn_num_splits=attn_num_splits,
198
+ )
199
+
200
+ return source
201
+
202
+
203
+ class FeatureTransformer(nn.Module):
204
+ def __init__(self,
205
+ num_layers=6,
206
+ d_model=128,
207
+ nhead=1,
208
+ ffn_dim_expansion=4,
209
+ ):
210
+ super(FeatureTransformer, self).__init__()
211
+
212
+ self.d_model = d_model
213
+ self.nhead = nhead
214
+
215
+ self.layers = nn.ModuleList([
216
+ TransformerBlock(d_model=d_model,
217
+ nhead=nhead,
218
+ ffn_dim_expansion=ffn_dim_expansion,
219
+ )
220
+ for i in range(num_layers)])
221
+
222
+ for p in self.parameters():
223
+ if p.dim() > 1:
224
+ nn.init.xavier_uniform_(p)
225
+
226
+ def forward(self, feature0, feature1,
227
+ attn_type='swin',
228
+ attn_num_splits=None,
229
+ **kwargs,
230
+ ):
231
+
232
+ b, c, h, w = feature0.shape
233
+ assert self.d_model == c
234
+
235
+ feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
236
+ feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C]
237
+
238
+ # 2d attention
239
+ if 'swin' in attn_type and attn_num_splits > 1:
240
+ # global and refine use different number of splits
241
+ window_size_h = h // attn_num_splits
242
+ window_size_w = w // attn_num_splits
243
+
244
+ # compute attn mask once
245
+ shifted_window_attn_mask = generate_shift_window_attn_mask(
246
+ input_resolution=(h, w),
247
+ window_size_h=window_size_h,
248
+ window_size_w=window_size_w,
249
+ shift_size_h=window_size_h // 2,
250
+ shift_size_w=window_size_w // 2,
251
+ device=feature0.device,
252
+ ) # [K*K, H/K*W/K, H/K*W/K]
253
+ else:
254
+ shifted_window_attn_mask = None
255
+
256
+ # 1d attention
257
+ if 'swin1d' in attn_type and attn_num_splits > 1:
258
+ window_size_w = w // attn_num_splits
259
+
260
+ # compute attn mask once
261
+ shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d(
262
+ input_w=w,
263
+ window_size_w=window_size_w,
264
+ shift_size_w=window_size_w // 2,
265
+ device=feature0.device,
266
+ ) # [K, W/K, W/K]
267
+ else:
268
+ shifted_window_attn_mask_1d = None
269
+
270
+ # concat feature0 and feature1 in batch dimension to compute in parallel
271
+ concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C]
272
+ concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C]
273
+
274
+ for i, layer in enumerate(self.layers):
275
+ concat0 = layer(concat0, concat1,
276
+ height=h,
277
+ width=w,
278
+ attn_type=attn_type,
279
+ with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1,
280
+ attn_num_splits=attn_num_splits,
281
+ shifted_window_attn_mask=shifted_window_attn_mask,
282
+ shifted_window_attn_mask_1d=shifted_window_attn_mask_1d,
283
+ )
284
+
285
+ # update feature1
286
+ concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0)
287
+
288
+ feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C]
289
+
290
+ # reshape back
291
+ feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
292
+ feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
293
+
294
+ return feature0, feature1
unimatch/trident_conv.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn.modules.utils import _pair
8
+
9
+
10
+ class MultiScaleTridentConv(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride=1,
17
+ strides=1,
18
+ paddings=0,
19
+ dilations=1,
20
+ dilation=1,
21
+ groups=1,
22
+ num_branch=1,
23
+ test_branch_idx=-1,
24
+ bias=False,
25
+ norm=None,
26
+ activation=None,
27
+ ):
28
+ super(MultiScaleTridentConv, self).__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = _pair(kernel_size)
32
+ self.num_branch = num_branch
33
+ self.stride = _pair(stride)
34
+ self.groups = groups
35
+ self.with_bias = bias
36
+ self.dilation = dilation
37
+ if isinstance(paddings, int):
38
+ paddings = [paddings] * self.num_branch
39
+ if isinstance(dilations, int):
40
+ dilations = [dilations] * self.num_branch
41
+ if isinstance(strides, int):
42
+ strides = [strides] * self.num_branch
43
+ self.paddings = [_pair(padding) for padding in paddings]
44
+ self.dilations = [_pair(dilation) for dilation in dilations]
45
+ self.strides = [_pair(stride) for stride in strides]
46
+ self.test_branch_idx = test_branch_idx
47
+ self.norm = norm
48
+ self.activation = activation
49
+
50
+ assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1
51
+
52
+ self.weight = nn.Parameter(
53
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
54
+ )
55
+ if bias:
56
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
57
+ else:
58
+ self.bias = None
59
+
60
+ nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
61
+ if self.bias is not None:
62
+ nn.init.constant_(self.bias, 0)
63
+
64
+ def forward(self, inputs):
65
+ num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
66
+ assert len(inputs) == num_branch
67
+
68
+ if self.training or self.test_branch_idx == -1:
69
+ outputs = [
70
+ F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups)
71
+ for input, stride, padding in zip(inputs, self.strides, self.paddings)
72
+ ]
73
+ else:
74
+ outputs = [
75
+ F.conv2d(
76
+ inputs[0],
77
+ self.weight,
78
+ self.bias,
79
+ self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1],
80
+ self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1],
81
+ self.dilation,
82
+ self.groups,
83
+ )
84
+ ]
85
+
86
+ if self.norm is not None:
87
+ outputs = [self.norm(x) for x in outputs]
88
+ if self.activation is not None:
89
+ outputs = [self.activation(x) for x in outputs]
90
+ return outputs
unimatch/unimatch.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .backbone import CNNEncoder
6
+ from .transformer import FeatureTransformer
7
+ from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow,
8
+ global_correlation_softmax_stereo, local_correlation_softmax_stereo,
9
+ correlation_softmax_depth)
10
+ from .attention import SelfAttnPropagation
11
+ from .geometry import flow_warp, compute_flow_with_depth_pose
12
+ from .reg_refine import BasicUpdateBlock
13
+ from .utils import normalize_img, feature_add_position, upsample_flow_with_mask
14
+
15
+
16
+ class UniMatch(nn.Module):
17
+ def __init__(self,
18
+ num_scales=1,
19
+ feature_channels=128,
20
+ upsample_factor=8,
21
+ num_head=1,
22
+ ffn_dim_expansion=4,
23
+ num_transformer_layers=6,
24
+ reg_refine=False, # optional local regression refinement
25
+ task='flow',
26
+ ):
27
+ super(UniMatch, self).__init__()
28
+
29
+ self.feature_channels = feature_channels
30
+ self.num_scales = num_scales
31
+ self.upsample_factor = upsample_factor
32
+ self.reg_refine = reg_refine
33
+
34
+ # CNN
35
+ self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
36
+
37
+ # Transformer
38
+ self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
39
+ d_model=feature_channels,
40
+ nhead=num_head,
41
+ ffn_dim_expansion=ffn_dim_expansion,
42
+ )
43
+
44
+ # propagation with self-attn
45
+ self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
46
+
47
+ if not self.reg_refine or task == 'depth':
48
+ # convex upsampling simiar to RAFT
49
+ # concat feature0 and low res flow as input
50
+ self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
51
+ nn.ReLU(inplace=True),
52
+ nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
53
+ # thus far, all the learnable parameters are task-agnostic
54
+
55
+ if reg_refine:
56
+ # optional task-specific local regression refinement
57
+ self.refine_proj = nn.Conv2d(128, 256, 1)
58
+ self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2,
59
+ downsample_factor=upsample_factor,
60
+ flow_dim=2 if task == 'flow' else 1,
61
+ bilinear_up=task == 'depth',
62
+ )
63
+
64
+ def extract_feature(self, img0, img1):
65
+ concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
66
+ features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
67
+
68
+ # reverse: resolution from low to high
69
+ features = features[::-1]
70
+
71
+ feature0, feature1 = [], []
72
+
73
+ for i in range(len(features)):
74
+ feature = features[i]
75
+ chunks = torch.chunk(feature, 2, 0) # tuple
76
+ feature0.append(chunks[0])
77
+ feature1.append(chunks[1])
78
+
79
+ return feature0, feature1
80
+
81
+ def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
82
+ is_depth=False):
83
+ if bilinear:
84
+ multiplier = 1 if is_depth else upsample_factor
85
+ up_flow = F.interpolate(flow, scale_factor=upsample_factor,
86
+ mode='bilinear', align_corners=True) * multiplier
87
+ else:
88
+ concat = torch.cat((flow, feature), dim=1)
89
+ mask = self.upsampler(concat)
90
+ up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
91
+ is_depth=is_depth)
92
+
93
+ return up_flow
94
+
95
+ def forward(self, img0, img1,
96
+ attn_type=None,
97
+ attn_splits_list=None,
98
+ corr_radius_list=None,
99
+ prop_radius_list=None,
100
+ num_reg_refine=1,
101
+ pred_bidir_flow=False,
102
+ task='flow',
103
+ intrinsics=None,
104
+ pose=None, # relative pose transform
105
+ min_depth=1. / 0.5, # inverse depth range
106
+ max_depth=1. / 10,
107
+ num_depth_candidates=64,
108
+ depth_from_argmax=False,
109
+ pred_bidir_depth=False,
110
+ **kwargs,
111
+ ):
112
+
113
+ if pred_bidir_flow:
114
+ assert task == 'flow'
115
+
116
+ if task == 'depth':
117
+ assert self.num_scales == 1 # multi-scale depth model is not supported yet
118
+
119
+ results_dict = {}
120
+ flow_preds = []
121
+
122
+ if task == 'flow':
123
+ # stereo and depth tasks have normalized img in dataloader
124
+ img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
125
+
126
+ # list of features, resolution low to high
127
+ feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
128
+
129
+ flow = None
130
+
131
+ if task != 'depth':
132
+ assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
133
+ else:
134
+ assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1
135
+
136
+ for scale_idx in range(self.num_scales):
137
+ feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
138
+
139
+ if pred_bidir_flow and scale_idx > 0:
140
+ # predicting bidirectional flow with refinement
141
+ feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
142
+
143
+ feature0_ori, feature1_ori = feature0, feature1
144
+
145
+ upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
146
+
147
+ if task == 'depth':
148
+ # scale intrinsics
149
+ intrinsics_curr = intrinsics.clone()
150
+ intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
151
+
152
+ if scale_idx > 0:
153
+ assert task != 'depth' # not supported for multi-scale depth model
154
+ flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
155
+
156
+ if flow is not None:
157
+ assert task != 'depth'
158
+ flow = flow.detach()
159
+
160
+ if task == 'stereo':
161
+ # construct flow vector for disparity
162
+ # flow here is actually disparity
163
+ zeros = torch.zeros_like(flow) # [B, 1, H, W]
164
+ # NOTE: reverse disp, disparity is positive
165
+ displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
166
+ feature1 = flow_warp(feature1, displace) # [B, C, H, W]
167
+ elif task == 'flow':
168
+ feature1 = flow_warp(feature1, flow) # [B, C, H, W]
169
+ else:
170
+ raise NotImplementedError
171
+
172
+ attn_splits = attn_splits_list[scale_idx]
173
+ if task != 'depth':
174
+ corr_radius = corr_radius_list[scale_idx]
175
+ prop_radius = prop_radius_list[scale_idx]
176
+
177
+ # add position to features
178
+ feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
179
+
180
+ # Transformer
181
+ feature0, feature1 = self.transformer(feature0, feature1,
182
+ attn_type=attn_type,
183
+ attn_num_splits=attn_splits,
184
+ )
185
+
186
+ # correlation and softmax
187
+ if task == 'depth':
188
+ # first generate depth candidates
189
+ b, _, h, w = feature0.size()
190
+ depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0)
191
+ depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h,
192
+ w) # [B, D, H, W]
193
+
194
+ flow_pred = correlation_softmax_depth(feature0, feature1,
195
+ intrinsics_curr,
196
+ pose,
197
+ depth_candidates=depth_candidates,
198
+ depth_from_argmax=depth_from_argmax,
199
+ pred_bidir_depth=pred_bidir_depth,
200
+ )[0]
201
+
202
+ else:
203
+ if corr_radius == -1: # global matching
204
+ if task == 'flow':
205
+ flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
206
+ elif task == 'stereo':
207
+ flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0]
208
+ else:
209
+ raise NotImplementedError
210
+ else: # local matching
211
+ if task == 'flow':
212
+ flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
213
+ elif task == 'stereo':
214
+ flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0]
215
+ else:
216
+ raise NotImplementedError
217
+
218
+ # flow or residual flow
219
+ flow = flow + flow_pred if flow is not None else flow_pred
220
+
221
+ if task == 'stereo':
222
+ flow = flow.clamp(min=0) # positive disparity
223
+
224
+ # upsample to the original resolution for supervison at training time only
225
+ if self.training:
226
+ flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor,
227
+ is_depth=task == 'depth')
228
+ flow_preds.append(flow_bilinear)
229
+
230
+ # flow propagation with self-attn
231
+ if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0:
232
+ feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
233
+
234
+ flow = self.feature_flow_attn(feature0, flow.detach(),
235
+ local_window_attn=prop_radius > 0,
236
+ local_window_radius=prop_radius,
237
+ )
238
+
239
+ # bilinear exclude the last one
240
+ if self.training and scale_idx < self.num_scales - 1:
241
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
242
+ upsample_factor=upsample_factor,
243
+ is_depth=task == 'depth')
244
+ flow_preds.append(flow_up)
245
+
246
+ if scale_idx == self.num_scales - 1:
247
+ if not self.reg_refine:
248
+ # upsample to the original image resolution
249
+
250
+ if task == 'stereo':
251
+ flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
252
+ flow_up_pad = self.upsample_flow(flow_pad, feature0)
253
+ flow_up = -flow_up_pad[:, :1] # [B, 1, H, W]
254
+ elif task == 'depth':
255
+ depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
256
+ depth_up_pad = self.upsample_flow(depth_pad, feature0,
257
+ is_depth=True).clamp(min=min_depth, max=max_depth)
258
+ flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
259
+ else:
260
+ flow_up = self.upsample_flow(flow, feature0)
261
+
262
+ flow_preds.append(flow_up)
263
+ else:
264
+ # task-specific local regression refinement
265
+ # supervise current flow
266
+ if self.training:
267
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
268
+ upsample_factor=upsample_factor,
269
+ is_depth=task == 'depth')
270
+ flow_preds.append(flow_up)
271
+
272
+ assert num_reg_refine > 0
273
+ for refine_iter_idx in range(num_reg_refine):
274
+ flow = flow.detach()
275
+
276
+ if task == 'stereo':
277
+ zeros = torch.zeros_like(flow) # [B, 1, H, W]
278
+ # NOTE: reverse disp, disparity is positive
279
+ displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
280
+ correlation = local_correlation_with_flow(
281
+ feature0_ori,
282
+ feature1_ori,
283
+ flow=displace,
284
+ local_radius=4,
285
+ ) # [B, (2R+1)^2, H, W]
286
+ elif task == 'depth':
287
+ if pred_bidir_depth and refine_iter_idx == 0:
288
+ intrinsics_curr = intrinsics_curr.repeat(2, 1, 1)
289
+ pose = torch.cat((pose, torch.inverse(pose)), dim=0)
290
+
291
+ feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori),
292
+ dim=0), torch.cat((feature1_ori,
293
+ feature0_ori), dim=0)
294
+
295
+ flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
296
+ intrinsics_curr,
297
+ extrinsics_rel=pose,
298
+ )
299
+
300
+ correlation = local_correlation_with_flow(
301
+ feature0_ori,
302
+ feature1_ori,
303
+ flow=flow_from_depth,
304
+ local_radius=4,
305
+ ) # [B, (2R+1)^2, H, W]
306
+
307
+ else:
308
+ correlation = local_correlation_with_flow(
309
+ feature0_ori,
310
+ feature1_ori,
311
+ flow=flow,
312
+ local_radius=4,
313
+ ) # [B, (2R+1)^2, H, W]
314
+
315
+ proj = self.refine_proj(feature0)
316
+
317
+ net, inp = torch.chunk(proj, chunks=2, dim=1)
318
+
319
+ net = torch.tanh(net)
320
+ inp = torch.relu(inp)
321
+
322
+ net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(),
323
+ )
324
+
325
+ if task == 'depth':
326
+ flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth)
327
+ else:
328
+ flow = flow + residual_flow
329
+
330
+ if task == 'stereo':
331
+ flow = flow.clamp(min=0) # positive
332
+
333
+ if self.training or refine_iter_idx == num_reg_refine - 1:
334
+ if task == 'depth':
335
+ if refine_iter_idx < num_reg_refine - 1:
336
+ # bilinear upsampling
337
+ flow_up = self.upsample_flow(flow, feature0, bilinear=True,
338
+ upsample_factor=upsample_factor,
339
+ is_depth=True)
340
+ else:
341
+ # last one convex upsampling
342
+ # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling
343
+ # pad depth to 2 channels as flow
344
+ depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
345
+ depth_up_pad = self.upsample_flow(depth_pad, feature0,
346
+ is_depth=True).clamp(min=min_depth,
347
+ max=max_depth)
348
+ flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
349
+
350
+ else:
351
+ flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor,
352
+ is_depth=task == 'depth')
353
+
354
+ flow_preds.append(flow_up)
355
+
356
+ if task == 'stereo':
357
+ for i in range(len(flow_preds)):
358
+ flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W]
359
+
360
+ # convert inverse depth to depth
361
+ if task == 'depth':
362
+ for i in range(len(flow_preds)):
363
+ flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W]
364
+
365
+ results_dict.update({'flow_preds': flow_preds})
366
+
367
+ return results_dict
unimatch/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .position import PositionEmbeddingSine
4
+
5
+
6
+ def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None):
7
+ assert device is not None
8
+
9
+ x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device),
10
+ torch.linspace(h_min, h_max, len_h, device=device)],
11
+ )
12
+ grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2]
13
+
14
+ return grid
15
+
16
+
17
+ def normalize_coords(coords, h, w):
18
+ # coords: [B, H, W, 2]
19
+ c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device)
20
+ return (coords - c) / c # [-1, 1]
21
+
22
+
23
+ def normalize_img(img0, img1):
24
+ # loaded images are in [0, 255]
25
+ # normalize by ImageNet mean and std
26
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
27
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
28
+ img0 = (img0 / 255. - mean) / std
29
+ img1 = (img1 / 255. - mean) / std
30
+
31
+ return img0, img1
32
+
33
+
34
+ def split_feature(feature,
35
+ num_splits=2,
36
+ channel_last=False,
37
+ ):
38
+ if channel_last: # [B, H, W, C]
39
+ b, h, w, c = feature.size()
40
+ assert h % num_splits == 0 and w % num_splits == 0
41
+
42
+ b_new = b * num_splits * num_splits
43
+ h_new = h // num_splits
44
+ w_new = w // num_splits
45
+
46
+ feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
47
+ ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
48
+ else: # [B, C, H, W]
49
+ b, c, h, w = feature.size()
50
+ assert h % num_splits == 0 and w % num_splits == 0
51
+
52
+ b_new = b * num_splits * num_splits
53
+ h_new = h // num_splits
54
+ w_new = w // num_splits
55
+
56
+ feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
57
+ ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
58
+
59
+ return feature
60
+
61
+
62
+ def merge_splits(splits,
63
+ num_splits=2,
64
+ channel_last=False,
65
+ ):
66
+ if channel_last: # [B*K*K, H/K, W/K, C]
67
+ b, h, w, c = splits.size()
68
+ new_b = b // num_splits // num_splits
69
+
70
+ splits = splits.view(new_b, num_splits, num_splits, h, w, c)
71
+ merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
72
+ new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
73
+ else: # [B*K*K, C, H/K, W/K]
74
+ b, c, h, w = splits.size()
75
+ new_b = b // num_splits // num_splits
76
+
77
+ splits = splits.view(new_b, num_splits, num_splits, c, h, w)
78
+ merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
79
+ new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
80
+
81
+ return merge
82
+
83
+
84
+ def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w,
85
+ shift_size_h, shift_size_w, device=torch.device('cuda')):
86
+ # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
87
+ # calculate attention mask for SW-MSA
88
+ h, w = input_resolution
89
+ img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1
90
+ h_slices = (slice(0, -window_size_h),
91
+ slice(-window_size_h, -shift_size_h),
92
+ slice(-shift_size_h, None))
93
+ w_slices = (slice(0, -window_size_w),
94
+ slice(-window_size_w, -shift_size_w),
95
+ slice(-shift_size_w, None))
96
+ cnt = 0
97
+ for h in h_slices:
98
+ for w in w_slices:
99
+ img_mask[:, h, w, :] = cnt
100
+ cnt += 1
101
+
102
+ mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True)
103
+
104
+ mask_windows = mask_windows.view(-1, window_size_h * window_size_w)
105
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
106
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
107
+
108
+ return attn_mask
109
+
110
+
111
+ def feature_add_position(feature0, feature1, attn_splits, feature_channels):
112
+ pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
113
+
114
+ if attn_splits > 1: # add position in splited window
115
+ feature0_splits = split_feature(feature0, num_splits=attn_splits)
116
+ feature1_splits = split_feature(feature1, num_splits=attn_splits)
117
+
118
+ position = pos_enc(feature0_splits)
119
+
120
+ feature0_splits = feature0_splits + position
121
+ feature1_splits = feature1_splits + position
122
+
123
+ feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
124
+ feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
125
+ else:
126
+ position = pos_enc(feature0)
127
+
128
+ feature0 = feature0 + position
129
+ feature1 = feature1 + position
130
+
131
+ return feature0, feature1
132
+
133
+
134
+ def upsample_flow_with_mask(flow, up_mask, upsample_factor,
135
+ is_depth=False):
136
+ # convex upsampling following raft
137
+
138
+ mask = up_mask
139
+ b, flow_channel, h, w = flow.shape
140
+ mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W]
141
+ mask = torch.softmax(mask, dim=2)
142
+
143
+ multiplier = 1 if is_depth else upsample_factor
144
+ up_flow = F.unfold(multiplier * flow, [3, 3], padding=1)
145
+ up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W]
146
+
147
+ up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W]
148
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W]
149
+ up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h,
150
+ upsample_factor * w) # [B, 2, K*H, K*W]
151
+
152
+ return up_flow
153
+
154
+
155
+ def split_feature_1d(feature,
156
+ num_splits=2,
157
+ ):
158
+ # feature: [B, W, C]
159
+ b, w, c = feature.size()
160
+ assert w % num_splits == 0
161
+
162
+ b_new = b * num_splits
163
+ w_new = w // num_splits
164
+
165
+ feature = feature.view(b, num_splits, w // num_splits, c
166
+ ).view(b_new, w_new, c) # [B*K, W/K, C]
167
+
168
+ return feature
169
+
170
+
171
+ def merge_splits_1d(splits,
172
+ h,
173
+ num_splits=2,
174
+ ):
175
+ b, w, c = splits.size()
176
+ new_b = b // num_splits // h
177
+
178
+ splits = splits.view(new_b, h, num_splits, w, c)
179
+ merge = splits.view(
180
+ new_b, h, num_splits * w, c) # [B, H, W, C]
181
+
182
+ return merge
183
+
184
+
185
+ def window_partition_1d(x, window_size_w):
186
+ """
187
+ Args:
188
+ x: (B, W, C)
189
+ window_size (int): window size
190
+
191
+ Returns:
192
+ windows: (num_windows*B, window_size, C)
193
+ """
194
+ B, W, C = x.shape
195
+ x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C)
196
+ return x
197
+
198
+
199
+ def generate_shift_window_attn_mask_1d(input_w, window_size_w,
200
+ shift_size_w, device=torch.device('cuda')):
201
+ # calculate attention mask for SW-MSA
202
+ img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1
203
+ w_slices = (slice(0, -window_size_w),
204
+ slice(-window_size_w, -shift_size_w),
205
+ slice(-shift_size_w, None))
206
+ cnt = 0
207
+ for w in w_slices:
208
+ img_mask[:, w, :] = cnt
209
+ cnt += 1
210
+
211
+ mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1
212
+ mask_windows = mask_windows.view(-1, window_size_w)
213
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size
214
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
215
+
216
+ return attn_mask
utils/flow_viz.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) 2018 Tom Runia
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to conditions.
11
+ #
12
+ # Author: Tom Runia
13
+ # Date Created: 2018-08-03
14
+
15
+ from __future__ import absolute_import
16
+ from __future__ import division
17
+ from __future__ import print_function
18
+
19
+ import numpy as np
20
+ from PIL import Image
21
+
22
+
23
+ def make_colorwheel():
24
+ '''
25
+ Generates a color wheel for optical flow visualization as presented in:
26
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
27
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
28
+ According to the C++ source code of Daniel Scharstein
29
+ According to the Matlab source code of Deqing Sun
30
+ '''
31
+
32
+ RY = 15
33
+ YG = 6
34
+ GC = 4
35
+ CB = 11
36
+ BM = 13
37
+ MR = 6
38
+
39
+ ncols = RY + YG + GC + CB + BM + MR
40
+ colorwheel = np.zeros((ncols, 3))
41
+ col = 0
42
+
43
+ # RY
44
+ colorwheel[0:RY, 0] = 255
45
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
46
+ col = col + RY
47
+ # YG
48
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
49
+ colorwheel[col:col + YG, 1] = 255
50
+ col = col + YG
51
+ # GC
52
+ colorwheel[col:col + GC, 1] = 255
53
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
54
+ col = col + GC
55
+ # CB
56
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
57
+ colorwheel[col:col + CB, 2] = 255
58
+ col = col + CB
59
+ # BM
60
+ colorwheel[col:col + BM, 2] = 255
61
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
62
+ col = col + BM
63
+ # MR
64
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
65
+ colorwheel[col:col + MR, 0] = 255
66
+ return colorwheel
67
+
68
+
69
+ def flow_compute_color(u, v, convert_to_bgr=False):
70
+ '''
71
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
72
+ According to the C++ source code of Daniel Scharstein
73
+ According to the Matlab source code of Deqing Sun
74
+ :param u: np.ndarray, input horizontal flow
75
+ :param v: np.ndarray, input vertical flow
76
+ :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
77
+ :return:
78
+ '''
79
+
80
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
81
+
82
+ colorwheel = make_colorwheel() # shape [55x3]
83
+ ncols = colorwheel.shape[0]
84
+
85
+ rad = np.sqrt(np.square(u) + np.square(v))
86
+ a = np.arctan2(-v, -u) / np.pi
87
+
88
+ fk = (a + 1) / 2 * (ncols - 1) + 1
89
+ k0 = np.floor(fk).astype(np.int32)
90
+ k1 = k0 + 1
91
+ k1[k1 == ncols] = 1
92
+ f = fk - k0
93
+
94
+ for i in range(colorwheel.shape[1]):
95
+ tmp = colorwheel[:, i]
96
+ col0 = tmp[k0] / 255.0
97
+ col1 = tmp[k1] / 255.0
98
+ col = (1 - f) * col0 + f * col1
99
+
100
+ idx = (rad <= 1)
101
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
102
+ col[~idx] = col[~idx] * 0.75 # out of range?
103
+
104
+ # Note the 2-i => BGR instead of RGB
105
+ ch_idx = 2 - i if convert_to_bgr else i
106
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
107
+
108
+ return flow_image
109
+
110
+
111
+ def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
112
+ '''
113
+ Expects a two dimensional flow image of shape [H,W,2]
114
+ According to the C++ source code of Daniel Scharstein
115
+ According to the Matlab source code of Deqing Sun
116
+ :param flow_uv: np.ndarray of shape [H,W,2]
117
+ :param clip_flow: float, maximum clipping value for flow
118
+ :return:
119
+ '''
120
+
121
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123
+
124
+ if clip_flow is not None:
125
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
126
+
127
+ u = flow_uv[:, :, 0]
128
+ v = flow_uv[:, :, 1]
129
+
130
+ rad = np.sqrt(np.square(u) + np.square(v))
131
+ rad_max = np.max(rad)
132
+
133
+ epsilon = 1e-5
134
+ u = u / (rad_max + epsilon)
135
+ v = v / (rad_max + epsilon)
136
+
137
+ return flow_compute_color(u, v, convert_to_bgr)
138
+
139
+
140
+ UNKNOWN_FLOW_THRESH = 1e7
141
+ SMALLFLOW = 0.0
142
+ LARGEFLOW = 1e8
143
+
144
+
145
+ def make_color_wheel():
146
+ """
147
+ Generate color wheel according Middlebury color code
148
+ :return: Color wheel
149
+ """
150
+ RY = 15
151
+ YG = 6
152
+ GC = 4
153
+ CB = 11
154
+ BM = 13
155
+ MR = 6
156
+
157
+ ncols = RY + YG + GC + CB + BM + MR
158
+
159
+ colorwheel = np.zeros([ncols, 3])
160
+
161
+ col = 0
162
+
163
+ # RY
164
+ colorwheel[0:RY, 0] = 255
165
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
166
+ col += RY
167
+
168
+ # YG
169
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
170
+ colorwheel[col:col + YG, 1] = 255
171
+ col += YG
172
+
173
+ # GC
174
+ colorwheel[col:col + GC, 1] = 255
175
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
176
+ col += GC
177
+
178
+ # CB
179
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
180
+ colorwheel[col:col + CB, 2] = 255
181
+ col += CB
182
+
183
+ # BM
184
+ colorwheel[col:col + BM, 2] = 255
185
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
186
+ col += + BM
187
+
188
+ # MR
189
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
190
+ colorwheel[col:col + MR, 0] = 255
191
+
192
+ return colorwheel
193
+
194
+
195
+ def compute_color(u, v):
196
+ """
197
+ compute optical flow color map
198
+ :param u: optical flow horizontal map
199
+ :param v: optical flow vertical map
200
+ :return: optical flow in color code
201
+ """
202
+ [h, w] = u.shape
203
+ img = np.zeros([h, w, 3])
204
+ nanIdx = np.isnan(u) | np.isnan(v)
205
+ u[nanIdx] = 0
206
+ v[nanIdx] = 0
207
+
208
+ colorwheel = make_color_wheel()
209
+ ncols = np.size(colorwheel, 0)
210
+
211
+ rad = np.sqrt(u ** 2 + v ** 2)
212
+
213
+ a = np.arctan2(-v, -u) / np.pi
214
+
215
+ fk = (a + 1) / 2 * (ncols - 1) + 1
216
+
217
+ k0 = np.floor(fk).astype(int)
218
+
219
+ k1 = k0 + 1
220
+ k1[k1 == ncols + 1] = 1
221
+ f = fk - k0
222
+
223
+ for i in range(0, np.size(colorwheel, 1)):
224
+ tmp = colorwheel[:, i]
225
+ col0 = tmp[k0 - 1] / 255
226
+ col1 = tmp[k1 - 1] / 255
227
+ col = (1 - f) * col0 + f * col1
228
+
229
+ idx = rad <= 1
230
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
231
+ notidx = np.logical_not(idx)
232
+
233
+ col[notidx] *= 0.75
234
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
235
+
236
+ return img
237
+
238
+
239
+ # from https://github.com/gengshan-y/VCN
240
+ def flow_to_image(flow):
241
+ """
242
+ Convert flow into middlebury color code image
243
+ :param flow: optical flow map
244
+ :return: optical flow image in middlebury color
245
+ """
246
+ u = flow[:, :, 0]
247
+ v = flow[:, :, 1]
248
+
249
+ maxu = -999.
250
+ maxv = -999.
251
+ minu = 999.
252
+ minv = 999.
253
+
254
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
255
+ u[idxUnknow] = 0
256
+ v[idxUnknow] = 0
257
+
258
+ maxu = max(maxu, np.max(u))
259
+ minu = min(minu, np.min(u))
260
+
261
+ maxv = max(maxv, np.max(v))
262
+ minv = min(minv, np.min(v))
263
+
264
+ rad = np.sqrt(u ** 2 + v ** 2)
265
+ maxrad = max(-1, np.max(rad))
266
+
267
+ u = u / (maxrad + np.finfo(float).eps)
268
+ v = v / (maxrad + np.finfo(float).eps)
269
+
270
+ img = compute_color(u, v)
271
+
272
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
273
+ img[idx] = 0
274
+
275
+ return np.uint8(img)
276
+
277
+
278
+ def save_vis_flow_tofile(flow, output_path):
279
+ vis_flow = flow_to_image(flow)
280
+ Image.fromarray(vis_flow).save(output_path)
281
+
282
+
283
+ def flow_tensor_to_image(flow):
284
+ """Used for tensorboard visualization"""
285
+ flow = flow.permute(1, 2, 0) # [H, W, 2]
286
+ flow = flow.detach().cpu().numpy()
287
+ flow = flow_to_image(flow) # [H, W, 3]
288
+ flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
289
+
290
+ return flow
utils/visualization.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import numpy as np
4
+ import torchvision.utils as vutils
5
+ import cv2
6
+ from matplotlib.cm import get_cmap
7
+ import matplotlib as mpl
8
+ import matplotlib.cm as cm
9
+
10
+
11
+ def vis_disparity(disp, return_rgb=False):
12
+ disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
13
+ disp_vis = disp_vis.astype("uint8")
14
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
15
+
16
+ if return_rgb:
17
+ disp_vis = cv2.cvtColor(disp_vis, cv2.COLOR_BGR2RGB)
18
+
19
+ return disp_vis
20
+
21
+
22
+ def gen_error_colormap():
23
+ cols = np.array(
24
+ [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149],
25
+ [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180],
26
+ [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209],
27
+ [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233],
28
+ [1.5 / 3.0, 3 / 3.0, 224, 243, 248],
29
+ [3 / 3.0, 6 / 3.0, 254, 224, 144],
30
+ [6 / 3.0, 12 / 3.0, 253, 174, 97],
31
+ [12 / 3.0, 24 / 3.0, 244, 109, 67],
32
+ [24 / 3.0, 48 / 3.0, 215, 48, 39],
33
+ [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32)
34
+ cols[:, 2: 5] /= 255.
35
+ return cols
36
+
37
+
38
+ def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1):
39
+ D_gt_np = D_gt_tensor.detach().cpu().numpy()
40
+ D_est_np = D_est_tensor.detach().cpu().numpy()
41
+ B, H, W = D_gt_np.shape
42
+ # valid mask
43
+ mask = D_gt_np > 0
44
+ # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5%
45
+ error = np.abs(D_gt_np - D_est_np)
46
+ error[np.logical_not(mask)] = 0
47
+ error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres)
48
+ # get colormap
49
+ cols = gen_error_colormap()
50
+ # create error image
51
+ error_image = np.zeros([B, H, W, 3], dtype=np.float32)
52
+ for i in range(cols.shape[0]):
53
+ error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:]
54
+ # TODO: imdilate
55
+ # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius));
56
+ error_image[np.logical_not(mask)] = 0.
57
+ # show color tag in the top-left cornor of the image
58
+ for i in range(cols.shape[0]):
59
+ distance = 20
60
+ error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:]
61
+
62
+ return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2])))
63
+
64
+
65
+ def save_images(logger, mode_tag, images_dict, global_step):
66
+ images_dict = tensor2numpy(images_dict)
67
+ for tag, values in images_dict.items():
68
+ if not isinstance(values, list) and not isinstance(values, tuple):
69
+ values = [values]
70
+ for idx, value in enumerate(values):
71
+ if len(value.shape) == 3:
72
+ value = value[:, np.newaxis, :, :]
73
+ value = value[:1]
74
+ value = torch.from_numpy(value)
75
+
76
+ image_name = '{}/{}'.format(mode_tag, tag)
77
+ if len(values) > 1:
78
+ image_name = image_name + "_" + str(idx)
79
+ logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True),
80
+ global_step)
81
+
82
+
83
+ def tensor2numpy(var_dict):
84
+ for key, vars in var_dict.items():
85
+ if isinstance(vars, np.ndarray):
86
+ var_dict[key] = vars
87
+ elif isinstance(vars, torch.Tensor):
88
+ var_dict[key] = vars.data.cpu().numpy()
89
+ else:
90
+ raise NotImplementedError("invalid input type for tensor2numpy")
91
+
92
+ return var_dict
93
+
94
+
95
+ def viz_depth_tensor_from_monodepth2(disp, return_numpy=False, colormap='plasma'):
96
+ # visualize inverse depth
97
+ assert isinstance(disp, torch.Tensor)
98
+
99
+ disp = disp.numpy()
100
+ vmax = np.percentile(disp, 95)
101
+ normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax)
102
+ mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap)
103
+ colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3]
104
+
105
+ if return_numpy:
106
+ return colormapped_im
107
+
108
+ viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W]
109
+
110
+ return viz