yocabon commited on
Commit
15cb3c1
·
unverified ·
1 Parent(s): 67aa04c

Update sparse_ga.py

Browse files
Files changed (1) hide show
  1. mast3r/cloud_opt/sparse_ga.py +6 -5
mast3r/cloud_opt/sparse_ga.py CHANGED
@@ -311,10 +311,8 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
311
  for imk, imv in preds_21.items():
312
  subsamp_preds_21[imk] = {}
313
  for im2k, (pred, conf) in preds_21[imk].items():
314
- subpred = pred[::subsample, ::subsample].reshape(-1, 3) # original subsample
315
- subconf = conf[::subsample, ::subsample].ravel() # for both ptmaps and confs
316
  idxs = anchors[imgs.index(im2k)][1]
317
- subsamp_preds_21[imk][im2k] = (subpred[idxs], subconf[idxs]) # anchors subsample
318
 
319
  # Prepare slices and corres for losses
320
  dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
@@ -344,6 +342,8 @@ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_dept
344
  loss = 0.
345
  cf_sum = 0.
346
  for s in dust3r_slices:
 
 
347
  # fallback to dust3r regression
348
  tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
349
  tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
@@ -669,7 +669,8 @@ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_con
669
  pixels[img2] = xy1, confs
670
  if img not in preds_21:
671
  preds_21[img] = {}
672
- preds_21[img][img2] = X2, C2
 
673
 
674
  if img == img2:
675
  X, C, X2, C2 = torch.load(path2, map_location=device)
@@ -677,7 +678,7 @@ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_con
677
  pixels[img1] = xy2, confs
678
  if img not in preds_21:
679
  preds_21[img] = {}
680
- preds_21[img][img1] = X2, C2
681
 
682
  if score is not None:
683
  i, j = imgs.index(img1), imgs.index(img2)
 
311
  for imk, imv in preds_21.items():
312
  subsamp_preds_21[imk] = {}
313
  for im2k, (pred, conf) in preds_21[imk].items():
 
 
314
  idxs = anchors[imgs.index(im2k)][1]
315
+ subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
316
 
317
  # Prepare slices and corres for losses
318
  dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
 
342
  loss = 0.
343
  cf_sum = 0.
344
  for s in dust3r_slices:
345
+ if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
346
+ continue
347
  # fallback to dust3r regression
348
  tgt_pts, tgt_confs = subsamp_preds_21[imgs[s.img2]][imgs[s.img1]]
349
  tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
 
669
  pixels[img2] = xy1, confs
670
  if img not in preds_21:
671
  preds_21[img] = {}
672
+ # Subsample preds_21
673
+ preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
674
 
675
  if img == img2:
676
  X, C, X2, C2 = torch.load(path2, map_location=device)
 
678
  pixels[img1] = xy2, confs
679
  if img not in preds_21:
680
  preds_21[img] = {}
681
+ preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
682
 
683
  if score is not None:
684
  i, j = imgs.index(img1), imgs.index(img2)