zino36 commited on
Commit
fc5081c
·
verified ·
1 Parent(s): f3cc519

Upload tsdf_optimizer.py

Browse files
Files changed (1) hide show
  1. tsdf_optimizer.py +273 -0
tsdf_optimizer.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from matplotlib import pyplot as pl
6
+
7
+ import mast3r.utils.path_to_dust3r # noqa
8
+ from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
9
+ from dust3r.cloud_opt.base_opt import clean_pointcloud
10
+
11
+
12
+ class TSDFPostProcess:
13
+ """ Optimizes a signed distance-function to improve depthmaps.
14
+ """
15
+
16
+ def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
17
+ self.TSDF_thresh = TSDF_thresh # None -> no TSDF
18
+ self.TSDF_batchsize = TSDF_batchsize
19
+ self.optimizer = optimizer
20
+
21
+ pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
22
+ pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
23
+ self.pts3d = pts3d
24
+ self.depthmaps = depthmaps
25
+ self.confs = confs
26
+
27
+ def _get_depthmaps(self, TSDF_filtering_thresh=None):
28
+ if TSDF_filtering_thresh:
29
+ self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
30
+ dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
31
+ return [d.exp() for d in dms]
32
+
33
+ @torch.no_grad()
34
+ def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
35
+ """
36
+ Leverage TSDF to post-process estimated depths
37
+ for each pixel, find zero level of TSDF along ray (or closest to 0)
38
+ """
39
+ print("Post-Processing Depths with TSDF fusion.")
40
+ self.TSDF_im_depthmaps = []
41
+ alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
42
+ ), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
43
+ for vi in tqdm(range(self.optimizer.n_imgs)):
44
+ dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
45
+ minvals = torch.full(dm.shape, 1e20)
46
+
47
+ for it in range(niter):
48
+ H, W = dm.shape
49
+ curthresh = (niter - it) * TSDF_filtering_thresh
50
+ dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
51
+ curthresh # decreasing search std along with iterations
52
+ newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
53
+ curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
54
+ imshape])[0] # [H,W,Nsamp,3]
55
+ # Batched TSDF eval
56
+ curproj = curproj.view(-1, 3)
57
+ tsdf_vals = []
58
+ valids = []
59
+ for batch in range(0, len(curproj), self.TSDF_batchsize):
60
+ values, valid = self._TSDF_query(
61
+ curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
62
+ tsdf_vals.append(values)
63
+ valids.append(valid)
64
+ tsdf_vals = torch.cat(tsdf_vals, dim=0)
65
+ valids = torch.cat(valids, dim=0)
66
+
67
+ tsdf_vals = tsdf_vals.view([H, W, nsamples])
68
+ valids = valids.view([H, W, nsamples])
69
+
70
+ # keep depth value that got us the closest to 0
71
+ tsdf_vals[~valids] = torch.inf # ignore invalid values
72
+ tsdf_vals = tsdf_vals.abs()
73
+ mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
74
+ # when all samples live on a very flat zone, do nothing
75
+ allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
76
+ dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
77
+
78
+ # Save refined depth map
79
+ self.TSDF_im_depthmaps.append(dm.log())
80
+
81
+ def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
82
+ """
83
+ TSDF query call: returns the weighted TSDF value for each query point [N, 3]
84
+ """
85
+ N, three = qpoints.shape
86
+ assert three == 3
87
+ qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
88
+ # get projection coordinates and depths onto images
89
+ coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
90
+ ), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
91
+ image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
92
+ proj_depths = coords_and_depth[..., -1]
93
+ # recover depth values after scene optim
94
+ pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
95
+ # Gather TSDF scores
96
+ all_SDF_scores = pred_depths - proj_depths # SDF
97
+ unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
98
+ # all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
99
+ all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
100
+ # Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
101
+ all_TSDF_weights = (~unseen).float() * valids.float()
102
+ if weighted:
103
+ all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
104
+ # Aggregate all votes, ignoring zeros
105
+ TSDF_weights = all_TSDF_weights.sum(dim=0)
106
+ valids = TSDF_weights != 0.
107
+ TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
108
+ TSDF_wsum[valids] /= TSDF_weights[valids]
109
+ return TSDF_wsum, valids
110
+
111
+ def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
112
+ """ Recover depth value for each input pixel coordinate, along with OOB validity mask
113
+ """
114
+ B, N, two = image_coords.shape
115
+ assert B == self.optimizer.n_imgs and two == 2
116
+ depths = torch.zeros([B, N], device=image_coords.device)
117
+ valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
118
+ confs = torch.zeros([B, N], device=image_coords.device)
119
+ curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
120
+ for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
121
+ H, W = depth.shape
122
+ valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
123
+ H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
124
+ imc[~valids[ni]] = 0
125
+ depths[ni] = depth[imc[:, 1], imc[:, 0]]
126
+ confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
127
+ return depths, confs, valids
128
+
129
+ def _get_confs_with_normals(self):
130
+ outconfs = []
131
+ # Confidence basedf on depth gradient
132
+
133
+ class Sobel(nn.Module):
134
+ def __init__(self):
135
+ super().__init__()
136
+ self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
137
+ Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
138
+ Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
139
+ G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
140
+ G = G.unsqueeze(1)
141
+ self.filter.weight = nn.Parameter(G, requires_grad=False)
142
+
143
+ def forward(self, img):
144
+ x = self.filter(img)
145
+ x = torch.mul(x, x)
146
+ x = torch.sum(x, dim=1, keepdim=True)
147
+ x = torch.sqrt(x)
148
+ return x
149
+
150
+ grad_op = Sobel().to(self.im_depthmaps[0].device)
151
+ for conf, depth in zip(self.im_conf, self.im_depthmaps):
152
+ grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
153
+ if not 'dbg show':
154
+ pl.imshow(grad_confs.cpu())
155
+ pl.show()
156
+ outconfs.append(conf * grad_confs.to(conf))
157
+ return outconfs
158
+
159
+ def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
160
+ """
161
+ Projection operation: from 3D points to 2D coordinates + depths
162
+ """
163
+ B = pts3d.shape[0]
164
+ assert pts3d.shape[0] == cam2worlds.shape[0]
165
+ # prepare Extrinsincs
166
+ R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
167
+ Rinv = R.transpose(-2, -1)
168
+ tinv = -Rinv @ t[..., None]
169
+
170
+ # prepare intrinsics
171
+ intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
172
+ if len(focals.shape) == 1:
173
+ focals = torch.stack([focals, focals], dim=-1)
174
+ intrinsics[:, 0, 0] = focals[:, 0]
175
+ intrinsics[:, 1, 1] = focals[:, 1]
176
+ intrinsics[:, :2, -1] = pps
177
+ # Project
178
+ projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
179
+ projpts = projpts.transpose(-2, -1) # [B,N,3]
180
+ projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
181
+ return projpts
182
+
183
+ def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
184
+ in_focals=None, in_pps=None, in_imshapes=None):
185
+ """
186
+ Backprojection operation: from image depths to 3D points
187
+ """
188
+ # Get depths and projection params if not provided
189
+ focals = self.optimizer.get_focals() if in_focals is None else in_focals
190
+ im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
191
+ depth = self._get_depthmaps() if in_depths is None else in_depths
192
+ pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
193
+ imshapes = self.imshapes if in_imshapes is None else in_imshapes
194
+ def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
195
+ dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
196
+
197
+ def autoprocess(x):
198
+ x = x[0]
199
+ return x.transpose(-2, -1) if len(x.shape) == 4 else x
200
+ return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
201
+
202
+ def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
203
+ """
204
+ Projection operation: from 3D points to 2D coordinates + depths
205
+ """
206
+ B = pts3d.shape[0]
207
+ assert pts3d.shape[0] == cam2worlds.shape[0]
208
+ # prepare Extrinsincs
209
+ R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
210
+ Rinv = R.transpose(-2, -1)
211
+ tinv = -Rinv @ t[..., None]
212
+
213
+ # prepare intrinsics
214
+ intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
215
+ if len(focals.shape) == 1:
216
+ focals = torch.stack([focals, focals], dim=-1)
217
+ intrinsics[:, 0, 0] = focals[:, 0]
218
+ intrinsics[:, 1, 1] = focals[:, 1]
219
+ intrinsics[:, :2, -1] = pps
220
+ # Project
221
+ projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
222
+ projpts = projpts.transpose(-2, -1) # [B,N,3]
223
+ projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
224
+ return projpts
225
+
226
+ def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
227
+ """
228
+ Backprojection operation: from image depths to 3D points
229
+ """
230
+ # Get depths and projection params if not provided
231
+ focals = self.optimizer.get_focals() if in_focals is None else in_focals
232
+ im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
233
+ depth = self._get_depthmaps() if in_depths is None else in_depths
234
+ pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
235
+ imshapes = self.imshapes if in_imshapes is None else in_imshapes
236
+
237
+ def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
238
+
239
+ dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
240
+
241
+ def autoprocess(x):
242
+ x = x[0]
243
+ H, W, three = x.shape[:3]
244
+ return x.transpose(-2, -1) if len(x.shape) == 4 else x
245
+ return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
246
+
247
+ def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
248
+ """
249
+ return 3D points (possibly filtering depths with TSDF)
250
+ """
251
+ return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
252
+
253
+ def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
254
+ # Setup inner variables
255
+ self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
256
+ self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
257
+ self.im_conf = confs
258
+
259
+ if self.TSDF_thresh > 0.:
260
+ # Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
261
+ self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
262
+ depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
263
+ # Turn them into 3D points
264
+ pts3d = self._backproj_pts3d(in_depths=depthmaps)
265
+ depthmaps = [dd.flatten() for dd in depthmaps]
266
+ pts3d = [pp.view(-1, 3) for pp in pts3d]
267
+ return pts3d, depthmaps
268
+
269
+ def get_dense_pts3d(self, clean_depth=True):
270
+ if clean_depth:
271
+ confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w),
272
+ self.depthmaps, self.pts3d)
273
+ return self.pts3d, self.depthmaps, confs