|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def batched_triangulate(pts2d, |
|
proj_mats): |
|
B, Ncams, Npts, two = pts2d.shape |
|
assert two==2 |
|
assert proj_mats.shape == (B, Ncams, 3, 4) |
|
|
|
x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) |
|
y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) |
|
eq = torch.cat([x, y], dim=1).transpose(1, 2) |
|
return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution |
|
|
|
def matches_to_depths(intrinsics, |
|
extrinsics, |
|
matches, |
|
batchsize=16, |
|
min_num_valids_ratio=.3 |
|
): |
|
B, Nv, H, W, five = matches.shape |
|
min_num_valids = np.floor(Nv*min_num_valids_ratio) |
|
out_aggregated_points, out_depths, out_confs = [], [], [] |
|
for b in range(B//batchsize+1): |
|
start, stop = b*batchsize,min(B,(b+1)*batchsize) |
|
sub_batch=slice(start,stop) |
|
sub_batchsize = stop-start |
|
if sub_batchsize==0:continue |
|
points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1] |
|
allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) |
|
|
|
allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:] |
|
cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] |
|
formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) |
|
|
|
|
|
points_3d_world = batched_triangulate(allpoints, formatted_camPs) |
|
|
|
|
|
points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3]) |
|
valids = points_3d_world.isfinite() |
|
valids_sum = valids.sum(dim=-1) |
|
validsuni=valids_sum.unique() |
|
assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset" |
|
confs[valids_sum==0] = 0. |
|
points_3d_world = points_3d_world*confs[...,None] |
|
|
|
|
|
normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1) |
|
normalization[normalization <= 1e-5] = 1. |
|
points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1) |
|
points_3d_world[~valids] = 0. |
|
aggregated_points = points_3d_world.sum(dim=1) |
|
|
|
|
|
aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan |
|
|
|
|
|
refcamE = extrinsics[sub_batch, 0] |
|
points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) |
|
depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] |
|
|
|
|
|
out_aggregated_points.append(aggregated_points.cpu()) |
|
out_depths.append(depths.cpu()) |
|
out_confs.append(confs.sum(dim=1).cpu()) |
|
|
|
out_aggregated_points = torch.cat(out_aggregated_points,dim=0) |
|
out_depths = torch.cat(out_depths,dim=0) |
|
out_confs = torch.cat(out_confs,dim=0) |
|
|
|
return out_aggregated_points, out_depths, out_confs |
|
|