Update sparse_ga.py
Browse files
mast3r/cloud_opt/sparse_ga.py
CHANGED
@@ -311,10 +311,8 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
|
|
311 |
for imk, imv in preds_21.items():
|
312 |
subsamp_preds_21[imk] = {}
|
313 |
for im2k, (pred, conf) in preds_21[imk].items():
|
314 |
-
subpred = pred[::subsample, ::subsample].reshape(-1, 3) # original subsample
|
315 |
-
subconf = conf[::subsample, ::subsample].ravel() # for both ptmaps and confs
|
316 |
idxs = anchors[imgs.index(im2k)][1]
|
317 |
-
subsamp_preds_21[imk][im2k] = (
|
318 |
|
319 |
# Prepare slices and corres for losses
|
320 |
dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
|
@@ -344,6 +342,8 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
|
|
344 |
loss = 0.
|
345 |
cf_sum = 0.
|
346 |
for s in dust3r_slices:
|
|
|
|
|
347 |
# fallback to dust3r regression
|
348 |
tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
|
349 |
tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
|
@@ -669,7 +669,8 @@ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_con
|
|
669 |
pixels[img2] = xy1, confs
|
670 |
if img not in preds_21:
|
671 |
preds_21[img] = {}
|
672 |
-
|
|
|
673 |
|
674 |
if img == img2:
|
675 |
X, C, X2, C2 = torch.load(path2, map_location=device)
|
@@ -677,7 +678,7 @@ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_con
|
|
677 |
pixels[img1] = xy2, confs
|
678 |
if img not in preds_21:
|
679 |
preds_21[img] = {}
|
680 |
-
preds_21[img][img1] = X2, C2
|
681 |
|
682 |
if score is not None:
|
683 |
i, j = imgs.index(img1), imgs.index(img2)
|
|
|
311 |
for imk, imv in preds_21.items():
|
312 |
subsamp_preds_21[imk] = {}
|
313 |
for im2k, (pred, conf) in preds_21[imk].items():
|
|
|
|
|
314 |
idxs = anchors[imgs.index(im2k)][1]
|
315 |
+
subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
|
316 |
|
317 |
# Prepare slices and corres for losses
|
318 |
dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
|
|
|
342 |
loss = 0.
|
343 |
cf_sum = 0.
|
344 |
for s in dust3r_slices:
|
345 |
+
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
346 |
+
continue
|
347 |
# fallback to dust3r regression
|
348 |
tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
|
349 |
tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
|
|
|
669 |
pixels[img2] = xy1, confs
|
670 |
if img not in preds_21:
|
671 |
preds_21[img] = {}
|
672 |
+
# Subsample preds_21
|
673 |
+
preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
674 |
|
675 |
if img == img2:
|
676 |
X, C, X2, C2 = torch.load(path2, map_location=device)
|
|
|
678 |
pixels[img1] = xy2, confs
|
679 |
if img not in preds_21:
|
680 |
preds_21[img] = {}
|
681 |
+
preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
682 |
|
683 |
if score is not None:
|
684 |
i, j = imgs.index(img1), imgs.index(img2)
|