Spaces:
Runtime error
Runtime error
File size: 5,068 Bytes
c3d3e4a 487ee6d c3d3e4a 487ee6d c3d3e4a 487ee6d c3d3e4a fb140f6 c3d3e4a fb140f6 c3d3e4a fb140f6 c3d3e4a fb140f6 c3d3e4a de4d7c5 c3d3e4a fb140f6 c3d3e4a 487ee6d c3d3e4a de4d7c5 487ee6d c3d3e4a fb140f6 c3d3e4a fb140f6 c3d3e4a fb140f6 c3d3e4a de4d7c5 c3d3e4a e0ba903 c3d3e4a fb140f6 c3d3e4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
# All rights reserved.
# This file is part of the pytorch-nicp,
# and is released under the "MIT License Agreement". Please see the LICENSE
# file that should have been included as part of this package.
import torch
import torch.nn as nn
import trimesh
from pytorch3d.loss import chamfer_distance
from pytorch3d.structures import Meshes
from tqdm import tqdm
from lib.common.train_util import init_loss
from lib.dataset.mesh_util import update_mesh_shape_prior_losses
# reference: https://github.com/wuhaozhe/pytorch-nicp
class LocalAffine(nn.Module):
def __init__(self, num_points, batch_size=1, edges=None):
'''
specify the number of points, the number of points should be constant across the batch
and the edges torch.Longtensor() with shape N * 2
the local affine operator supports batch operation
batch size must be constant
add additional pooling on top of w matrix
'''
super(LocalAffine, self).__init__()
self.A = nn.Parameter(
torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1)
)
self.b = nn.Parameter(
torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(
batch_size, num_points, 1, 1
)
)
self.edges = edges
self.num_points = num_points
def stiffness(self):
'''
calculate the stiffness of local affine transformation
f norm get infinity gradient when w is zero matrix,
'''
if self.edges is None:
raise Exception("edges cannot be none when calculate stiff")
affine_weight = torch.cat((self.A, self.b), dim=3)
w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0])
w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1])
w_diff = (w1 - w2)**2
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
return w_diff, w_rigid
def forward(self, x):
'''
x should have shape of B * N * 3 * 1
'''
x = x.unsqueeze(3)
out_x = torch.matmul(self.A, x)
out_x = out_x + self.b
out_x.squeeze_(3)
stiffness, rigid = self.stiffness()
return out_x, stiffness, rigid
def trimesh2meshes(mesh):
'''
convert trimesh mesh to pytorch3d mesh
'''
verts = torch.from_numpy(mesh.vertices).float()
faces = torch.from_numpy(mesh.faces).long()
mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0))
return mesh
def register(target_mesh, src_mesh, device, verbose=True):
# define local_affine deform verts
tgt_mesh = trimesh2meshes(target_mesh).to(device)
src_verts = src_mesh.verts_padded().clone()
local_affine_model = LocalAffine(
src_mesh.verts_padded().shape[1],
src_mesh.verts_padded().shape[0], src_mesh.edges_packed()
).to(device)
optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}],
lr=1e-2,
amsgrad=True)
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer_cloth,
mode="min",
factor=0.1,
verbose=0,
min_lr=1e-5,
patience=5,
)
losses = init_loss()
if verbose:
loop_cloth = tqdm(range(100))
else:
loop_cloth = range(100)
for i in loop_cloth:
optimizer_cloth.zero_grad()
deformed_verts, stiffness, rigid = local_affine_model(x=src_verts)
src_mesh = src_mesh.update_padded(deformed_verts)
# losses for laplacian, edge, normal consistency
update_mesh_shape_prior_losses(src_mesh, losses)
losses["cloth"]["value"] = chamfer_distance(
x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded()
)[0]
losses["stiff"]["value"] = torch.mean(stiffness)
losses["rigid"]["value"] = torch.mean(rigid)
# Weighted sum of the losses
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
pbar_desc = "Register SMPL-X -> d-BiNI -- "
for k in losses.keys():
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
cloth_loss = cloth_loss + \
losses[k]["value"] * losses[k]["weight"]
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
if verbose:
pbar_desc += f"TOTAL: {cloth_loss:.3f}"
loop_cloth.set_description(pbar_desc)
# update params
cloth_loss.backward(retain_graph=True)
optimizer_cloth.step()
scheduler_cloth.step(cloth_loss)
print(pbar_desc)
final = trimesh.Trimesh(
src_mesh.verts_packed().detach().squeeze(0).cpu(),
src_mesh.faces_packed().detach().squeeze(0).cpu(),
process=False,
maintains_order=True
)
return final
|