Spaces:
Build error
Build error
import torch | |
def index(feat, uv): | |
''' | |
:param feat: [B, C, H, W] image features | |
:param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1] | |
:return: [B, C, N] image features at the uv coordinates | |
''' | |
uv = uv.transpose(1, 2) # [B, N, 2] | |
uv = uv.unsqueeze(2) # [B, N, 1, 2] | |
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample | |
# for old versions, simply remove the aligned_corners argument. | |
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1] | |
return samples[:, :, :, 0] # [B, C, N] | |
def orthogonal(points, calibrations, transforms=None): | |
''' | |
Compute the orthogonal projections of 3D points into the image plane by given projection matrix | |
:param points: [B, 3, N] Tensor of 3D points | |
:param calibrations: [B, 4, 4] Tensor of projection matrix | |
:param transforms: [B, 2, 3] Tensor of image transform matrix | |
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane | |
''' | |
rot = calibrations[:, :3, :3] | |
trans = calibrations[:, :3, 3:4] | |
pts = torch.baddbmm(trans, rot, points) # [B, 3, N] | |
if transforms is not None: | |
scale = transforms[:2, :2] | |
shift = transforms[:2, 2:3] | |
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) | |
return pts | |
def perspective(points, calibrations, transforms=None): | |
''' | |
Compute the perspective projections of 3D points into the image plane by given projection matrix | |
:param points: [Bx3xN] Tensor of 3D points | |
:param calibrations: [Bx4x4] Tensor of projection matrix | |
:param transforms: [Bx2x3] Tensor of image transform matrix | |
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane | |
''' | |
rot = calibrations[:, :3, :3] | |
trans = calibrations[:, :3, 3:4] | |
homo = torch.baddbmm(trans, rot, points) # [B, 3, N] | |
xy = homo[:, :2, :] / homo[:, 2:3, :] | |
if transforms is not None: | |
scale = transforms[:2, :2] | |
shift = transforms[:2, 2:3] | |
xy = torch.baddbmm(shift, scale, xy) | |
xyz = torch.cat([xy, homo[:, 2:3, :]], 1) | |
return xyz | |