# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import torch def index(feat, uv): ''' :param feat: [B, C, H, W] image features :param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1] :return: [B, C, N] image features at the uv coordinates ''' uv = uv.transpose(1, 2) # [B, N, 2] (B, N, _) = uv.shape C = feat.shape[1] if uv.shape[-1] == 3: # uv = uv[:,:,[2,1,0]] # uv = uv * torch.tensor([1.0,-1.0,1.0]).type_as(uv)[None,None,...] uv = uv.unsqueeze(2).unsqueeze(3) # [B, N, 1, 1, 3] else: uv = uv.unsqueeze(2) # [B, N, 1, 2] # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample # for old versions, simply remove the aligned_corners argument. samples = torch.nn.functional.grid_sample( feat, uv, align_corners=True) # [B, C, N, 1] #samples = grid_sample(feat, uv) # [B, C, N, 1] return samples.view(B, C, N) # [B, C, N] def grid_sample(image, optical): N, C, IH, IW = image.shape _, H, W, _ = optical.shape ix = optical[..., 0] iy = optical[..., 1] ix = ((ix + 1) / 2) * (IW-1); iy = ((iy + 1) / 2) * (IH-1); with torch.no_grad(): ix_nw = torch.floor(ix); iy_nw = torch.floor(iy); ix_ne = ix_nw + 1; iy_ne = iy_nw; ix_sw = ix_nw; iy_sw = iy_nw + 1; ix_se = ix_nw + 1; iy_se = iy_nw + 1; nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) with torch.no_grad(): torch.clamp(ix_nw, 0, IW-1, out=ix_nw) torch.clamp(iy_nw, 0, IH-1, out=iy_nw) torch.clamp(ix_ne, 0, IW-1, out=ix_ne) torch.clamp(iy_ne, 0, IH-1, out=iy_ne) torch.clamp(ix_sw, 0, IW-1, out=ix_sw) torch.clamp(iy_sw, 0, IH-1, out=iy_sw) torch.clamp(ix_se, 0, IW-1, out=ix_se) torch.clamp(iy_se, 0, IH-1, out=iy_se) image = image.view(N, C, IH * IW) nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) return out_val def orthogonal(points, calibrations, transforms=None): ''' Compute the orthogonal projections of 3D points into the image plane by given projection matrix :param points: [B, 3, N] Tensor of 3D points :param calibrations: [B, 3, 4] Tensor of projection matrix :param transforms: [B, 2, 3] Tensor of image transform matrix :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane ''' rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] pts = torch.baddbmm(trans, rot, points) # [B, 3, N] if transforms is not None: scale = transforms[:2, :2] shift = transforms[:2, 2:3] pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) return pts def perspective(points, calibrations, transforms=None): ''' Compute the perspective projections of 3D points into the image plane by given projection matrix :param points: [Bx3xN] Tensor of 3D points :param calibrations: [Bx3x4] Tensor of projection matrix :param transforms: [Bx2x3] Tensor of image transform matrix :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane ''' rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] homo = torch.baddbmm(trans, rot, points) # [B, 3, N] xy = homo[:, :2, :] / homo[:, 2:3, :] if transforms is not None: scale = transforms[:2, :2] shift = transforms[:2, 2:3] xy = torch.baddbmm(shift, scale, xy) xyz = torch.cat([xy, homo[:, 2:3, :]], 1) return xyz