|
from core.remesh import calc_vertex_normals |
|
from core.opt import MeshOptimizer |
|
from utils.func import make_sparse_camera, make_round_views |
|
from utils.render import NormalsRenderer |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
from utils.video_utils import write_video |
|
from omegaconf import OmegaConf |
|
import numpy as np |
|
import os |
|
from PIL import Image |
|
import kornia |
|
import torch |
|
import torch.nn as nn |
|
import trimesh |
|
from icecream import ic |
|
from utils.project_mesh import multiview_color_projection, get_cameras_list |
|
from utils.mesh_utils import to_py3d_mesh, rot6d_to_rotmat, tensor2variable |
|
from utils.project_mesh import project_color, get_cameras_list |
|
from utils.smpl_util import SMPLX |
|
from lib.dataset.mesh_util import apply_vertex_mask, part_removal, poisson, keep_largest |
|
from scipy.spatial.transform import Rotation as R |
|
from scipy.spatial import KDTree |
|
import argparse |
|
|
|
bg_color = np.array([1,1,1]) |
|
|
|
class colorModel(nn.Module): |
|
def __init__(self, renderer, v, f, c): |
|
super().__init__() |
|
self.renderer = renderer |
|
self.v = v |
|
self.f = f |
|
self.colors = nn.Parameter(c, requires_grad=True) |
|
self.bg_color = torch.from_numpy(bg_color).float().to(self.colors.device) |
|
def forward(self, return_mask=False): |
|
rgba = self.renderer.render(self.v, self.f, colors=self.colors) |
|
if return_mask: |
|
return rgba |
|
else: |
|
mask = rgba[..., 3:] |
|
return rgba[..., :3] * mask + self.bg_color * (1 - mask) |
|
|
|
|
|
def scale_mesh(vert): |
|
min_bbox, max_bbox = vert.min(0)[0], vert.max(0)[0] |
|
center = (min_bbox + max_bbox) / 2 |
|
offset = -center |
|
vert = vert + offset |
|
|
|
max_dist = torch.max(torch.sqrt(torch.sum(vert**2, dim=1))) |
|
scale = 1.0 / max_dist |
|
return scale, offset |
|
|
|
def save_mesh(save_name, vertices, faces, color=None): |
|
trimesh.Trimesh( |
|
vertices.detach().cpu().numpy(), |
|
faces.detach().cpu().numpy(), |
|
vertex_colors=(color.detach().cpu().numpy() * 255).astype(np.uint8) if color is not None else None) \ |
|
.export(save_name) |
|
|
|
|
|
|
|
|
|
class ReMesh: |
|
def __init__(self, opt, econ_dataset): |
|
self.opt = opt |
|
self.device = torch.device(f"cuda:{opt.gpu_id}" if torch.cuda.is_available() else "cpu") |
|
self.num_view = opt.num_view |
|
|
|
self.out_path = opt.res_path |
|
os.makedirs(self.out_path, exist_ok=True) |
|
self.resolution = opt.resolution |
|
self.views = ['front_face', 'front_right', 'right', 'back', 'left', 'front_left' ] |
|
self.weights = torch.Tensor([1., 0.4, 0.8, 1.0, 0.8, 0.4]).view(6,1,1,1).to(self.device) |
|
|
|
self.renderer = self.prepare_render() |
|
|
|
self.econ_dataset = econ_dataset |
|
self.smplx_face = torch.Tensor(econ_dataset.faces.astype(np.int64)).long().to(self.device) |
|
|
|
def prepare_render(self): |
|
|
|
mv, proj = make_sparse_camera(self.opt.cam_path, self.opt.scale, views=[0,1,2,4,6,7], device=self.device) |
|
renderer = NormalsRenderer(mv, proj, [self.resolution, self.resolution], device=self.device) |
|
return renderer |
|
|
|
def proj_texture(self, fused_images, vertices, faces): |
|
mesh = to_py3d_mesh(vertices, faces) |
|
mesh = mesh.to(self.device) |
|
camera_focal = 1/2 |
|
cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) |
|
mesh = multiview_color_projection(mesh, fused_images, camera_focal=camera_focal, resolution=self.resolution, weights=self.weights.squeeze().cpu().numpy(), |
|
device=self.device, complete_unseen=True, confidence_threshold=0.2, cameras_list=cameras_list) |
|
return mesh |
|
|
|
def get_invisible_idx(self, imgs, vertices, faces): |
|
mesh = to_py3d_mesh(vertices, faces) |
|
mesh = mesh.to(self.device) |
|
camera_focal = 1/2 |
|
if self.num_view == 6: |
|
cameras_list = get_cameras_list([0, 45, 90, 180, 270, 315], device=self.device, focal=camera_focal) |
|
elif self.num_view == 4: |
|
cameras_list = get_cameras_list([0, 90, 180, 270], device=self.device, focal=camera_focal) |
|
valid_vert_id = [] |
|
vertices_colors = torch.zeros((vertices.shape[0], 3)).float().to(self.device) |
|
valid_cnt = torch.zeros((vertices.shape[0])).to(self.device) |
|
for cam, img, weight in zip(cameras_list, imgs, self.weights.squeeze()): |
|
ret = project_color(mesh, cam, img, eps=0.01, resolution=self.resolution, device=self.device) |
|
|
|
valid_cnt[ret['valid_verts']] += weight |
|
vertices_colors[ret['valid_verts']] += ret['valid_colors']*weight |
|
valid_mask = valid_cnt > 1 |
|
invalid_mask = valid_cnt < 1 |
|
vertices_colors[valid_mask] /= valid_cnt[valid_mask][:, None] |
|
|
|
|
|
invisible_vert = valid_cnt < 1 |
|
invisible_vert_indices = torch.nonzero(invisible_vert).squeeze() |
|
|
|
return vertices_colors, invisible_vert_indices |
|
|
|
def inpaint_missed_colors(self, all_vertices, all_colors, missing_indices): |
|
all_vertices = all_vertices.detach().cpu().numpy() |
|
all_colors = all_colors.detach().cpu().numpy() |
|
missing_indices = missing_indices.detach().cpu().numpy() |
|
|
|
|
|
non_missing_indices = np.setdiff1d(np.arange(len(all_vertices)), missing_indices) |
|
|
|
kdtree = KDTree(all_vertices[non_missing_indices]) |
|
|
|
|
|
for missing_index in missing_indices: |
|
missing_vertex = all_vertices[missing_index] |
|
|
|
_, nearest_index = kdtree.query(missing_vertex.reshape(1, -1)) |
|
|
|
interpolated_color = all_colors[non_missing_indices[nearest_index]] |
|
|
|
all_colors[missing_index] = interpolated_color |
|
|
|
return torch.from_numpy(all_colors).to(self.device) |
|
|
|
def load_training_data(self, case): |
|
|
|
kernal = torch.ones(3, 3) |
|
erode_iters = 2 |
|
normals = [] |
|
masks = [] |
|
colors = [] |
|
for idx, view in enumerate(self.views): |
|
|
|
normal = Image.open(f'{self.opt.mv_path}/{case}/normals_{view}_masked.png') |
|
|
|
normal = normal.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) |
|
normal = np.array(normal).astype(np.float32) / 255. |
|
mask = normal[..., 3:] |
|
mask_troch = torch.from_numpy(mask).unsqueeze(0) |
|
for _ in range(erode_iters): |
|
mask_torch = kornia.morphology.erosion(mask_troch, kernal) |
|
mask_erode = mask_torch.squeeze(0).numpy() |
|
masks.append(mask_erode) |
|
normal = normal[..., :3] * mask_erode |
|
normals.append(normal) |
|
|
|
color = Image.open(f'{self.opt.mv_path}/{case}/color_{view}_masked.png') |
|
color = color.convert('RGBA').resize((self.resolution, self.resolution), Image.BILINEAR) |
|
color = np.array(color).astype(np.float32) / 255. |
|
color_mask = color[..., 3:] |
|
|
|
color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) |
|
colors.append(color_dilate) |
|
|
|
masks = np.stack(masks, 0) |
|
masks = torch.from_numpy(masks).to(self.device) |
|
normals = np.stack(normals, 0) |
|
target_normals = torch.from_numpy(normals).to(self.device) |
|
colors = np.stack(colors, 0) |
|
target_colors = torch.from_numpy(colors).to(self.device) |
|
return masks, target_colors, target_normals |
|
|
|
def preprocess(self, color_pils, normal_pils): |
|
|
|
kernal = torch.ones(3, 3) |
|
erode_iters = 2 |
|
normals = [] |
|
masks = [] |
|
colors = [] |
|
for normal, color in zip(normal_pils, color_pils): |
|
normal = normal.resize((self.resolution, self.resolution), Image.BILINEAR) |
|
normal = np.array(normal).astype(np.float32) / 255. |
|
mask = normal[..., 3:] |
|
mask_troch = torch.from_numpy(mask).unsqueeze(0) |
|
for _ in range(erode_iters): |
|
mask_torch = kornia.morphology.erosion(mask_troch, kernal) |
|
mask_erode = mask_torch.squeeze(0).numpy() |
|
masks.append(mask_erode) |
|
normal = normal[..., :3] * mask_erode |
|
normals.append(normal) |
|
|
|
color = color.resize((self.resolution, self.resolution), Image.BILINEAR) |
|
color = np.array(color).astype(np.float32) / 255. |
|
color_mask = color[..., 3:] |
|
|
|
color_dilate = color[..., :3] * mask_erode + bg_color * (1 - mask_erode) |
|
colors.append(color_dilate) |
|
|
|
masks = np.stack(masks, 0) |
|
masks = torch.from_numpy(masks).to(self.device) |
|
normals = np.stack(normals, 0) |
|
target_normals = torch.from_numpy(normals).to(self.device) |
|
colors = np.stack(colors, 0) |
|
target_colors = torch.from_numpy(colors).to(self.device) |
|
return masks, target_colors, target_normals |
|
|
|
def optimize_case(self, case, pose, clr_img, nrm_img, opti_texture=True): |
|
case_path = f'{self.out_path}/{case}' |
|
os.makedirs(case_path, exist_ok=True) |
|
|
|
if clr_img is not None: |
|
masks, target_colors, target_normals = self.preprocess(clr_img, nrm_img) |
|
else: |
|
masks, target_colors, target_normals = self.load_training_data(case) |
|
|
|
|
|
rz = R.from_euler('z', 180, degrees=True).as_matrix() |
|
ry = R.from_euler('y', 180, degrees=True).as_matrix() |
|
rz = torch.from_numpy(rz).float().to(self.device) |
|
ry = torch.from_numpy(ry).float().to(self.device) |
|
|
|
scale, offset = None, None |
|
|
|
global_orient = pose["global_orient"] |
|
body_pose = pose["body_pose"] |
|
left_hand_pose = pose["left_hand_pose"] |
|
right_hand_pose = pose["right_hand_pose"] |
|
beta = pose["betas"] |
|
|
|
|
|
optimed_pose = torch.tensor(body_pose, |
|
device=self.device, |
|
requires_grad=True) |
|
optimed_trans = torch.tensor(pose["trans"], |
|
device=self.device, |
|
requires_grad=True) |
|
optimed_betas = torch.tensor(beta, |
|
device=self.device, |
|
requires_grad=True) |
|
optimed_orient = torch.tensor(global_orient, |
|
device=self.device, |
|
requires_grad=True) |
|
optimed_rhand = torch.tensor(right_hand_pose, |
|
device=self.device, |
|
requires_grad=True) |
|
optimed_lhand = torch.tensor(left_hand_pose, |
|
device=self.device, |
|
requires_grad=True) |
|
|
|
optimed_params = [ |
|
{'params': [optimed_lhand, optimed_rhand], 'lr': 1e-3}, |
|
{'params': [optimed_betas, optimed_trans, optimed_orient, optimed_pose], 'lr': 3e-3}, |
|
] |
|
optimizer_smpl = torch.optim.Adam( |
|
optimed_params, |
|
amsgrad=True, |
|
) |
|
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer_smpl, |
|
mode="min", |
|
factor=0.5, |
|
verbose=0, |
|
min_lr=1e-5, |
|
patience=5, |
|
) |
|
smpl_steps = 100 |
|
|
|
for i in tqdm(range(smpl_steps)): |
|
optimizer_smpl.zero_grad() |
|
|
|
optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view( |
|
-1, 6)).unsqueeze(0) |
|
optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view( |
|
-1, 6)).unsqueeze(0) |
|
|
|
smpl_verts, smpl_landmarks, smpl_joints = self.econ_dataset.smpl_model( |
|
shape_params=optimed_betas, |
|
expression_params=tensor2variable(pose["exp"], self.device), |
|
body_pose=optimed_pose_mat, |
|
global_pose=optimed_orient_mat, |
|
jaw_pose=tensor2variable(pose["jaw_pose"], self.device), |
|
left_hand_pose=optimed_lhand, |
|
right_hand_pose=optimed_rhand, |
|
|
|
) |
|
|
|
smpl_verts = smpl_verts + optimed_trans |
|
|
|
v_smpl = torch.matmul(torch.matmul(smpl_verts.squeeze(0), rz.T), ry.T) |
|
if scale is None: |
|
scale, offset = scale_mesh(v_smpl.detach()) |
|
v_smpl = (v_smpl + offset) * scale * 2 |
|
|
|
|
|
|
|
normals = calc_vertex_normals(v_smpl, self.smplx_face) |
|
nrm = self.renderer.render(v_smpl, self.smplx_face, normals=normals) |
|
|
|
masks_ = nrm[..., 3:] |
|
smpl_mask_loss = ((masks_ - masks) * self.weights).abs().mean() |
|
smpl_nrm_loss = ((nrm[..., :3] - target_normals) * self.weights).abs().mean() |
|
|
|
smpl_loss = smpl_mask_loss + smpl_nrm_loss |
|
|
|
smpl_loss.backward() |
|
optimizer_smpl.step() |
|
scheduler_smpl.step(smpl_loss) |
|
|
|
mesh_smpl = trimesh.Trimesh(vertices=v_smpl.detach().cpu().numpy(), faces=self.smplx_face.detach().cpu().numpy()) |
|
|
|
|
|
nrm_opt = MeshOptimizer(v_smpl.detach(), self.smplx_face.detach(), edge_len_lims=[0.01, 0.1]) |
|
vertices, faces = nrm_opt.vertices, nrm_opt.faces |
|
|
|
|
|
for i in tqdm(range(self.opt.iters)): |
|
nrm_opt.zero_grad() |
|
|
|
normals = calc_vertex_normals(vertices,faces) |
|
nrm = self.renderer.render(vertices,faces, normals=normals) |
|
normals = nrm[..., :3] |
|
|
|
loss = ((normals-target_normals) * self.weights).abs().mean() |
|
|
|
|
|
|
|
alpha = nrm[..., 3:] |
|
loss += ((alpha - masks) * self.weights).abs().mean() |
|
|
|
loss.backward() |
|
|
|
nrm_opt.step() |
|
|
|
vertices,faces = nrm_opt.remesh() |
|
|
|
if self.opt.debug and i % self.opt.snapshot_step == 0: |
|
import imageio |
|
os.makedirs(f'{case_path}/normals', exist_ok=True) |
|
imageio.imwrite(f'{case_path}/normals/{i:02d}.png',(nrm.detach()[0,:,:,:3]*255).clamp(max=255).type(torch.uint8).cpu().numpy()) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
mesh_remeshed = trimesh.Trimesh(vertices=vertices.detach().cpu().numpy(), faces=faces.detach().cpu().numpy()) |
|
mesh_remeshed.export(f'{case_path}/{case}_remeshed.obj') |
|
|
|
vertices = vertices.detach() |
|
faces = faces.detach() |
|
|
|
|
|
smpl_data = SMPLX() |
|
if self.opt.replace_hand and True in pose['hands_visibility'][0]: |
|
hand_mask = torch.zeros(smpl_data.smplx_verts.shape[0], ) |
|
if pose['hands_visibility'][0][0]: |
|
hand_mask.index_fill_( |
|
0, torch.tensor(smpl_data.smplx_mano_vid_dict["left_hand"]), 1.0 |
|
) |
|
if pose['hands_visibility'][0][1]: |
|
hand_mask.index_fill_( |
|
0, torch.tensor(smpl_data.smplx_mano_vid_dict["right_hand"]), 1.0 |
|
) |
|
|
|
hand_mesh = apply_vertex_mask(mesh_smpl.copy(), hand_mask) |
|
body_mesh = part_removal( |
|
mesh_remeshed.copy(), |
|
hand_mesh, |
|
0.08, |
|
self.device, |
|
mesh_smpl.copy(), |
|
region="hand" |
|
) |
|
final = poisson(sum([hand_mesh, body_mesh]), f'{case_path}/{case}_final.obj', 10, False) |
|
else: |
|
final = poisson(mesh_remeshed, f'{case_path}/{case}_final.obj', 10, False) |
|
vertices = torch.from_numpy(final.vertices).float().to(self.device) |
|
faces = torch.from_numpy(final.faces).long().to(self.device) |
|
|
|
masked_color = [] |
|
for tmp in clr_img: |
|
|
|
tmp = tmp.resize((self.resolution, self.resolution), Image.BILINEAR) |
|
tmp = np.array(tmp).astype(np.float32) / 255. |
|
masked_color.append(torch.from_numpy(tmp).permute(2, 0, 1).to(self.device)) |
|
|
|
meshes = self.proj_texture(masked_color, vertices, faces) |
|
vertices = meshes.verts_packed().float() |
|
faces = meshes.faces_packed().long() |
|
colors = meshes.textures.verts_features_packed().float() |
|
save_mesh(f'./{case_path}/result_clr_scale{self.opt.scale}_{case}.obj', vertices, faces, colors) |
|
self.evaluate(vertices, colors, faces, save_path=f'{case_path}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) |
|
|
|
|
|
def evaluate(self, target_vertices, target_colors, target_faces, save_path=None, save_nrm=False): |
|
mv, proj = make_round_views(60, self.opt.scale, device=self.device) |
|
renderer = NormalsRenderer(mv, proj, [512, 512], device=self.device) |
|
|
|
target_images = renderer.render(target_vertices,target_faces, colors=target_colors) |
|
target_images = target_images.detach().cpu().numpy() |
|
target_images = target_images[..., :3] * target_images[..., 3:4] + bg_color * (1 - target_images[..., 3:4]) |
|
target_images = (target_images.clip(0, 1) * 255).astype(np.uint8) |
|
|
|
if save_nrm: |
|
target_normals = calc_vertex_normals(target_vertices, target_faces) |
|
|
|
target_normals = renderer.render(target_vertices, target_faces, normals=target_normals) |
|
target_normals = target_normals.detach().cpu().numpy() |
|
target_normals = target_normals[..., :3] * target_normals[..., 3:4] + bg_color * (1 - target_normals[..., 3:4]) |
|
target_normals = (target_normals.clip(0, 1) * 255).astype(np.uint8) |
|
frames = [np.concatenate([img, nrm], 1) for img, nrm in zip(target_images, target_normals)] |
|
else: |
|
frames = [img for img in target_images] |
|
if save_path is not None: |
|
write_video(frames, fps=25, save_path=save_path) |
|
return frames |
|
|
|
def run(self): |
|
cases = sorted(os.listdir(self.opt.imgs_path)) |
|
for idx in range(len(cases)): |
|
case = cases[idx].split('.')[0] |
|
print(f'Processing {case}') |
|
pose = self.econ_dataset.__getitem__(idx) |
|
v, f, c = self.optimize_case(case, pose, None, None, opti_texture=True) |
|
self.evaluate(v, c, f, save_path=f'{self.opt.res_path}/{case}/result_clr_scale{self.opt.scale}_{case}.mp4', save_nrm=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--config", help="path to the yaml configs file", default='config.yaml') |
|
args, extras = parser.parse_known_args() |
|
|
|
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) |
|
from econdataset import SMPLDataset |
|
dataset_param = {'image_dir': opt.imgs_path, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} |
|
econdata = SMPLDataset(dataset_param, device='cuda') |
|
EHuman = ReMesh(opt, econdata) |
|
EHuman.run() |
|
|
|
|
|
|
|
|