File size: 13,525 Bytes
83ae704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import torch
from torch import nn
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as pl

import mast3r.utils.path_to_dust3r  # noqa
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
from dust3r.cloud_opt.base_opt import clean_pointcloud


class TSDFPostProcess:
    """ Optimizes a signed distance-function to improve depthmaps.
    """

    def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
        self.TSDF_thresh = TSDF_thresh  # None -> no TSDF
        self.TSDF_batchsize = TSDF_batchsize
        self.optimizer = optimizer

        pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
        pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
        self.pts3d = pts3d
        self.depthmaps = depthmaps
        self.confs = confs

    def _get_depthmaps(self, TSDF_filtering_thresh=None):
        if TSDF_filtering_thresh:
            self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh)  # compute refined depths if needed
        dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
        return [d.exp() for d in dms]

    @torch.no_grad()
    def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
        """
        Leverage TSDF to post-process estimated depths
        for each pixel, find zero level of TSDF along ray (or closest to 0)
        """
        print("Post-Processing Depths with TSDF fusion.")
        self.TSDF_im_depthmaps = []
        alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
        ), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
        for vi in tqdm(range(self.optimizer.n_imgs)):
            dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
            minvals = torch.full(dm.shape, 1e20)

            for it in range(niter):
                H, W = dm.shape
                curthresh = (niter - it) * TSDF_filtering_thresh
                dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
                    curthresh  # decreasing search std along with iterations
                newdm = dm[..., None] + dm_offsets  # [H,W,Nsamp]
                curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
                    imshape])[0]  # [H,W,Nsamp,3]
                # Batched TSDF eval
                curproj = curproj.view(-1, 3)
                tsdf_vals = []
                valids = []
                for batch in range(0, len(curproj), self.TSDF_batchsize):
                    values, valid = self._TSDF_query(
                        curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
                    tsdf_vals.append(values)
                    valids.append(valid)
                tsdf_vals = torch.cat(tsdf_vals, dim=0)
                valids = torch.cat(valids, dim=0)

                tsdf_vals = tsdf_vals.view([H, W, nsamples])
                valids = valids.view([H, W, nsamples])

                # keep depth value that got us the closest to 0
                tsdf_vals[~valids] = torch.inf  # ignore invalid values
                tsdf_vals = tsdf_vals.abs()
                mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
                # when all samples live on a very flat zone, do nothing
                allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
                dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]

            # Save refined depth map
            self.TSDF_im_depthmaps.append(dm.log())

    def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
        """
        TSDF query call: returns the weighted TSDF value for each query point [N, 3]
        """
        N, three = qpoints.shape
        assert three == 3
        qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1)  # [B,N,3]
        # get projection coordinates and depths onto images
        coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
        ), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
        image_coords = coords_and_depth[..., :2].round().to(int)  # for now, there's no interpolation...
        proj_depths = coords_and_depth[..., -1]
        # recover depth values after scene optim
        pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
        # Gather TSDF scores
        all_SDF_scores = pred_depths - proj_depths  # SDF
        unseen = all_SDF_scores < -TSDF_filtering_thresh  # handle visibility
        # all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
        all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20)  # SDF -> TSDF
        # Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
        all_TSDF_weights = (~unseen).float() * valids.float()
        if weighted:
            all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
        # Aggregate all votes, ignoring zeros
        TSDF_weights = all_TSDF_weights.sum(dim=0)
        valids = TSDF_weights != 0.
        TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
        TSDF_wsum[valids] /= TSDF_weights[valids]
        return TSDF_wsum, valids

    def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
        """ Recover depth value for each input pixel coordinate, along with OOB validity mask
        """
        B, N, two = image_coords.shape
        assert B == self.optimizer.n_imgs and two == 2
        depths = torch.zeros([B, N], device=image_coords.device)
        valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
        confs = torch.zeros([B, N], device=image_coords.device)
        curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
        for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
            H, W = depth.shape
            valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
                                           H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
            imc[~valids[ni]] = 0
            depths[ni] = depth[imc[:, 1], imc[:, 0]]
            confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
        return depths, confs, valids

    def _get_confs_with_normals(self):
        outconfs = []
        # Confidence basedf on depth gradient

        class Sobel(nn.Module):
            def __init__(self):
                super().__init__()
                self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
                Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
                Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
                G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
                G = G.unsqueeze(1)
                self.filter.weight = nn.Parameter(G, requires_grad=False)

            def forward(self, img):
                x = self.filter(img)
                x = torch.mul(x, x)
                x = torch.sum(x, dim=1, keepdim=True)
                x = torch.sqrt(x)
                return x

        grad_op = Sobel().to(self.im_depthmaps[0].device)
        for conf, depth in zip(self.im_conf, self.im_depthmaps):
            grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
            if not 'dbg show':
                pl.imshow(grad_confs.cpu())
                pl.show()
            outconfs.append(conf * grad_confs.to(conf))
        return outconfs

    def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
        """
        Projection operation: from 3D points to 2D coordinates + depths
        """
        B = pts3d.shape[0]
        assert pts3d.shape[0] == cam2worlds.shape[0]
        # prepare Extrinsincs
        R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
        Rinv = R.transpose(-2, -1)
        tinv = -Rinv @ t[..., None]

        # prepare intrinsics
        intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
        if len(focals.shape) == 1:
            focals = torch.stack([focals, focals], dim=-1)
        intrinsics[:, 0, 0] = focals[:, 0]
        intrinsics[:, 1, 1] = focals[:, 1]
        intrinsics[:, :2, -1] = pps
        # Project
        projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv)  # I(RX+t) : [B,3,N]
        projpts = projpts.transpose(-2, -1)  # [B,N,3]
        projpts[..., :2] /= projpts[..., [-1]]  # [B,N,3] (X/Z , Y/Z, Z)
        return projpts

    def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
                        in_focals=None, in_pps=None, in_imshapes=None):
        """
        Backprojection operation: from image depths to 3D points
        """
        # Get depths and  projection params if not provided
        focals = self.optimizer.get_focals() if in_focals is None else in_focals
        im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
        depth = self._get_depthmaps() if in_depths is None else in_depths
        pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
        imshapes = self.imshapes if in_imshapes is None else in_imshapes
        def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
        dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]

        def autoprocess(x):
            x = x[0]
            return x.transpose(-2, -1) if len(x.shape) == 4 else x
        return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]

    def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
        """
        Projection operation: from 3D points to 2D coordinates + depths
        """
        B = pts3d.shape[0]
        assert pts3d.shape[0] == cam2worlds.shape[0]
        # prepare Extrinsincs
        R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
        Rinv = R.transpose(-2, -1)
        tinv = -Rinv @ t[..., None]

        # prepare intrinsics
        intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
        if len(focals.shape) == 1:
            focals = torch.stack([focals, focals], dim=-1)
        intrinsics[:, 0, 0] = focals[:, 0]
        intrinsics[:, 1, 1] = focals[:, 1]
        intrinsics[:, :2, -1] = pps
        # Project
        projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv)  # I(RX+t) : [B,3,N]
        projpts = projpts.transpose(-2, -1)  # [B,N,3]
        projpts[..., :2] /= projpts[..., [-1]]  # [B,N,3] (X/Z , Y/Z, Z)
        return projpts

    def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
        """
        Backprojection operation: from image depths to 3D points
        """
        # Get depths and  projection params if not provided
        focals = self.optimizer.get_focals() if in_focals is None else in_focals
        im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
        depth = self._get_depthmaps() if in_depths is None else in_depths
        pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
        imshapes = self.imshapes if in_imshapes is None else in_imshapes

        def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])

        dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]

        def autoprocess(x):
            x = x[0]
            H, W, three = x.shape[:3]
            return x.transpose(-2, -1) if len(x.shape) == 4 else x
        return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]

    def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
        """
        return 3D points (possibly filtering depths with TSDF) 
        """
        return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)

    def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
        # Setup inner variables
        self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
        self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
        self.im_conf = confs

        if self.TSDF_thresh > 0.:
            # Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
            self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
            depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
            # Turn them into 3D points
            pts3d = self._backproj_pts3d(in_depths=depthmaps)
            depthmaps = [dd.flatten() for dd in depthmaps]
            pts3d = [pp.view(-1, 3) for pp in pts3d]
        return pts3d, depthmaps

    def get_dense_pts3d(self, clean_depth=True):
        if clean_depth:
            confs = clean_pointcloud(self.confs, self.optimizer.intrinsics, inv(self.optimizer.cam2w),
                                     self.depthmaps, self.pts3d)
        return self.pts3d, self.depthmaps, confs