|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
from util.flexicubes import FlexiCubes
|
|
|
|
import torch.nn.functional as F
|
|
|
|
def get_center_boundary_index(grid_res, device):
|
|
v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
|
|
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
|
|
center_indices = torch.nonzero(v.reshape(-1))
|
|
|
|
v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
|
|
v[:2, ...] = True
|
|
v[-2:, ...] = True
|
|
v[:, :2, ...] = True
|
|
v[:, -2:, ...] = True
|
|
v[:, :, :2] = True
|
|
v[:, :, -2:] = True
|
|
boundary_indices = torch.nonzero(v.reshape(-1))
|
|
return center_indices, boundary_indices
|
|
|
|
|
|
|
|
|
|
class FlexiCubesGeometry(object):
|
|
def __init__(
|
|
self, grid_res=64, scale=2.0, device='cuda', renderer=None,
|
|
render_type='neural_render', args=None):
|
|
super(FlexiCubesGeometry, self).__init__()
|
|
self.grid_res = grid_res
|
|
self.device = device
|
|
self.args = args
|
|
self.fc = FlexiCubes(device, weight_scale=0.5)
|
|
self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
|
|
if isinstance(scale, list):
|
|
self.verts[:, 0] = self.verts[:, 0] * scale[0]
|
|
self.verts[:, 1] = self.verts[:, 1] * scale[1]
|
|
self.verts[:, 2] = self.verts[:, 2] * scale[1]
|
|
else:
|
|
self.verts = self.verts * scale
|
|
|
|
all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
|
|
self.all_edges = torch.unique(all_edges, dim=0)
|
|
|
|
|
|
self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
|
|
self.renderer = renderer
|
|
self.render_type = render_type
|
|
|
|
def getAABB(self):
|
|
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
|
|
|
def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
|
|
if indices is None:
|
|
indices = self.indices
|
|
|
|
verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
|
|
beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
|
|
gamma_f=weight_n[:, 20], training=is_training
|
|
)
|
|
return verts, faces, v_reg_loss
|
|
|
|
|
|
def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
|
|
return_value = dict()
|
|
if self.render_type == 'neural_render':
|
|
tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
|
|
mesh_v_nx3.unsqueeze(dim=0),
|
|
mesh_f_fx3.int(),
|
|
camera_mv_bx4x4,
|
|
mesh_v_nx3.unsqueeze(dim=0),
|
|
resolution=resolution,
|
|
device=self.device,
|
|
hierarchical_mask=hierarchical_mask
|
|
)
|
|
|
|
return_value['tex_pos'] = tex_pos
|
|
return_value['mask'] = mask
|
|
return_value['hard_mask'] = hard_mask
|
|
return_value['rast'] = rast
|
|
return_value['v_pos_clip'] = v_pos_clip
|
|
return_value['mask_pyramid'] = mask_pyramid
|
|
return_value['depth'] = depth
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return return_value
|
|
|
|
def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
|
|
|
|
v_list = []
|
|
f_list = []
|
|
n_batch = v_deformed_bxnx3.shape[0]
|
|
all_render_output = []
|
|
for i_batch in range(n_batch):
|
|
verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
|
|
v_list.append(verts_nx3)
|
|
f_list.append(faces_fx3)
|
|
render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
|
|
all_render_output.append(render_output)
|
|
|
|
|
|
return_keys = all_render_output[0].keys()
|
|
return_value = dict()
|
|
for k in return_keys:
|
|
value = [v[k] for v in all_render_output]
|
|
return_value[k] = value
|
|
|
|
return return_value
|
|
|