Spaces:
Sleeping
Sleeping
Upload tsdf_optimizer.py
Browse files- tsdf_optimizer.py +273 -0
tsdf_optimizer.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from matplotlib import pyplot as pl
|
6 |
+
|
7 |
+
import mast3r.utils.path_to_dust3r # noqa
|
8 |
+
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
|
9 |
+
from dust3r.cloud_opt.base_opt import clean_pointcloud
|
10 |
+
|
11 |
+
|
12 |
+
class TSDFPostProcess:
|
13 |
+
""" Optimizes a signed distance-function to improve depthmaps.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
|
17 |
+
self.TSDF_thresh = TSDF_thresh # None -> no TSDF
|
18 |
+
self.TSDF_batchsize = TSDF_batchsize
|
19 |
+
self.optimizer = optimizer
|
20 |
+
|
21 |
+
pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
|
22 |
+
pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
|
23 |
+
self.pts3d = pts3d
|
24 |
+
self.depthmaps = depthmaps
|
25 |
+
self.confs = confs
|
26 |
+
|
27 |
+
def _get_depthmaps(self, TSDF_filtering_thresh=None):
|
28 |
+
if TSDF_filtering_thresh:
|
29 |
+
self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
|
30 |
+
dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
|
31 |
+
return [d.exp() for d in dms]
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
|
35 |
+
"""
|
36 |
+
Leverage TSDF to post-process estimated depths
|
37 |
+
for each pixel, find zero level of TSDF along ray (or closest to 0)
|
38 |
+
"""
|
39 |
+
print("Post-Processing Depths with TSDF fusion.")
|
40 |
+
self.TSDF_im_depthmaps = []
|
41 |
+
alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
|
42 |
+
), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
|
43 |
+
for vi in tqdm(range(self.optimizer.n_imgs)):
|
44 |
+
dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
|
45 |
+
minvals = torch.full(dm.shape, 1e20)
|
46 |
+
|
47 |
+
for it in range(niter):
|
48 |
+
H, W = dm.shape
|
49 |
+
curthresh = (niter - it) * TSDF_filtering_thresh
|
50 |
+
dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
|
51 |
+
curthresh # decreasing search std along with iterations
|
52 |
+
newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
|
53 |
+
curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
|
54 |
+
imshape])[0] # [H,W,Nsamp,3]
|
55 |
+
# Batched TSDF eval
|
56 |
+
curproj = curproj.view(-1, 3)
|
57 |
+
tsdf_vals = []
|
58 |
+
valids = []
|
59 |
+
for batch in range(0, len(curproj), self.TSDF_batchsize):
|
60 |
+
values, valid = self._TSDF_query(
|
61 |
+
curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
|
62 |
+
tsdf_vals.append(values)
|
63 |
+
valids.append(valid)
|
64 |
+
tsdf_vals = torch.cat(tsdf_vals, dim=0)
|
65 |
+
valids = torch.cat(valids, dim=0)
|
66 |
+
|
67 |
+
tsdf_vals = tsdf_vals.view([H, W, nsamples])
|
68 |
+
valids = valids.view([H, W, nsamples])
|
69 |
+
|
70 |
+
# keep depth value that got us the closest to 0
|
71 |
+
tsdf_vals[~valids] = torch.inf # ignore invalid values
|
72 |
+
tsdf_vals = tsdf_vals.abs()
|
73 |
+
mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
|
74 |
+
# when all samples live on a very flat zone, do nothing
|
75 |
+
allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
|
76 |
+
dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
|
77 |
+
|
78 |
+
# Save refined depth map
|
79 |
+
self.TSDF_im_depthmaps.append(dm.log())
|
80 |
+
|
81 |
+
def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
|
82 |
+
"""
|
83 |
+
TSDF query call: returns the weighted TSDF value for each query point [N, 3]
|
84 |
+
"""
|
85 |
+
N, three = qpoints.shape
|
86 |
+
assert three == 3
|
87 |
+
qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
|
88 |
+
# get projection coordinates and depths onto images
|
89 |
+
coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
|
90 |
+
), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
|
91 |
+
image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
|
92 |
+
proj_depths = coords_and_depth[..., -1]
|
93 |
+
# recover depth values after scene optim
|
94 |
+
pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
|
95 |
+
# Gather TSDF scores
|
96 |
+
all_SDF_scores = pred_depths - proj_depths # SDF
|
97 |
+
unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
|
98 |
+
# all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
|
99 |
+
all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
|
100 |
+
# Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
|
101 |
+
all_TSDF_weights = (~unseen).float() * valids.float()
|
102 |
+
if weighted:
|
103 |
+
all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
|
104 |
+
# Aggregate all votes, ignoring zeros
|
105 |
+
TSDF_weights = all_TSDF_weights.sum(dim=0)
|
106 |
+
valids = TSDF_weights != 0.
|
107 |
+
TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
|
108 |
+
TSDF_wsum[valids] /= TSDF_weights[valids]
|
109 |
+
return TSDF_wsum, valids
|
110 |
+
|
111 |
+
def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
|
112 |
+
""" Recover depth value for each input pixel coordinate, along with OOB validity mask
|
113 |
+
"""
|
114 |
+
B, N, two = image_coords.shape
|
115 |
+
assert B == self.optimizer.n_imgs and two == 2
|
116 |
+
depths = torch.zeros([B, N], device=image_coords.device)
|
117 |
+
valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
|
118 |
+
confs = torch.zeros([B, N], device=image_coords.device)
|
119 |
+
curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
|
120 |
+
for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
|
121 |
+
H, W = depth.shape
|
122 |
+
valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
|
123 |
+
H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
|
124 |
+
imc[~valids[ni]] = 0
|
125 |
+
depths[ni] = depth[imc[:, 1], imc[:, 0]]
|
126 |
+
confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
|
127 |
+
return depths, confs, valids
|
128 |
+
|
129 |
+
def _get_confs_with_normals(self):
|
130 |
+
outconfs = []
|
131 |
+
# Confidence basedf on depth gradient
|
132 |
+
|
133 |
+
class Sobel(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super().__init__()
|
136 |
+
self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
|
137 |
+
Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
|
138 |
+
Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
|
139 |
+
G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
|
140 |
+
G = G.unsqueeze(1)
|
141 |
+
self.filter.weight = nn.Parameter(G, requires_grad=False)
|
142 |
+
|
143 |
+
def forward(self, img):
|
144 |
+
x = self.filter(img)
|
145 |
+
x = torch.mul(x, x)
|
146 |
+
x = torch.sum(x, dim=1, keepdim=True)
|
147 |
+
x = torch.sqrt(x)
|
148 |
+
return x
|
149 |
+
|
150 |
+
grad_op = Sobel().to(self.im_depthmaps[0].device)
|
151 |
+
for conf, depth in zip(self.im_conf, self.im_depthmaps):
|
152 |
+
grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
|
153 |
+
if not 'dbg show':
|
154 |
+
pl.imshow(grad_confs.cpu())
|
155 |
+
pl.show()
|
156 |
+
outconfs.append(conf * grad_confs.to(conf))
|
157 |
+
return outconfs
|
158 |
+
|
159 |
+
def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
|
160 |
+
"""
|
161 |
+
Projection operation: from 3D points to 2D coordinates + depths
|
162 |
+
"""
|
163 |
+
B = pts3d.shape[0]
|
164 |
+
assert pts3d.shape[0] == cam2worlds.shape[0]
|
165 |
+
# prepare Extrinsincs
|
166 |
+
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
167 |
+
Rinv = R.transpose(-2, -1)
|
168 |
+
tinv = -Rinv @ t[..., None]
|
169 |
+
|
170 |
+
# prepare intrinsics
|
171 |
+
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
|
172 |
+
if len(focals.shape) == 1:
|
173 |
+
focals = torch.stack([focals, focals], dim=-1)
|
174 |
+
intrinsics[:, 0, 0] = focals[:, 0]
|
175 |
+
intrinsics[:, 1, 1] = focals[:, 1]
|
176 |
+
intrinsics[:, :2, -1] = pps
|
177 |
+
# Project
|
178 |
+
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
179 |
+
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
180 |
+
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
181 |
+
return projpts
|
182 |
+
|
183 |
+
def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
|
184 |
+
in_focals=None, in_pps=None, in_imshapes=None):
|
185 |
+
"""
|
186 |
+
Backprojection operation: from image depths to 3D points
|
187 |
+
"""
|
188 |
+
# Get depths and projection params if not provided
|
189 |
+
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
190 |
+
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
191 |
+
depth = self._get_depthmaps() if in_depths is None else in_depths
|
192 |
+
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
193 |
+
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
194 |
+
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
195 |
+
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
|
196 |
+
|
197 |
+
def autoprocess(x):
|
198 |
+
x = x[0]
|
199 |
+
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
200 |
+
return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
|
201 |
+
|
202 |
+
def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
|
203 |
+
"""
|
204 |
+
Projection operation: from 3D points to 2D coordinates + depths
|
205 |
+
"""
|
206 |
+
B = pts3d.shape[0]
|
207 |
+
assert pts3d.shape[0] == cam2worlds.shape[0]
|
208 |
+
# prepare Extrinsincs
|
209 |
+
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
210 |
+
Rinv = R.transpose(-2, -1)
|
211 |
+
tinv = -Rinv @ t[..., None]
|
212 |
+
|
213 |
+
# prepare intrinsics
|
214 |
+
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
|
215 |
+
if len(focals.shape) == 1:
|
216 |
+
focals = torch.stack([focals, focals], dim=-1)
|
217 |
+
intrinsics[:, 0, 0] = focals[:, 0]
|
218 |
+
intrinsics[:, 1, 1] = focals[:, 1]
|
219 |
+
intrinsics[:, :2, -1] = pps
|
220 |
+
# Project
|
221 |
+
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
222 |
+
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
223 |
+
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
224 |
+
return projpts
|
225 |
+
|
226 |
+
def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
|
227 |
+
"""
|
228 |
+
Backprojection operation: from image depths to 3D points
|
229 |
+
"""
|
230 |
+
# Get depths and projection params if not provided
|
231 |
+
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
232 |
+
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
233 |
+
depth = self._get_depthmaps() if in_depths is None else in_depths
|
234 |
+
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
235 |
+
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
236 |
+
|
237 |
+
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
238 |
+
|
239 |
+
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
|
240 |
+
|
241 |
+
def autoprocess(x):
|
242 |
+
x = x[0]
|
243 |
+
H, W, three = x.shape[:3]
|
244 |
+
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
245 |
+
return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
|
246 |
+
|
247 |
+
def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
|
248 |
+
"""
|
249 |
+
return 3D points (possibly filtering depths with TSDF)
|
250 |
+
"""
|
251 |
+
return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
|
252 |
+
|
253 |
+
def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
|
254 |
+
# Setup inner variables
|
255 |
+
self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
|
256 |
+
self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
|
257 |
+
self.im_conf = confs
|
258 |
+
|
259 |
+
if self.TSDF_thresh > 0.:
|
260 |
+
# Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
|
261 |
+
self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
|
262 |
+
depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
|
263 |
+
# Turn them into 3D points
|
264 |
+
pts3d = self._backproj_pts3d(in_depths=depthmaps)
|
265 |
+
depthmaps = [dd.flatten() for dd in depthmaps]
|
266 |
+
pts3d = [pp.view(-1, 3) for pp in pts3d]
|
267 |
+
return pts3d, depthmaps
|
268 |
+
|
269 |
+
def get_dense_pts3d(self, clean_depth=True):
|
270 |
+
if clean_depth:
|
271 |
+
confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w),
|
272 |
+
self.depthmaps, self.pts3d)
|
273 |
+
return self.pts3d, self.depthmaps, confs
|