import torch import torch.nn.functional as F import numpy as np def ppts_to_pts(ppts, bw, A): """transform points from the pose space to the zero space""" sh = ppts.shape bw = bw.permute(0, 2, 1) A = torch.bmm(bw, A.view(sh[0], 24, -1)) A = A.view(sh[0], -1, 4, 4) pts = ppts - A[..., :3, 3] R_inv = torch.inverse(A[..., :3, :3]) pts = torch.sum(R_inv * pts[:, :, None], dim=3) return pts def grid_sample_blend_weights(grid_coords, bw): # the blend weight is indexed by xyz grid_coords = grid_coords[:, None, None] bw = F.grid_sample(bw, grid_coords, padding_mode='border', align_corners=True) bw = bw[:, :, 0, 0] return bw def bounds_grid_sample_blend_weights(pts, bw, bounds): """grid sample blend weights""" pts = pts.clone() # interpolate blend weights min_xyz = bounds[:, 0] max_xyz = bounds[:, 1] bounds = max_xyz[:, None] - min_xyz[:, None] grid_coords = (pts - min_xyz[:, None]) / bounds grid_coords = grid_coords * 2 - 1 # convert xyz to zyx, since the blend weight is indexed by xyz grid_coords = grid_coords[..., [2, 1, 0]] # the blend weight is indexed by xyz bw = bw.permute(0, 4, 1, 2, 3) grid_coords = grid_coords[:, None, None] bw = F.grid_sample(bw, grid_coords, padding_mode='border', align_corners=True) bw = bw[:, :, 0, 0] return bw def grid_sample_A_blend_weights(nf_grid_coords, bw): """ nf_grid_coords: batch_size x N_samples x 24 x 3 bw: batch_size x 24 x 64 x 64 x 64 """ bws = [] for i in range(24): nf_grid_coords_ = nf_grid_coords[:, :, i] nf_grid_coords_ = nf_grid_coords_[:, None, None] bw_ = F.grid_sample(bw[:, i:i + 1], nf_grid_coords_, padding_mode='border', align_corners=True) bw_ = bw_[:, :, 0, 0] bws.append(bw_) bw = torch.cat(bws, dim=1) return bw def ppts_to_pts(pts, bw, A): """transform points from the pose space to the t pose""" sh = pts.shape bw = bw.permute(0, 2, 1) A = torch.bmm(bw, A.view(sh[0], 24, -1)) A = A.view(sh[0], -1, 4, 4) pts = pts - A[..., :3, 3] R_inv = torch.inverse(A[..., :3, :3]) pts = torch.sum(R_inv * pts[:, :, None], dim=3) return pts