|
import time |
|
import torch |
|
import trimesh |
|
import numpy as np |
|
import torch.optim as optim |
|
from torch import autograd |
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
|
from .common import make_3d_grid |
|
from .utils import libmcubes |
|
from .utils.libmise import MISE |
|
from .utils.libsimplify import simplify_mesh |
|
from .common import transform_pointcloud |
|
|
|
|
|
class Generator3D(object): |
|
''' Generator class for DVRs. |
|
|
|
It provides functions to generate the final mesh as well refining options. |
|
|
|
Args: |
|
model (nn.Module): trained DVR model |
|
points_batch_size (int): batch size for points evaluation |
|
threshold (float): threshold value |
|
refinement_step (int): number of refinement steps |
|
device (device): pytorch device |
|
resolution0 (int): start resolution for MISE |
|
upsampling steps (int): number of upsampling steps |
|
with_normals (bool): whether normals should be estimated |
|
padding (float): how much padding should be used for MISE |
|
simplify_nfaces (int): number of faces the mesh should be simplified to |
|
refine_max_faces (int): max number of faces which are used as batch |
|
size for refinement process (we added this functionality in this |
|
work) |
|
''' |
|
def __init__( |
|
self, |
|
model, |
|
points_batch_size=100000, |
|
threshold=0.5, |
|
refinement_step=0, |
|
device=None, |
|
resolution0=16, |
|
upsampling_steps=3, |
|
with_normals=False, |
|
padding=0.1, |
|
simplify_nfaces=None, |
|
with_color=False, |
|
refine_max_faces=10000 |
|
): |
|
self.model = model.to(device) |
|
self.points_batch_size = points_batch_size |
|
self.refinement_step = refinement_step |
|
self.threshold = threshold |
|
self.device = device |
|
self.resolution0 = resolution0 |
|
self.upsampling_steps = upsampling_steps |
|
self.with_normals = with_normals |
|
self.padding = padding |
|
self.simplify_nfaces = simplify_nfaces |
|
self.with_color = with_color |
|
self.refine_max_faces = refine_max_faces |
|
|
|
def generate_mesh(self, data, return_stats=True): |
|
''' Generates the output mesh. |
|
|
|
Args: |
|
data (tensor): data tensor |
|
return_stats (bool): whether stats should be returned |
|
''' |
|
self.model.eval() |
|
device = self.device |
|
stats_dict = {} |
|
|
|
inputs = data.get('inputs', torch.empty(1, 0)).to(device) |
|
kwargs = {} |
|
|
|
c = self.model.encode_inputs(inputs) |
|
mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs) |
|
|
|
return mesh, stats_dict |
|
|
|
def generate_meshes(self, data, return_stats=True): |
|
''' Generates the output meshes with data of batch size >=1 |
|
|
|
Args: |
|
data (tensor): data tensor |
|
return_stats (bool): whether stats should be returned |
|
''' |
|
self.model.eval() |
|
device = self.device |
|
stats_dict = {} |
|
|
|
inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device) |
|
|
|
meshes = [] |
|
for i in range(inputs.shape[0]): |
|
input_i = inputs[i].unsqueeze(0) |
|
c = self.model.encode_inputs(input_i) |
|
mesh = self.generate_from_latent(c, stats_dict=stats_dict) |
|
meshes.append(mesh) |
|
|
|
return meshes |
|
|
|
def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True): |
|
''' Generates a point cloud from the mesh. |
|
|
|
Args: |
|
mesh (trimesh): mesh |
|
data (dict): data dictionary |
|
n_points (int): number of point cloud points |
|
scale_back (bool): whether to undo scaling (requires a scale |
|
matrix in data dictionary) |
|
''' |
|
pcl = mesh.sample(n_points).astype(np.float32) |
|
|
|
if scale_back: |
|
scale_mat = data.get('camera.scale_mat_0', None) |
|
if scale_mat is not None: |
|
pcl = transform_pointcloud(pcl, scale_mat[0]) |
|
else: |
|
print('Warning: No scale_mat found!') |
|
pcl_out = trimesh.Trimesh(vertices=pcl, process=False) |
|
return pcl_out |
|
|
|
def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs): |
|
''' Generates mesh from latent. |
|
|
|
Args: |
|
c (tensor): latent conditioned code c |
|
pl (tensor): predicted plane parameters |
|
stats_dict (dict): stats dictionary |
|
''' |
|
threshold = np.log(self.threshold) - np.log(1. - self.threshold) |
|
|
|
t0 = time.time() |
|
|
|
box_size = 1 + self.padding |
|
|
|
|
|
if self.upsampling_steps == 0: |
|
nx = self.resolution0 |
|
pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3) |
|
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() |
|
value_grid = values.reshape(nx, nx, nx) |
|
else: |
|
mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold) |
|
|
|
points = mesh_extractor.query() |
|
|
|
while points.shape[0] != 0: |
|
|
|
pointsf = torch.FloatTensor(points).to(self.device) |
|
|
|
pointsf = 2 * pointsf / mesh_extractor.resolution |
|
pointsf = box_size * (pointsf - 1.0) |
|
|
|
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy() |
|
|
|
values = values.astype(np.float64) |
|
mesh_extractor.update(points, values) |
|
points = mesh_extractor.query() |
|
|
|
value_grid = mesh_extractor.to_dense() |
|
|
|
|
|
stats_dict['time (eval points)'] = time.time() - t0 |
|
|
|
mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict) |
|
return mesh |
|
|
|
def eval_points(self, p, c=None, pl=None, **kwargs): |
|
''' Evaluates the occupancy values for the points. |
|
|
|
Args: |
|
p (tensor): points |
|
c (tensor): latent conditioned code c |
|
''' |
|
p_split = torch.split(p, self.points_batch_size) |
|
occ_hats = [] |
|
|
|
for pi in p_split: |
|
pi = pi.unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
occ_hat = self.model.decode(pi, c, pl, **kwargs).logits |
|
|
|
occ_hats.append(occ_hat.squeeze(0).detach().cpu()) |
|
|
|
occ_hat = torch.cat(occ_hats, dim=0) |
|
|
|
return occ_hat |
|
|
|
def extract_mesh(self, occ_hat, c=None, stats_dict=dict()): |
|
''' Extracts the mesh from the predicted occupancy grid. |
|
|
|
Args: |
|
occ_hat (tensor): value grid of occupancies |
|
c (tensor): latent conditioned code c |
|
stats_dict (dict): stats dictionary |
|
''' |
|
|
|
n_x, n_y, n_z = occ_hat.shape |
|
box_size = 1 + self.padding |
|
threshold = np.log(self.threshold) - np.log(1. - self.threshold) |
|
|
|
t0 = time.time() |
|
occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6) |
|
vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold) |
|
stats_dict['time (marching cubes)'] = time.time() - t0 |
|
|
|
vertices -= 0.5 |
|
|
|
vertices -= 1 |
|
|
|
vertices /= np.array([n_x - 1, n_y - 1, n_z - 1]) |
|
vertices *= 2 |
|
vertices = box_size * (vertices - 1) |
|
|
|
|
|
|
|
|
|
|
|
if self.with_normals and not vertices.shape[0] == 0: |
|
t0 = time.time() |
|
normals = self.estimate_normals(vertices, c) |
|
stats_dict['time (normals)'] = time.time() - t0 |
|
else: |
|
normals = None |
|
|
|
mesh = trimesh.Trimesh( |
|
vertices, |
|
triangles, |
|
vertex_normals=normals, |
|
|
|
process=False |
|
) |
|
|
|
|
|
if vertices.shape[0] == 0: |
|
return mesh |
|
|
|
|
|
if self.simplify_nfaces is not None: |
|
t0 = time.time() |
|
mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.) |
|
stats_dict['time (simplify)'] = time.time() - t0 |
|
|
|
|
|
if self.refinement_step > 0: |
|
t0 = time.time() |
|
self.refine_mesh(mesh, occ_hat, c) |
|
stats_dict['time (refine)'] = time.time() - t0 |
|
|
|
|
|
if self.with_color and not vertices.shape[0] == 0: |
|
t0 = time.time() |
|
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c) |
|
stats_dict['time (color)'] = time.time() - t0 |
|
mesh = trimesh.Trimesh( |
|
vertices=mesh.vertices, |
|
faces=mesh.faces, |
|
vertex_normals=mesh.vertex_normals, |
|
vertex_colors=vertex_colors, |
|
process=False |
|
) |
|
|
|
return mesh |
|
|
|
def estimate_colors(self, vertices, c=None): |
|
''' Estimates vertex colors by evaluating the texture field. |
|
|
|
Args: |
|
vertices (numpy array): vertices of the mesh |
|
c (tensor): latent conditioned code c |
|
''' |
|
device = self.device |
|
vertices = torch.FloatTensor(vertices) |
|
vertices_split = torch.split(vertices, self.points_batch_size) |
|
colors = [] |
|
for vi in vertices_split: |
|
vi = vi.to(device) |
|
with torch.no_grad(): |
|
ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu() |
|
colors.append(ci) |
|
|
|
colors = np.concatenate(colors, axis=0) |
|
colors = np.clip(colors, 0, 1) |
|
colors = (colors * 255).astype(np.uint8) |
|
colors = np.concatenate( |
|
[colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1 |
|
) |
|
return colors |
|
|
|
def estimate_normals(self, vertices, c=None): |
|
''' Estimates the normals by computing the gradient of the objective. |
|
|
|
Args: |
|
vertices (numpy array): vertices of the mesh |
|
z (tensor): latent code z |
|
c (tensor): latent conditioned code c |
|
''' |
|
device = self.device |
|
vertices = torch.FloatTensor(vertices) |
|
vertices_split = torch.split(vertices, self.points_batch_size) |
|
|
|
normals = [] |
|
c = c.unsqueeze(0) |
|
for vi in vertices_split: |
|
vi = vi.unsqueeze(0).to(device) |
|
vi.requires_grad_() |
|
occ_hat = self.model.decode(vi, c).logits |
|
out = occ_hat.sum() |
|
out.backward() |
|
ni = -vi.grad |
|
ni = ni / torch.norm(ni, dim=-1, keepdim=True) |
|
ni = ni.squeeze(0).cpu().numpy() |
|
normals.append(ni) |
|
|
|
normals = np.concatenate(normals, axis=0) |
|
return normals |
|
|
|
def refine_mesh(self, mesh, occ_hat, c=None): |
|
''' Refines the predicted mesh. |
|
|
|
Args: |
|
mesh (trimesh object): predicted mesh |
|
occ_hat (tensor): predicted occupancy grid |
|
c (tensor): latent conditioned code c |
|
''' |
|
|
|
self.model.eval() |
|
|
|
|
|
n_x, n_y, n_z = occ_hat.shape |
|
assert (n_x == n_y == n_z) |
|
|
|
threshold = self.threshold |
|
|
|
|
|
v0 = torch.FloatTensor(mesh.vertices).to(self.device) |
|
v = torch.nn.Parameter(v0.clone()) |
|
|
|
|
|
faces = torch.LongTensor(mesh.faces) |
|
|
|
|
|
|
|
c = c.detach() |
|
|
|
|
|
optimizer = optim.RMSprop([v], lr=1e-5) |
|
|
|
|
|
ds_faces = TensorDataset(faces) |
|
dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True) |
|
|
|
|
|
|
|
|
|
it_r = 0 |
|
while it_r < self.refinement_step: |
|
for f_it in dataloader: |
|
f_it = f_it[0].to(self.device) |
|
optimizer.zero_grad() |
|
|
|
|
|
face_vertex = v[f_it] |
|
eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0]) |
|
eps = torch.FloatTensor(eps).to(self.device) |
|
face_point = (face_vertex * eps[:, :, None]).sum(dim=1) |
|
|
|
face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :] |
|
face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :] |
|
face_normal = torch.cross(face_v1, face_v2) |
|
face_normal = face_normal / \ |
|
(face_normal.norm(dim=1, keepdim=True) + 1e-10) |
|
|
|
face_value = torch.cat( |
|
[ |
|
torch.sigmoid(self.model.decode(p_split, c).logits) |
|
for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1) |
|
], |
|
dim=1 |
|
) |
|
|
|
normal_target = -autograd.grad([face_value.sum()], [face_point], |
|
create_graph=True)[0] |
|
|
|
normal_target = \ |
|
normal_target / \ |
|
(normal_target.norm(dim=1, keepdim=True) + 1e-10) |
|
loss_target = (face_value - threshold).pow(2).mean() |
|
loss_normal = \ |
|
(face_normal - normal_target).pow(2).sum(dim=1).mean() |
|
|
|
loss = loss_target + 0.01 * loss_normal |
|
|
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
it_r += 1 |
|
|
|
if it_r >= self.refinement_step: |
|
break |
|
|
|
mesh.vertices = v.data.cpu().numpy() |
|
return mesh |
|
|