Spaces:
Runtime error
Runtime error
NICP for SMPL-X completion
Browse files- README.md +2 -2
- apps/avatarizer.py +109 -8
- apps/infer.py +11 -9
- lib/common/local_affine.py +136 -0
- lib/common/train_util.py +1 -1
- lib/dataset/mesh_util.py +23 -1
- lib/smplx/lbs.py +72 -0
README.md
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
|
5 |
<h1 align="center">ECON: Explicit Clothed humans Obtained from Normals</h1>
|
6 |
<p align="center">
|
7 |
-
<a href="
|
8 |
路
|
9 |
<a href="https://ps.is.tuebingen.mpg.de/person/jyang"><strong>Jinlong Yang</strong></a>
|
10 |
路
|
@@ -28,7 +28,7 @@
|
|
28 |
<img src='https://img.shields.io/badge/Paper-PDF (coming soon)-green?style=for-the-badge&logo=arXiv&logoColor=green' alt='Paper PDF'>
|
29 |
</a>
|
30 |
<a href='https://xiuyuliang.cn/econ/'>
|
31 |
-
<img src='https://img.shields.io/badge/ECON-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=
|
32 |
<a href="https://discord.gg/Vqa7KBGRyk"><img src="https://img.shields.io/discord/940240966844035082?color=7289DA&labelColor=4a64bd&logo=discord&logoColor=white&style=for-the-badge"></a>
|
33 |
<a href="https://youtu.be/j5hw4tsWpoY"><img alt="youtube views" title="Subscribe to my YouTube channel" src="https://img.shields.io/youtube/views/j5hw4tsWpoY?logo=youtube&labelColor=ce4630&style=for-the-badge"/></a>
|
34 |
</p>
|
|
|
4 |
|
5 |
<h1 align="center">ECON: Explicit Clothed humans Obtained from Normals</h1>
|
6 |
<p align="center">
|
7 |
+
<a href="http://xiuyuliang.cn/"><strong>Yuliang Xiu</strong></a>
|
8 |
路
|
9 |
<a href="https://ps.is.tuebingen.mpg.de/person/jyang"><strong>Jinlong Yang</strong></a>
|
10 |
路
|
|
|
28 |
<img src='https://img.shields.io/badge/Paper-PDF (coming soon)-green?style=for-the-badge&logo=arXiv&logoColor=green' alt='Paper PDF'>
|
29 |
</a>
|
30 |
<a href='https://xiuyuliang.cn/econ/'>
|
31 |
+
<img src='https://img.shields.io/badge/ECON-Page-orange?style=for-the-badge&logo=Google%20chrome&logoColor=white' alt='Project Page'></a>
|
32 |
<a href="https://discord.gg/Vqa7KBGRyk"><img src="https://img.shields.io/discord/940240966844035082?color=7289DA&labelColor=4a64bd&logo=discord&logoColor=white&style=for-the-badge"></a>
|
33 |
<a href="https://youtu.be/j5hw4tsWpoY"><img alt="youtube views" title="Subscribe to my YouTube channel" src="https://img.shields.io/youtube/views/j5hw4tsWpoY?logo=youtube&labelColor=ce4630&style=for-the-badge"/></a>
|
34 |
</p>
|
apps/avatarizer.py
CHANGED
@@ -3,12 +3,27 @@ import trimesh
|
|
3 |
import torch
|
4 |
import os.path as osp
|
5 |
import lib.smplx as smplx
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from lib.dataset.mesh_util import SMPLX
|
|
|
7 |
|
8 |
smplx_container = SMPLX()
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
|
13 |
for key in smplx_param.keys():
|
14 |
smplx_param[key] = smplx_param[key].cpu().view(1, -1)
|
@@ -28,20 +43,106 @@ smpl_model = smplx.create(
|
|
28 |
smpl_out = smpl_model(
|
29 |
body_pose=smplx_param["body_pose"],
|
30 |
global_orient=smplx_param["global_orient"],
|
31 |
-
# transl=smplx_param["transl"],
|
32 |
betas=smplx_param["betas"],
|
33 |
expression=smplx_param["expression"],
|
34 |
jaw_pose=smplx_param["jaw_pose"],
|
35 |
left_hand_pose=smplx_param["left_hand_pose"],
|
36 |
right_hand_pose=smplx_param["right_hand_pose"],
|
37 |
return_verts=True,
|
|
|
38 |
return_joint_transformation=True,
|
39 |
return_vertex_transformation=True)
|
40 |
|
41 |
smpl_verts = smpl_out.vertices.detach()[0]
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
trimesh.Trimesh(
|
|
|
|
3 |
import torch
|
4 |
import os.path as osp
|
5 |
import lib.smplx as smplx
|
6 |
+
from pytorch3d.ops import SubdivideMeshes
|
7 |
+
from pytorch3d.structures import Meshes
|
8 |
+
|
9 |
+
from lib.smplx.lbs import general_lbs
|
10 |
+
from lib.dataset.mesh_util import keep_largest, poisson
|
11 |
+
from scipy.spatial import cKDTree
|
12 |
from lib.dataset.mesh_util import SMPLX
|
13 |
+
from lib.common.local_affine import register
|
14 |
|
15 |
smplx_container = SMPLX()
|
16 |
+
device = torch.device("cuda:0")
|
17 |
+
|
18 |
+
prefix = "./results/github/econ/obj/304e9c4798a8c3967de7c74c24ef2e38"
|
19 |
+
smpl_path = f"{prefix}_smpl_00.npy"
|
20 |
+
econ_path = f"{prefix}_0_full.obj"
|
21 |
|
22 |
+
smplx_param = np.load(smpl_path, allow_pickle=True).item()
|
23 |
+
econ_obj = trimesh.load(econ_path)
|
24 |
+
econ_obj.vertices *= np.array([1.0, -1.0, -1.0])
|
25 |
+
econ_obj.vertices /= smplx_param["scale"].cpu().numpy()
|
26 |
+
econ_obj.vertices -= smplx_param["transl"].cpu().numpy()
|
27 |
|
28 |
for key in smplx_param.keys():
|
29 |
smplx_param[key] = smplx_param[key].cpu().view(1, -1)
|
|
|
43 |
smpl_out = smpl_model(
|
44 |
body_pose=smplx_param["body_pose"],
|
45 |
global_orient=smplx_param["global_orient"],
|
|
|
46 |
betas=smplx_param["betas"],
|
47 |
expression=smplx_param["expression"],
|
48 |
jaw_pose=smplx_param["jaw_pose"],
|
49 |
left_hand_pose=smplx_param["left_hand_pose"],
|
50 |
right_hand_pose=smplx_param["right_hand_pose"],
|
51 |
return_verts=True,
|
52 |
+
return_full_pose=True,
|
53 |
return_joint_transformation=True,
|
54 |
return_vertex_transformation=True)
|
55 |
|
56 |
smpl_verts = smpl_out.vertices.detach()[0]
|
57 |
+
smpl_tree = cKDTree(smpl_verts.cpu().numpy())
|
58 |
+
dist, idx = smpl_tree.query(econ_obj.vertices, k=5)
|
59 |
+
|
60 |
+
if not osp.exists(f"{prefix}_econ_cano.obj") or not osp.exists(f"{prefix}_smpl_cano.obj"):
|
61 |
+
|
62 |
+
# canonicalize for ECON
|
63 |
+
econ_verts = torch.tensor(econ_obj.vertices).float()
|
64 |
+
inv_mat = torch.inverse(smpl_out.vertex_transformation.detach()[0][idx[:, 0]])
|
65 |
+
homo_coord = torch.ones_like(econ_verts)[..., :1]
|
66 |
+
econ_cano_verts = inv_mat @ torch.cat([econ_verts, homo_coord], dim=1).unsqueeze(-1)
|
67 |
+
econ_cano_verts = econ_cano_verts[:, :3, 0].cpu()
|
68 |
+
econ_cano = trimesh.Trimesh(econ_cano_verts, econ_obj.faces)
|
69 |
+
|
70 |
+
# canonicalize for SMPL-X
|
71 |
+
inv_mat = torch.inverse(smpl_out.vertex_transformation.detach()[0])
|
72 |
+
homo_coord = torch.ones_like(smpl_verts)[..., :1]
|
73 |
+
smpl_cano_verts = inv_mat @ torch.cat([smpl_verts, homo_coord], dim=1).unsqueeze(-1)
|
74 |
+
smpl_cano_verts = smpl_cano_verts[:, :3, 0].cpu()
|
75 |
+
smpl_cano = trimesh.Trimesh(smpl_cano_verts, smpl_model.faces, maintain_orders=True, process=False)
|
76 |
+
smpl_cano.export(f"{prefix}_smpl_cano.obj")
|
77 |
+
|
78 |
+
# remove hands from ECON for next registeration
|
79 |
+
econ_cano_body = econ_cano.copy()
|
80 |
+
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
81 |
+
econ_cano_body.update_faces(mano_mask[econ_cano.faces].all(axis=1))
|
82 |
+
econ_cano_body.remove_unreferenced_vertices()
|
83 |
+
econ_cano_body = keep_largest(econ_cano_body)
|
84 |
+
|
85 |
+
# remove SMPL-X hand and face
|
86 |
+
register_mask = ~np.isin(
|
87 |
+
np.arange(smpl_cano_verts.shape[0]),
|
88 |
+
np.concatenate([smplx_container.smplx_mano_vid, smplx_container.smplx_front_flame_vid]))
|
89 |
+
register_mask *= ~smplx_container.eyeball_vertex_mask.bool().numpy()
|
90 |
+
smpl_cano_body = smpl_cano.copy()
|
91 |
+
smpl_cano_body.update_faces(register_mask[smpl_cano.faces].all(axis=1))
|
92 |
+
smpl_cano_body.remove_unreferenced_vertices()
|
93 |
+
smpl_cano_body = keep_largest(smpl_cano_body)
|
94 |
+
|
95 |
+
# upsample the smpl_cano_body and do registeration
|
96 |
+
smpl_cano_body = Meshes(
|
97 |
+
verts=[torch.tensor(smpl_cano_body.vertices).float()],
|
98 |
+
faces=[torch.tensor(smpl_cano_body.faces).long()],
|
99 |
+
).to(device)
|
100 |
+
sm = SubdivideMeshes(smpl_cano_body)
|
101 |
+
smpl_cano_body = register(econ_cano_body, sm(smpl_cano_body), device)
|
102 |
+
|
103 |
+
# remove over-streched+hand faces from ECON
|
104 |
+
econ_cano_body = econ_cano.copy()
|
105 |
+
edge_before = np.sqrt(
|
106 |
+
((econ_obj.vertices[econ_cano.edges[:, 0]] - econ_obj.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
|
107 |
+
edge_after = np.sqrt(
|
108 |
+
((econ_cano.vertices[econ_cano.edges[:, 0]] - econ_cano.vertices[econ_cano.edges[:, 1]])**2).sum(axis=1))
|
109 |
+
edge_diff = edge_after / edge_before.clip(1e-2)
|
110 |
+
streched_mask = np.unique(econ_cano.edges[edge_diff > 6])
|
111 |
+
mano_mask = ~np.isin(idx[:, 0], smplx_container.smplx_mano_vid)
|
112 |
+
mano_mask[streched_mask] = False
|
113 |
+
econ_cano_body.update_faces(mano_mask[econ_cano.faces].all(axis=1))
|
114 |
+
econ_cano_body.remove_unreferenced_vertices()
|
115 |
+
|
116 |
+
# stitch the registered SMPL-X body and floating hands to ECON
|
117 |
+
econ_cano_tree = cKDTree(econ_cano.vertices)
|
118 |
+
dist, idx = econ_cano_tree.query(smpl_cano_body.vertices, k=1)
|
119 |
+
smpl_cano_body.update_faces((dist > 0.02)[smpl_cano_body.faces].all(axis=1))
|
120 |
+
smpl_cano_body.remove_unreferenced_vertices()
|
121 |
+
|
122 |
+
smpl_hand = smpl_cano.copy()
|
123 |
+
smpl_hand.update_faces(smplx_container.mano_vertex_mask.numpy()[smpl_hand.faces].all(axis=1))
|
124 |
+
smpl_hand.remove_unreferenced_vertices()
|
125 |
+
econ_cano = sum([smpl_hand, smpl_cano_body, econ_cano_body])
|
126 |
+
econ_cano = poisson(econ_cano, f"{prefix}_econ_cano.obj")
|
127 |
+
else:
|
128 |
+
econ_cano = trimesh.load(f"{prefix}_econ_cano.obj")
|
129 |
+
smpl_cano = trimesh.load(f"{prefix}_smpl_cano.obj", maintain_orders=True, process=False)
|
130 |
+
|
131 |
+
smpl_tree = cKDTree(smpl_cano.vertices)
|
132 |
+
dist, idx = smpl_tree.query(econ_cano.vertices, k=2)
|
133 |
+
knn_weights = np.exp(-dist**2)
|
134 |
+
knn_weights /= knn_weights.sum(axis=1, keepdims=True)
|
135 |
+
econ_J_regressor = (smpl_model.J_regressor[:, idx] * knn_weights[None]).sum(axis=-1)
|
136 |
+
econ_lbs_weights = (smpl_model.lbs_weights.T[:, idx] * knn_weights[None]).sum(axis=-1).T
|
137 |
+
econ_J_regressor /= econ_J_regressor.sum(axis=1, keepdims=True)
|
138 |
+
econ_lbs_weights /= econ_lbs_weights.sum(axis=1, keepdims=True)
|
139 |
+
|
140 |
+
posed_econ_verts, _ = general_lbs(
|
141 |
+
pose=smpl_out.full_pose,
|
142 |
+
v_template=torch.tensor(econ_cano.vertices).unsqueeze(0),
|
143 |
+
J_regressor=econ_J_regressor,
|
144 |
+
parents=smpl_model.parents,
|
145 |
+
lbs_weights=econ_lbs_weights)
|
146 |
|
147 |
+
econ_pose = trimesh.Trimesh(posed_econ_verts[0].detach(), econ_cano.faces)
|
148 |
+
econ_pose.export(f"{prefix}_econ_pose.obj")
|
apps/infer.py
CHANGED
@@ -37,6 +37,7 @@ from lib.common.train_util import init_loss, load_normal_networks, load_networks
|
|
37 |
from lib.common.BNI import BNI
|
38 |
from lib.common.BNI_utils import save_normal_tensor
|
39 |
from lib.dataset.TestDataset import TestDataset
|
|
|
40 |
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
|
41 |
from lib.dataset.mesh_util import *
|
42 |
from lib.common.voxelize import VoxelGrid
|
@@ -156,8 +157,8 @@ if __name__ == "__main__":
|
|
156 |
|
157 |
N_body, N_pose = optimed_pose.shape[:2]
|
158 |
|
159 |
-
smpl_path =
|
160 |
-
|
161 |
if osp.exists(smpl_path):
|
162 |
|
163 |
smpl_verts_lst = []
|
@@ -182,6 +183,7 @@ if __name__ == "__main__":
|
|
182 |
|
183 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
184 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
|
|
185 |
else:
|
186 |
# smpl optimization
|
187 |
loop_smpl = tqdm(range(args.loop_smpl))
|
@@ -447,15 +449,15 @@ if __name__ == "__main__":
|
|
447 |
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
448 |
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
449 |
)
|
450 |
-
|
451 |
-
#
|
452 |
-
|
453 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
454 |
faces=[torch.tensor(side_mesh.faces).long()],
|
455 |
-
)
|
456 |
-
sm = SubdivideMeshes(
|
457 |
-
|
458 |
-
|
459 |
|
460 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
461 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
|
|
37 |
from lib.common.BNI import BNI
|
38 |
from lib.common.BNI_utils import save_normal_tensor
|
39 |
from lib.dataset.TestDataset import TestDataset
|
40 |
+
from lib.common.local_affine import register
|
41 |
from lib.net.geometry import rot6d_to_rotmat, rotation_matrix_to_angle_axis
|
42 |
from lib.dataset.mesh_util import *
|
43 |
from lib.common.voxelize import VoxelGrid
|
|
|
157 |
|
158 |
N_body, N_pose = optimed_pose.shape[:2]
|
159 |
|
160 |
+
smpl_path = f"{args.out_dir}/{cfg.name}/obj/{data['name']}_smpl_00.obj"
|
161 |
+
|
162 |
if osp.exists(smpl_path):
|
163 |
|
164 |
smpl_verts_lst = []
|
|
|
183 |
|
184 |
in_tensor["smpl_verts"] = batch_smpl_verts * torch.tensor([1., -1., 1.]).to(device)
|
185 |
in_tensor["smpl_faces"] = batch_smpl_faces[:, :, [0, 2, 1]]
|
186 |
+
|
187 |
else:
|
188 |
# smpl optimization
|
189 |
loop_smpl = tqdm(range(args.loop_smpl))
|
|
|
449 |
(SMPLX_object.front_flame_vertex_mask + SMPLX_object.mano_vertex_mask +
|
450 |
SMPLX_object.eyeball_vertex_mask).eq(0).float(),
|
451 |
)
|
452 |
+
|
453 |
+
#register side_mesh to BNI surfaces
|
454 |
+
side_mesh = Meshes(
|
455 |
verts=[torch.tensor(side_mesh.vertices).float()],
|
456 |
faces=[torch.tensor(side_mesh.faces).long()],
|
457 |
+
).to(device)
|
458 |
+
sm = SubdivideMeshes(side_mesh)
|
459 |
+
side_mesh = register(BNI_object.F_B_trimesh, sm(side_mesh), device)
|
460 |
+
|
461 |
|
462 |
side_verts = torch.tensor(side_mesh.vertices).float().to(device)
|
463 |
side_faces = torch.tensor(side_mesh.faces).long().to(device)
|
lib/common/local_affine.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology.
|
2 |
+
# All rights reserved.
|
3 |
+
# This file is part of the pytorch-nicp,
|
4 |
+
# and is released under the "MIT License Agreement". Please see the LICENSE
|
5 |
+
# file that should have been included as part of this package.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import trimesh
|
9 |
+
import torch.nn as nn
|
10 |
+
from tqdm import tqdm
|
11 |
+
from pytorch3d.structures import Meshes
|
12 |
+
from pytorch3d.loss import chamfer_distance
|
13 |
+
from lib.dataset.mesh_util import update_mesh_shape_prior_losses
|
14 |
+
from lib.common.train_util import init_loss
|
15 |
+
|
16 |
+
|
17 |
+
# reference: https://github.com/wuhaozhe/pytorch-nicp
|
18 |
+
class LocalAffine(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, num_points, batch_size=1, edges=None):
|
21 |
+
'''
|
22 |
+
specify the number of points, the number of points should be constant across the batch
|
23 |
+
and the edges torch.Longtensor() with shape N * 2
|
24 |
+
the local affine operator supports batch operation
|
25 |
+
batch size must be constant
|
26 |
+
add additional pooling on top of w matrix
|
27 |
+
'''
|
28 |
+
super(LocalAffine, self).__init__()
|
29 |
+
self.A = nn.Parameter(torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1))
|
30 |
+
self.b = nn.Parameter(torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat(batch_size, num_points, 1, 1))
|
31 |
+
self.edges = edges
|
32 |
+
self.num_points = num_points
|
33 |
+
|
34 |
+
def stiffness(self):
|
35 |
+
'''
|
36 |
+
calculate the stiffness of local affine transformation
|
37 |
+
f norm get infinity gradient when w is zero matrix,
|
38 |
+
'''
|
39 |
+
if self.edges is None:
|
40 |
+
raise Exception("edges cannot be none when calculate stiff")
|
41 |
+
idx1 = self.edges[:, 0]
|
42 |
+
idx2 = self.edges[:, 1]
|
43 |
+
affine_weight = torch.cat((self.A, self.b), dim=3)
|
44 |
+
w1 = torch.index_select(affine_weight, dim=1, index=idx1)
|
45 |
+
w2 = torch.index_select(affine_weight, dim=1, index=idx2)
|
46 |
+
w_diff = (w1 - w2)**2
|
47 |
+
w_rigid = (torch.linalg.det(self.A) - 1.0)**2
|
48 |
+
return w_diff, w_rigid
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
'''
|
52 |
+
x should have shape of B * N * 3
|
53 |
+
'''
|
54 |
+
x = x.unsqueeze(3)
|
55 |
+
out_x = torch.matmul(self.A, x)
|
56 |
+
out_x = out_x + self.b
|
57 |
+
stiffness, rigid = self.stiffness()
|
58 |
+
out_x.squeeze_(3)
|
59 |
+
return out_x, stiffness, rigid
|
60 |
+
|
61 |
+
|
62 |
+
def trimesh2meshes(mesh):
|
63 |
+
'''
|
64 |
+
convert trimesh mesh to pytorch3d mesh
|
65 |
+
'''
|
66 |
+
verts = torch.from_numpy(mesh.vertices).float()
|
67 |
+
faces = torch.from_numpy(mesh.faces).long()
|
68 |
+
mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0))
|
69 |
+
return mesh
|
70 |
+
|
71 |
+
|
72 |
+
def register(target_mesh, src_mesh, device):
|
73 |
+
|
74 |
+
# define local_affine deform verts
|
75 |
+
tgt_mesh = trimesh2meshes(target_mesh).to(device)
|
76 |
+
src_verts = src_mesh.verts_padded().clone()
|
77 |
+
|
78 |
+
local_affine_model = LocalAffine(src_mesh.verts_padded().shape[1],
|
79 |
+
src_mesh.verts_padded().shape[0], src_mesh.edges_packed()).to(device)
|
80 |
+
|
81 |
+
optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], lr=1e-2, amsgrad=True)
|
82 |
+
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
83 |
+
optimizer_cloth,
|
84 |
+
mode="min",
|
85 |
+
factor=0.1,
|
86 |
+
verbose=0,
|
87 |
+
min_lr=1e-5,
|
88 |
+
patience=5,
|
89 |
+
)
|
90 |
+
|
91 |
+
losses = init_loss()
|
92 |
+
|
93 |
+
loop_cloth = tqdm(range(200))
|
94 |
+
|
95 |
+
for i in loop_cloth:
|
96 |
+
|
97 |
+
optimizer_cloth.zero_grad()
|
98 |
+
|
99 |
+
deformed_verts, stiffness, rigid = local_affine_model(src_verts)
|
100 |
+
src_mesh = src_mesh.update_padded(deformed_verts)
|
101 |
+
|
102 |
+
# losses for laplacian, edge, normal consistency
|
103 |
+
update_mesh_shape_prior_losses(src_mesh, losses)
|
104 |
+
|
105 |
+
losses["cloth"]["value"] = chamfer_distance(
|
106 |
+
x=src_mesh.verts_padded(),
|
107 |
+
y=tgt_mesh.verts_padded())[0]
|
108 |
+
|
109 |
+
losses["stiffness"]["value"] = torch.mean(stiffness)
|
110 |
+
losses["rigid"]["value"] = torch.mean(rigid)
|
111 |
+
|
112 |
+
# Weighted sum of the losses
|
113 |
+
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device)
|
114 |
+
pbar_desc = "Register SMPL-X towards ECON --- "
|
115 |
+
|
116 |
+
for k in losses.keys():
|
117 |
+
if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0:
|
118 |
+
cloth_loss = cloth_loss + \
|
119 |
+
losses[k]["value"] * losses[k]["weight"]
|
120 |
+
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | "
|
121 |
+
|
122 |
+
pbar_desc += f"Total: {cloth_loss:.5f}"
|
123 |
+
loop_cloth.set_description(pbar_desc)
|
124 |
+
|
125 |
+
# update params
|
126 |
+
cloth_loss.backward(retain_graph=True)
|
127 |
+
optimizer_cloth.step()
|
128 |
+
scheduler_cloth.step(cloth_loss)
|
129 |
+
|
130 |
+
final = trimesh.Trimesh(
|
131 |
+
src_mesh.verts_packed().detach().squeeze(0).cpu(),
|
132 |
+
src_mesh.faces_packed().detach().squeeze(0).cpu(),
|
133 |
+
process=False,
|
134 |
+
maintains_order=True)
|
135 |
+
|
136 |
+
return final
|
lib/common/train_util.py
CHANGED
@@ -32,7 +32,7 @@ def init_loss():
|
|
32 |
losses = {
|
33 |
# Cloth: Normal_recon - Normal_pred
|
34 |
"cloth": {
|
35 |
-
"weight":
|
36 |
"value": 0.0
|
37 |
},
|
38 |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
|
|
32 |
losses = {
|
33 |
# Cloth: Normal_recon - Normal_pred
|
34 |
"cloth": {
|
35 |
+
"weight": 1e3,
|
36 |
"value": 0.0
|
37 |
},
|
38 |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2)
|
lib/dataset/mesh_util.py
CHANGED
@@ -552,7 +552,7 @@ def poisson_remesh(obj_path):
|
|
552 |
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
|
553 |
# ms.apply_coord_laplacian_smoothing()
|
554 |
ms.save_current_mesh(obj_path)
|
555 |
-
ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
|
556 |
polished_mesh = trimesh.load_mesh(obj_path)
|
557 |
|
558 |
return polished_mesh
|
@@ -1013,6 +1013,15 @@ def clean_floats(mesh):
|
|
1013 |
return sum(clean_mesh_lst)
|
1014 |
|
1015 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1016 |
def mesh_move(mesh_lst, step, scale=1.0):
|
1017 |
|
1018 |
trans = np.array([1.0, 0.0, 0.0]) * step
|
@@ -1036,3 +1045,16 @@ def rescale_smpl(fitted_path, scale=100, translate=(0, 0, 0)):
|
|
1036 |
fitted_body.apply_transform(resize_matrix)
|
1037 |
|
1038 |
return np.array(fitted_body.vertices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
ms.meshing_decimation_quadric_edge_collapse(targetfacenum=50000)
|
553 |
# ms.apply_coord_laplacian_smoothing()
|
554 |
ms.save_current_mesh(obj_path)
|
555 |
+
# ms.save_current_mesh(obj_path.replace(".obj", ".ply"))
|
556 |
polished_mesh = trimesh.load_mesh(obj_path)
|
557 |
|
558 |
return polished_mesh
|
|
|
1013 |
return sum(clean_mesh_lst)
|
1014 |
|
1015 |
|
1016 |
+
def keep_largest(mesh):
|
1017 |
+
mesh_lst = mesh.split(only_watertight=False)
|
1018 |
+
keep_mesh = mesh_lst[0]
|
1019 |
+
for mesh in mesh_lst:
|
1020 |
+
if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
|
1021 |
+
keep_mesh = mesh
|
1022 |
+
return keep_mesh
|
1023 |
+
|
1024 |
+
|
1025 |
def mesh_move(mesh_lst, step, scale=1.0):
|
1026 |
|
1027 |
trans = np.array([1.0, 0.0, 0.0]) * step
|
|
|
1045 |
fitted_body.apply_transform(resize_matrix)
|
1046 |
|
1047 |
return np.array(fitted_body.vertices)
|
1048 |
+
|
1049 |
+
|
1050 |
+
def get_joint_mesh(joints, radius=2.0):
|
1051 |
+
|
1052 |
+
ball = trimesh.creation.icosphere(radius=radius)
|
1053 |
+
combined = None
|
1054 |
+
for joint in joints:
|
1055 |
+
ball_new = trimesh.Trimesh(vertices=ball.vertices + joint, faces=ball.faces, process=False)
|
1056 |
+
if combined is None:
|
1057 |
+
combined = ball_new
|
1058 |
+
else:
|
1059 |
+
combined = sum([combined, ball_new])
|
1060 |
+
return combined
|
lib/smplx/lbs.py
CHANGED
@@ -194,6 +194,7 @@ def lbs(
|
|
194 |
# 3. Add pose blend shapes
|
195 |
# N x J x 3 x 3
|
196 |
ident = torch.eye(3, dtype=dtype, device=device)
|
|
|
197 |
if pose2rot:
|
198 |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
199 |
|
@@ -229,6 +230,77 @@ def lbs(
|
|
229 |
return verts, J_transformed
|
230 |
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
|
233 |
"""Calculates the 3D joint locations from the vertices
|
234 |
|
|
|
194 |
# 3. Add pose blend shapes
|
195 |
# N x J x 3 x 3
|
196 |
ident = torch.eye(3, dtype=dtype, device=device)
|
197 |
+
|
198 |
if pose2rot:
|
199 |
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
200 |
|
|
|
230 |
return verts, J_transformed
|
231 |
|
232 |
|
233 |
+
def general_lbs(
|
234 |
+
pose: Tensor,
|
235 |
+
v_template: Tensor,
|
236 |
+
J_regressor: Tensor,
|
237 |
+
parents: Tensor,
|
238 |
+
lbs_weights: Tensor,
|
239 |
+
pose2rot: bool = True,
|
240 |
+
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
241 |
+
"""Performs Linear Blend Skinning with the given shape and pose parameters
|
242 |
+
|
243 |
+
Parameters
|
244 |
+
----------
|
245 |
+
pose : torch.tensor Bx(J + 1) * 3
|
246 |
+
The pose parameters in axis-angle format
|
247 |
+
v_template torch.tensor BxVx3
|
248 |
+
The template mesh that will be deformed
|
249 |
+
J_regressor : torch.tensor JxV
|
250 |
+
The regressor array that is used to calculate the joints from
|
251 |
+
the position of the vertices
|
252 |
+
parents: torch.tensor J
|
253 |
+
The array that describes the kinematic tree for the model
|
254 |
+
lbs_weights: torch.tensor N x V x (J + 1)
|
255 |
+
The linear blend skinning weights that represent how much the
|
256 |
+
rotation matrix of each part affects each vertex
|
257 |
+
pose2rot: bool, optional
|
258 |
+
Flag on whether to convert the input pose tensor to rotation
|
259 |
+
matrices. The default value is True. If False, then the pose tensor
|
260 |
+
should already contain rotation matrices and have a size of
|
261 |
+
Bx(J + 1)x9
|
262 |
+
dtype: torch.dtype, optional
|
263 |
+
|
264 |
+
Returns
|
265 |
+
-------
|
266 |
+
verts: torch.tensor BxVx3
|
267 |
+
The vertices of the mesh after applying the shape and pose
|
268 |
+
displacements.
|
269 |
+
joints: torch.tensor BxJx3
|
270 |
+
The joints of the model
|
271 |
+
"""
|
272 |
+
|
273 |
+
batch_size = pose.shape[0]
|
274 |
+
device, dtype = pose.device, pose.dtype
|
275 |
+
|
276 |
+
# Get the joints
|
277 |
+
# NxJx3 array
|
278 |
+
J = vertices2joints(J_regressor, v_template)
|
279 |
+
|
280 |
+
if pose2rot:
|
281 |
+
rot_mats = batch_rodrigues(pose.view(-1, 3)).view([batch_size, -1, 3, 3])
|
282 |
+
else:
|
283 |
+
rot_mats = pose.view(batch_size, -1, 3, 3)
|
284 |
+
|
285 |
+
# 4. Get the global joint location
|
286 |
+
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
287 |
+
|
288 |
+
# 5. Do skinning:
|
289 |
+
# W is N x V x (J + 1)
|
290 |
+
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
291 |
+
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
292 |
+
num_joints = J_regressor.shape[0]
|
293 |
+
T = torch.matmul(W, A.view(batch_size, num_joints, 16)).view(batch_size, -1, 4, 4)
|
294 |
+
|
295 |
+
homogen_coord = torch.ones([batch_size, v_template.shape[1], 1], dtype=dtype, device=device)
|
296 |
+
v_posed_homo = torch.cat([v_template, homogen_coord], dim=2)
|
297 |
+
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
298 |
+
|
299 |
+
verts = v_homo[:, :, :3, 0]
|
300 |
+
|
301 |
+
return verts, J
|
302 |
+
|
303 |
+
|
304 |
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
|
305 |
"""Calculates the 3D joint locations from the vertices
|
306 |
|