PSHuman / lib /pymafx /utils /mesh_generation.py
fffiloni's picture
Migrated from GitHub
2252f3d verified
raw
history blame
14.3 kB
import time
import torch
import trimesh
import numpy as np
import torch.optim as optim
from torch import autograd
from torch.utils.data import TensorDataset, DataLoader
from .common import make_3d_grid
from .utils import libmcubes
from .utils.libmise import MISE
from .utils.libsimplify import simplify_mesh
from .common import transform_pointcloud
class Generator3D(object):
''' Generator class for DVRs.
It provides functions to generate the final mesh as well refining options.
Args:
model (nn.Module): trained DVR model
points_batch_size (int): batch size for points evaluation
threshold (float): threshold value
refinement_step (int): number of refinement steps
device (device): pytorch device
resolution0 (int): start resolution for MISE
upsampling steps (int): number of upsampling steps
with_normals (bool): whether normals should be estimated
padding (float): how much padding should be used for MISE
simplify_nfaces (int): number of faces the mesh should be simplified to
refine_max_faces (int): max number of faces which are used as batch
size for refinement process (we added this functionality in this
work)
'''
def __init__(
self,
model,
points_batch_size=100000,
threshold=0.5,
refinement_step=0,
device=None,
resolution0=16,
upsampling_steps=3,
with_normals=False,
padding=0.1,
simplify_nfaces=None,
with_color=False,
refine_max_faces=10000
):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.refinement_step = refinement_step
self.threshold = threshold
self.device = device
self.resolution0 = resolution0
self.upsampling_steps = upsampling_steps
self.with_normals = with_normals
self.padding = padding
self.simplify_nfaces = simplify_nfaces
self.with_color = with_color
self.refine_max_faces = refine_max_faces
def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh.
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 0)).to(device)
kwargs = {}
c = self.model.encode_inputs(inputs)
mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs)
return mesh, stats_dict
def generate_meshes(self, data, return_stats=True):
''' Generates the output meshes with data of batch size >=1
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device)
meshes = []
for i in range(inputs.shape[0]):
input_i = inputs[i].unsqueeze(0)
c = self.model.encode_inputs(input_i)
mesh = self.generate_from_latent(c, stats_dict=stats_dict)
meshes.append(mesh)
return meshes
def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True):
''' Generates a point cloud from the mesh.
Args:
mesh (trimesh): mesh
data (dict): data dictionary
n_points (int): number of point cloud points
scale_back (bool): whether to undo scaling (requires a scale
matrix in data dictionary)
'''
pcl = mesh.sample(n_points).astype(np.float32)
if scale_back:
scale_mat = data.get('camera.scale_mat_0', None)
if scale_mat is not None:
pcl = transform_pointcloud(pcl, scale_mat[0])
else:
print('Warning: No scale_mat found!')
pcl_out = trimesh.Trimesh(vertices=pcl, process=False)
return pcl_out
def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs):
''' Generates mesh from latent.
Args:
c (tensor): latent conditioned code c
pl (tensor): predicted plane parameters
stats_dict (dict): stats dictionary
'''
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
t0 = time.time()
# Compute bounding box size
box_size = 1 + self.padding
# Shortcut
if self.upsampling_steps == 0:
nx = self.resolution0
pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3)
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
value_grid = values.reshape(nx, nx, nx)
else:
mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold)
points = mesh_extractor.query()
while points.shape[0] != 0:
# Query points
pointsf = torch.FloatTensor(points).to(self.device)
# Normalize to bounding box
pointsf = 2 * pointsf / mesh_extractor.resolution
pointsf = box_size * (pointsf - 1.0)
# Evaluate model and update
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
values = values.astype(np.float64)
mesh_extractor.update(points, values)
points = mesh_extractor.query()
value_grid = mesh_extractor.to_dense()
# Extract mesh
stats_dict['time (eval points)'] = time.time() - t0
mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict)
return mesh
def eval_points(self, p, c=None, pl=None, **kwargs):
''' Evaluates the occupancy values for the points.
Args:
p (tensor): points
c (tensor): latent conditioned code c
'''
p_split = torch.split(p, self.points_batch_size)
occ_hats = []
for pi in p_split:
pi = pi.unsqueeze(0).to(self.device)
with torch.no_grad():
occ_hat = self.model.decode(pi, c, pl, **kwargs).logits
occ_hats.append(occ_hat.squeeze(0).detach().cpu())
occ_hat = torch.cat(occ_hats, dim=0)
return occ_hat
def extract_mesh(self, occ_hat, c=None, stats_dict=dict()):
''' Extracts the mesh from the predicted occupancy grid.
Args:
occ_hat (tensor): value grid of occupancies
c (tensor): latent conditioned code c
stats_dict (dict): stats dictionary
'''
# Some short hands
n_x, n_y, n_z = occ_hat.shape
box_size = 1 + self.padding
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
# Make sure that mesh is watertight
t0 = time.time()
occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold)
stats_dict['time (marching cubes)'] = time.time() - t0
# Strange behaviour in libmcubes: vertices are shifted by 0.5
vertices -= 0.5
# Undo padding
vertices -= 1
# Normalize to bounding box
vertices /= np.array([n_x - 1, n_y - 1, n_z - 1])
vertices *= 2
vertices = box_size * (vertices - 1)
# mesh_pymesh = pymesh.form_mesh(vertices, triangles)
# mesh_pymesh = fix_pymesh(mesh_pymesh)
# Estimate normals if needed
if self.with_normals and not vertices.shape[0] == 0:
t0 = time.time()
normals = self.estimate_normals(vertices, c)
stats_dict['time (normals)'] = time.time() - t0
else:
normals = None
# Create mesh
mesh = trimesh.Trimesh(
vertices,
triangles,
vertex_normals=normals,
# vertex_colors=vertex_colors,
process=False
)
# Directly return if mesh is empty
if vertices.shape[0] == 0:
return mesh
# TODO: normals are lost here
if self.simplify_nfaces is not None:
t0 = time.time()
mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.)
stats_dict['time (simplify)'] = time.time() - t0
# Refine mesh
if self.refinement_step > 0:
t0 = time.time()
self.refine_mesh(mesh, occ_hat, c)
stats_dict['time (refine)'] = time.time() - t0
# Estimate Vertex Colors
if self.with_color and not vertices.shape[0] == 0:
t0 = time.time()
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c)
stats_dict['time (color)'] = time.time() - t0
mesh = trimesh.Trimesh(
vertices=mesh.vertices,
faces=mesh.faces,
vertex_normals=mesh.vertex_normals,
vertex_colors=vertex_colors,
process=False
)
return mesh
def estimate_colors(self, vertices, c=None):
''' Estimates vertex colors by evaluating the texture field.
Args:
vertices (numpy array): vertices of the mesh
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
colors = []
for vi in vertices_split:
vi = vi.to(device)
with torch.no_grad():
ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu()
colors.append(ci)
colors = np.concatenate(colors, axis=0)
colors = np.clip(colors, 0, 1)
colors = (colors * 255).astype(np.uint8)
colors = np.concatenate(
[colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1
)
return colors
def estimate_normals(self, vertices, c=None):
''' Estimates the normals by computing the gradient of the objective.
Args:
vertices (numpy array): vertices of the mesh
z (tensor): latent code z
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
normals = []
c = c.unsqueeze(0)
for vi in vertices_split:
vi = vi.unsqueeze(0).to(device)
vi.requires_grad_()
occ_hat = self.model.decode(vi, c).logits
out = occ_hat.sum()
out.backward()
ni = -vi.grad
ni = ni / torch.norm(ni, dim=-1, keepdim=True)
ni = ni.squeeze(0).cpu().numpy()
normals.append(ni)
normals = np.concatenate(normals, axis=0)
return normals
def refine_mesh(self, mesh, occ_hat, c=None):
''' Refines the predicted mesh.
Args:
mesh (trimesh object): predicted mesh
occ_hat (tensor): predicted occupancy grid
c (tensor): latent conditioned code c
'''
self.model.eval()
# Some shorthands
n_x, n_y, n_z = occ_hat.shape
assert (n_x == n_y == n_z)
# threshold = np.log(self.threshold) - np.log(1. - self.threshold)
threshold = self.threshold
# Vertex parameter
v0 = torch.FloatTensor(mesh.vertices).to(self.device)
v = torch.nn.Parameter(v0.clone())
# Faces of mesh
faces = torch.LongTensor(mesh.faces)
# detach c; otherwise graph needs to be retained
# caused by new Pytorch version?
c = c.detach()
# Start optimization
optimizer = optim.RMSprop([v], lr=1e-5)
# Dataset
ds_faces = TensorDataset(faces)
dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True)
# We updated the refinement algorithm to subsample faces; this is
# usefull when using a high extraction resolution / when working on
# small GPUs
it_r = 0
while it_r < self.refinement_step:
for f_it in dataloader:
f_it = f_it[0].to(self.device)
optimizer.zero_grad()
# Loss
face_vertex = v[f_it]
eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0])
eps = torch.FloatTensor(eps).to(self.device)
face_point = (face_vertex * eps[:, :, None]).sum(dim=1)
face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :]
face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :]
face_normal = torch.cross(face_v1, face_v2)
face_normal = face_normal / \
(face_normal.norm(dim=1, keepdim=True) + 1e-10)
face_value = torch.cat(
[
torch.sigmoid(self.model.decode(p_split, c).logits)
for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
],
dim=1
)
normal_target = -autograd.grad([face_value.sum()], [face_point],
create_graph=True)[0]
normal_target = \
normal_target / \
(normal_target.norm(dim=1, keepdim=True) + 1e-10)
loss_target = (face_value - threshold).pow(2).mean()
loss_normal = \
(face_normal - normal_target).pow(2).sum(dim=1).mean()
loss = loss_target + 0.01 * loss_normal
# Update
loss.backward()
optimizer.step()
# Update it_r
it_r += 1
if it_r >= self.refinement_step:
break
mesh.vertices = v.data.cpu().numpy()
return mesh