|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import trimesh |
|
from pytorch3d.loss import chamfer_distance |
|
from pytorch3d.structures import Meshes |
|
from tqdm import tqdm |
|
|
|
from lib.common.train_util import init_loss |
|
from lib.dataset.mesh_util import update_mesh_shape_prior_losses |
|
|
|
|
|
|
|
class LocalAffine(nn.Module): |
|
def __init__(self, num_points, batch_size=1, edges=None): |
|
''' |
|
specify the number of points, the number of points should be constant across the batch |
|
and the edges torch.Longtensor() with shape N * 2 |
|
the local affine operator supports batch operation |
|
batch size must be constant |
|
add additional pooling on top of w matrix |
|
''' |
|
super(LocalAffine, self).__init__() |
|
self.A = nn.Parameter( |
|
torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1) |
|
) |
|
self.b = nn.Parameter( |
|
torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat( |
|
batch_size, num_points, 1, 1 |
|
) |
|
) |
|
self.edges = edges |
|
self.num_points = num_points |
|
|
|
def stiffness(self): |
|
''' |
|
calculate the stiffness of local affine transformation |
|
f norm get infinity gradient when w is zero matrix, |
|
''' |
|
if self.edges is None: |
|
raise Exception("edges cannot be none when calculate stiff") |
|
affine_weight = torch.cat((self.A, self.b), dim=3) |
|
w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0]) |
|
w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1]) |
|
w_diff = (w1 - w2)**2 |
|
w_rigid = (torch.linalg.det(self.A) - 1.0)**2 |
|
return w_diff, w_rigid |
|
|
|
def forward(self, x): |
|
''' |
|
x should have shape of B * N * 3 * 1 |
|
''' |
|
x = x.unsqueeze(3) |
|
out_x = torch.matmul(self.A, x) |
|
out_x = out_x + self.b |
|
out_x.squeeze_(3) |
|
stiffness, rigid = self.stiffness() |
|
|
|
return out_x, stiffness, rigid |
|
|
|
|
|
def trimesh2meshes(mesh): |
|
''' |
|
convert trimesh mesh to pytorch3d mesh |
|
''' |
|
verts = torch.from_numpy(mesh.vertices).float() |
|
faces = torch.from_numpy(mesh.faces).long() |
|
mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0)) |
|
return mesh |
|
|
|
|
|
def register(target_mesh, src_mesh, device, verbose=True): |
|
|
|
|
|
tgt_mesh = trimesh2meshes(target_mesh).to(device) |
|
src_verts = src_mesh.verts_padded().clone() |
|
|
|
local_affine_model = LocalAffine( |
|
src_mesh.verts_padded().shape[1], |
|
src_mesh.verts_padded().shape[0], src_mesh.edges_packed() |
|
).to(device) |
|
|
|
optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], |
|
lr=1e-2, |
|
amsgrad=True) |
|
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer_cloth, |
|
mode="min", |
|
factor=0.1, |
|
verbose=0, |
|
min_lr=1e-5, |
|
patience=5, |
|
) |
|
|
|
losses = init_loss() |
|
|
|
if verbose: |
|
loop_cloth = tqdm(range(100)) |
|
else: |
|
loop_cloth = range(100) |
|
|
|
for i in loop_cloth: |
|
|
|
optimizer_cloth.zero_grad() |
|
|
|
deformed_verts, stiffness, rigid = local_affine_model(x=src_verts) |
|
src_mesh = src_mesh.update_padded(deformed_verts) |
|
|
|
|
|
update_mesh_shape_prior_losses(src_mesh, losses) |
|
|
|
losses["cloth"]["value"] = chamfer_distance( |
|
x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded() |
|
)[0] |
|
losses["stiff"]["value"] = torch.mean(stiffness) |
|
losses["rigid"]["value"] = torch.mean(rigid) |
|
|
|
|
|
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) |
|
pbar_desc = "Register SMPL-X -> d-BiNI -- " |
|
|
|
for k in losses.keys(): |
|
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0: |
|
cloth_loss = cloth_loss + \ |
|
losses[k]["value"] * losses[k]["weight"] |
|
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | " |
|
|
|
if verbose: |
|
pbar_desc += f"TOTAL: {cloth_loss:.3f}" |
|
loop_cloth.set_description(pbar_desc) |
|
|
|
|
|
cloth_loss.backward(retain_graph=True) |
|
optimizer_cloth.step() |
|
scheduler_cloth.step(cloth_loss) |
|
|
|
print(pbar_desc) |
|
|
|
final = trimesh.Trimesh( |
|
src_mesh.verts_packed().detach().squeeze(0).cpu(), |
|
src_mesh.faces_packed().detach().squeeze(0).cpu(), |
|
process=False, |
|
maintains_order=True |
|
) |
|
|
|
return final |
|
|