Spaces:
Running
on
L40S
Running
on
L40S
File size: 14,260 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 |
import time
import torch
import trimesh
import numpy as np
import torch.optim as optim
from torch import autograd
from torch.utils.data import TensorDataset, DataLoader
from .common import make_3d_grid
from .utils import libmcubes
from .utils.libmise import MISE
from .utils.libsimplify import simplify_mesh
from .common import transform_pointcloud
class Generator3D(object):
''' Generator class for DVRs.
It provides functions to generate the final mesh as well refining options.
Args:
model (nn.Module): trained DVR model
points_batch_size (int): batch size for points evaluation
threshold (float): threshold value
refinement_step (int): number of refinement steps
device (device): pytorch device
resolution0 (int): start resolution for MISE
upsampling steps (int): number of upsampling steps
with_normals (bool): whether normals should be estimated
padding (float): how much padding should be used for MISE
simplify_nfaces (int): number of faces the mesh should be simplified to
refine_max_faces (int): max number of faces which are used as batch
size for refinement process (we added this functionality in this
work)
'''
def __init__(
self,
model,
points_batch_size=100000,
threshold=0.5,
refinement_step=0,
device=None,
resolution0=16,
upsampling_steps=3,
with_normals=False,
padding=0.1,
simplify_nfaces=None,
with_color=False,
refine_max_faces=10000
):
self.model = model.to(device)
self.points_batch_size = points_batch_size
self.refinement_step = refinement_step
self.threshold = threshold
self.device = device
self.resolution0 = resolution0
self.upsampling_steps = upsampling_steps
self.with_normals = with_normals
self.padding = padding
self.simplify_nfaces = simplify_nfaces
self.with_color = with_color
self.refine_max_faces = refine_max_faces
def generate_mesh(self, data, return_stats=True):
''' Generates the output mesh.
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 0)).to(device)
kwargs = {}
c = self.model.encode_inputs(inputs)
mesh = self.generate_from_latent(c, stats_dict=stats_dict, data=data, **kwargs)
return mesh, stats_dict
def generate_meshes(self, data, return_stats=True):
''' Generates the output meshes with data of batch size >=1
Args:
data (tensor): data tensor
return_stats (bool): whether stats should be returned
'''
self.model.eval()
device = self.device
stats_dict = {}
inputs = data.get('inputs', torch.empty(1, 1, 0)).to(device)
meshes = []
for i in range(inputs.shape[0]):
input_i = inputs[i].unsqueeze(0)
c = self.model.encode_inputs(input_i)
mesh = self.generate_from_latent(c, stats_dict=stats_dict)
meshes.append(mesh)
return meshes
def generate_pointcloud(self, mesh, data=None, n_points=2000000, scale_back=True):
''' Generates a point cloud from the mesh.
Args:
mesh (trimesh): mesh
data (dict): data dictionary
n_points (int): number of point cloud points
scale_back (bool): whether to undo scaling (requires a scale
matrix in data dictionary)
'''
pcl = mesh.sample(n_points).astype(np.float32)
if scale_back:
scale_mat = data.get('camera.scale_mat_0', None)
if scale_mat is not None:
pcl = transform_pointcloud(pcl, scale_mat[0])
else:
print('Warning: No scale_mat found!')
pcl_out = trimesh.Trimesh(vertices=pcl, process=False)
return pcl_out
def generate_from_latent(self, c=None, pl=None, stats_dict={}, data=None, **kwargs):
''' Generates mesh from latent.
Args:
c (tensor): latent conditioned code c
pl (tensor): predicted plane parameters
stats_dict (dict): stats dictionary
'''
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
t0 = time.time()
# Compute bounding box size
box_size = 1 + self.padding
# Shortcut
if self.upsampling_steps == 0:
nx = self.resolution0
pointsf = box_size * make_3d_grid((-0.5, ) * 3, (0.5, ) * 3, (nx, ) * 3)
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
value_grid = values.reshape(nx, nx, nx)
else:
mesh_extractor = MISE(self.resolution0, self.upsampling_steps, threshold)
points = mesh_extractor.query()
while points.shape[0] != 0:
# Query points
pointsf = torch.FloatTensor(points).to(self.device)
# Normalize to bounding box
pointsf = 2 * pointsf / mesh_extractor.resolution
pointsf = box_size * (pointsf - 1.0)
# Evaluate model and update
values = self.eval_points(pointsf, c, pl, **kwargs).cpu().numpy()
values = values.astype(np.float64)
mesh_extractor.update(points, values)
points = mesh_extractor.query()
value_grid = mesh_extractor.to_dense()
# Extract mesh
stats_dict['time (eval points)'] = time.time() - t0
mesh = self.extract_mesh(value_grid, c, stats_dict=stats_dict)
return mesh
def eval_points(self, p, c=None, pl=None, **kwargs):
''' Evaluates the occupancy values for the points.
Args:
p (tensor): points
c (tensor): latent conditioned code c
'''
p_split = torch.split(p, self.points_batch_size)
occ_hats = []
for pi in p_split:
pi = pi.unsqueeze(0).to(self.device)
with torch.no_grad():
occ_hat = self.model.decode(pi, c, pl, **kwargs).logits
occ_hats.append(occ_hat.squeeze(0).detach().cpu())
occ_hat = torch.cat(occ_hats, dim=0)
return occ_hat
def extract_mesh(self, occ_hat, c=None, stats_dict=dict()):
''' Extracts the mesh from the predicted occupancy grid.
Args:
occ_hat (tensor): value grid of occupancies
c (tensor): latent conditioned code c
stats_dict (dict): stats dictionary
'''
# Some short hands
n_x, n_y, n_z = occ_hat.shape
box_size = 1 + self.padding
threshold = np.log(self.threshold) - np.log(1. - self.threshold)
# Make sure that mesh is watertight
t0 = time.time()
occ_hat_padded = np.pad(occ_hat, 1, 'constant', constant_values=-1e6)
vertices, triangles = libmcubes.marching_cubes(occ_hat_padded, threshold)
stats_dict['time (marching cubes)'] = time.time() - t0
# Strange behaviour in libmcubes: vertices are shifted by 0.5
vertices -= 0.5
# Undo padding
vertices -= 1
# Normalize to bounding box
vertices /= np.array([n_x - 1, n_y - 1, n_z - 1])
vertices *= 2
vertices = box_size * (vertices - 1)
# mesh_pymesh = pymesh.form_mesh(vertices, triangles)
# mesh_pymesh = fix_pymesh(mesh_pymesh)
# Estimate normals if needed
if self.with_normals and not vertices.shape[0] == 0:
t0 = time.time()
normals = self.estimate_normals(vertices, c)
stats_dict['time (normals)'] = time.time() - t0
else:
normals = None
# Create mesh
mesh = trimesh.Trimesh(
vertices,
triangles,
vertex_normals=normals,
# vertex_colors=vertex_colors,
process=False
)
# Directly return if mesh is empty
if vertices.shape[0] == 0:
return mesh
# TODO: normals are lost here
if self.simplify_nfaces is not None:
t0 = time.time()
mesh = simplify_mesh(mesh, self.simplify_nfaces, 5.)
stats_dict['time (simplify)'] = time.time() - t0
# Refine mesh
if self.refinement_step > 0:
t0 = time.time()
self.refine_mesh(mesh, occ_hat, c)
stats_dict['time (refine)'] = time.time() - t0
# Estimate Vertex Colors
if self.with_color and not vertices.shape[0] == 0:
t0 = time.time()
vertex_colors = self.estimate_colors(np.array(mesh.vertices), c)
stats_dict['time (color)'] = time.time() - t0
mesh = trimesh.Trimesh(
vertices=mesh.vertices,
faces=mesh.faces,
vertex_normals=mesh.vertex_normals,
vertex_colors=vertex_colors,
process=False
)
return mesh
def estimate_colors(self, vertices, c=None):
''' Estimates vertex colors by evaluating the texture field.
Args:
vertices (numpy array): vertices of the mesh
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
colors = []
for vi in vertices_split:
vi = vi.to(device)
with torch.no_grad():
ci = self.model.decode_color(vi.unsqueeze(0), c).squeeze(0).cpu()
colors.append(ci)
colors = np.concatenate(colors, axis=0)
colors = np.clip(colors, 0, 1)
colors = (colors * 255).astype(np.uint8)
colors = np.concatenate(
[colors, np.full((colors.shape[0], 1), 255, dtype=np.uint8)], axis=1
)
return colors
def estimate_normals(self, vertices, c=None):
''' Estimates the normals by computing the gradient of the objective.
Args:
vertices (numpy array): vertices of the mesh
z (tensor): latent code z
c (tensor): latent conditioned code c
'''
device = self.device
vertices = torch.FloatTensor(vertices)
vertices_split = torch.split(vertices, self.points_batch_size)
normals = []
c = c.unsqueeze(0)
for vi in vertices_split:
vi = vi.unsqueeze(0).to(device)
vi.requires_grad_()
occ_hat = self.model.decode(vi, c).logits
out = occ_hat.sum()
out.backward()
ni = -vi.grad
ni = ni / torch.norm(ni, dim=-1, keepdim=True)
ni = ni.squeeze(0).cpu().numpy()
normals.append(ni)
normals = np.concatenate(normals, axis=0)
return normals
def refine_mesh(self, mesh, occ_hat, c=None):
''' Refines the predicted mesh.
Args:
mesh (trimesh object): predicted mesh
occ_hat (tensor): predicted occupancy grid
c (tensor): latent conditioned code c
'''
self.model.eval()
# Some shorthands
n_x, n_y, n_z = occ_hat.shape
assert (n_x == n_y == n_z)
# threshold = np.log(self.threshold) - np.log(1. - self.threshold)
threshold = self.threshold
# Vertex parameter
v0 = torch.FloatTensor(mesh.vertices).to(self.device)
v = torch.nn.Parameter(v0.clone())
# Faces of mesh
faces = torch.LongTensor(mesh.faces)
# detach c; otherwise graph needs to be retained
# caused by new Pytorch version?
c = c.detach()
# Start optimization
optimizer = optim.RMSprop([v], lr=1e-5)
# Dataset
ds_faces = TensorDataset(faces)
dataloader = DataLoader(ds_faces, batch_size=self.refine_max_faces, shuffle=True)
# We updated the refinement algorithm to subsample faces; this is
# usefull when using a high extraction resolution / when working on
# small GPUs
it_r = 0
while it_r < self.refinement_step:
for f_it in dataloader:
f_it = f_it[0].to(self.device)
optimizer.zero_grad()
# Loss
face_vertex = v[f_it]
eps = np.random.dirichlet((0.5, 0.5, 0.5), size=f_it.shape[0])
eps = torch.FloatTensor(eps).to(self.device)
face_point = (face_vertex * eps[:, :, None]).sum(dim=1)
face_v1 = face_vertex[:, 1, :] - face_vertex[:, 0, :]
face_v2 = face_vertex[:, 2, :] - face_vertex[:, 1, :]
face_normal = torch.cross(face_v1, face_v2)
face_normal = face_normal / \
(face_normal.norm(dim=1, keepdim=True) + 1e-10)
face_value = torch.cat(
[
torch.sigmoid(self.model.decode(p_split, c).logits)
for p_split in torch.split(face_point.unsqueeze(0), 20000, dim=1)
],
dim=1
)
normal_target = -autograd.grad([face_value.sum()], [face_point],
create_graph=True)[0]
normal_target = \
normal_target / \
(normal_target.norm(dim=1, keepdim=True) + 1e-10)
loss_target = (face_value - threshold).pow(2).mean()
loss_normal = \
(face_normal - normal_target).pow(2).sum(dim=1).mean()
loss = loss_target + 0.01 * loss_normal
# Update
loss.backward()
optimizer.step()
# Update it_r
it_r += 1
if it_r >= self.refinement_step:
break
mesh.vertices = v.data.cpu().numpy()
return mesh
|