from lib.renderer.mesh import load_fit_body from lib.dataset.hoppeMesh import HoppeMesh from lib.dataset.body_model import TetraSMPLModel from lib.common.render import Render from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis from termcolor import colored import os.path as osp import numpy as np from PIL import Image import random import os import trimesh import torch from kaolin.ops.mesh import check_sign import torchvision.transforms as transforms from huggingface_hub import hf_hub_download, cached_download class PIFuDataset(): def __init__(self, cfg, split='train', vis=False): self.split = split self.root = cfg.root self.bsize = cfg.batch_size self.overfit = cfg.overfit # for debug, only used in visualize_sampling3D self.vis = vis self.opt = cfg.dataset self.datasets = self.opt.types self.input_size = self.opt.input_size self.scales = self.opt.scales self.workers = cfg.num_threads self.prior_type = cfg.net.prior_type self.noise_type = self.opt.noise_type self.noise_scale = self.opt.noise_scale noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21] self.noise_smpl_idx = [] self.noise_smplx_idx = [] for idx in noise_joints: self.noise_smpl_idx.append(idx * 3) self.noise_smpl_idx.append(idx * 3 + 1) self.noise_smpl_idx.append(idx * 3 + 2) self.noise_smplx_idx.append((idx-1) * 3) self.noise_smplx_idx.append((idx-1) * 3 + 1) self.noise_smplx_idx.append((idx-1) * 3 + 2) self.use_sdf = cfg.sdf self.sdf_clip = cfg.sdf_clip # [(feat_name, channel_num),...] self.in_geo = [item[0] for item in cfg.net.in_geo] self.in_nml = [item[0] for item in cfg.net.in_nml] self.in_geo_dim = [item[1] for item in cfg.net.in_geo] self.in_nml_dim = [item[1] for item in cfg.net.in_nml] self.in_total = self.in_geo + self.in_nml self.in_total_dim = self.in_geo_dim + self.in_nml_dim if self.split == 'train': self.rotations = np.arange( 0, 360, 360 / self.opt.rotation_num).astype(np.int32) else: self.rotations = range(0, 360, 120) self.datasets_dict = {} for dataset_id, dataset in enumerate(self.datasets): mesh_dir = None smplx_dir = None dataset_dir = osp.join(self.root, dataset) if dataset in ['thuman2']: mesh_dir = osp.join(dataset_dir, "scans") smplx_dir = osp.join(dataset_dir, "fits") smpl_dir = osp.join(dataset_dir, "smpl") self.datasets_dict[dataset] = { "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str), "smplx_dir": smplx_dir, "smpl_dir": smpl_dir, "mesh_dir": mesh_dir, "scale": self.scales[dataset_id] } self.subject_list = self.get_subject_list(split) self.smplx = SMPLX() # PIL to tensor self.image_to_tensor = transforms.Compose([ transforms.Resize(self.input_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # PIL to tensor self.mask_to_tensor = transforms.Compose([ transforms.Resize(self.input_size), transforms.ToTensor(), transforms.Normalize((0.0, ), (1.0, )) ]) self.device = torch.device(f"cuda:{cfg.gpus[0]}") self.render = Render(size=512, device=self.device) def render_normal(self, verts, faces): # render optimized mesh (normal, T_normal, image [-1,1]) self.render.load_meshes(verts, faces) return self.render.get_rgb_image() def get_subject_list(self, split): subject_list = [] for dataset in self.datasets: split_txt = osp.join(self.root, dataset, f'{split}.txt') if osp.exists(split_txt): print(f"load from {split_txt}") subject_list += np.loadtxt(split_txt, dtype=str).tolist() else: full_txt = osp.join(self.root, dataset, 'all.txt') print(f"split {full_txt} into train/val/test") full_lst = np.loadtxt(full_txt, dtype=str) full_lst = [dataset+"/"+item for item in full_lst] [train_lst, test_lst, val_lst] = np.split( full_lst, [500, 500+5, ]) np.savetxt(full_txt.replace( "all", "train"), train_lst, fmt="%s") np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s") np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s") print(f"load from {split_txt}") subject_list += np.loadtxt(split_txt, dtype=str).tolist() if self.split != 'test': subject_list += subject_list[:self.bsize - len(subject_list) % self.bsize] print(colored(f"total: {len(subject_list)}", "yellow")) random.shuffle(subject_list) # subject_list = ["thuman2/0008"] return subject_list def __len__(self): return len(self.subject_list) * len(self.rotations) def __getitem__(self, index): # only pick the first data if overfitting if self.overfit: index = 0 rid = index % len(self.rotations) mid = index // len(self.rotations) rotation = self.rotations[rid] subject = self.subject_list[mid].split("/")[1] dataset = self.subject_list[mid].split("/")[0] render_folder = "/".join([dataset + f"_{self.opt.rotation_num}views", subject]) # setup paths data_dict = { 'dataset': dataset, 'subject': subject, 'rotation': rotation, 'scale': self.datasets_dict[dataset]["scale"], 'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"), 'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"), 'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"), 'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'), 'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'), 'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png') } # load training data data_dict.update(self.load_calib(data_dict)) # image/normal/depth loader for name, channel in zip(self.in_total, self.in_total_dim): if f'{name}_path' not in data_dict.keys(): data_dict.update({ f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png') }) # tensor update data_dict.update({ name: self.imagepath2tensor( data_dict[f'{name}_path'], channel, inv=False) }) data_dict.update(self.load_mesh(data_dict)) data_dict.update(self.get_sampling_geo( data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf)) data_dict.update(self.load_smpl(data_dict, self.vis)) if self.prior_type == 'pamir': data_dict.update(self.load_smpl_voxel(data_dict)) if (self.split != 'test') and (not self.vis): del data_dict['verts'] del data_dict['faces'] if not self.vis: del data_dict['mesh'] path_keys = [ key for key in data_dict.keys() if '_path' in key or '_dir' in key ] for key in path_keys: del data_dict[key] return data_dict def imagepath2tensor(self, path, channel=3, inv=False): rgba = Image.open(path).convert('RGBA') mask = rgba.split()[-1] image = rgba.convert('RGB') image = self.image_to_tensor(image) mask = self.mask_to_tensor(mask) image = (image * mask)[:channel] return (image * (0.5 - inv) * 2.0).float() def load_calib(self, data_dict): calib_data = np.loadtxt(data_dict['calib_path'], dtype=float) extrinsic = calib_data[:4, :4] intrinsic = calib_data[4:8, :4] calib_mat = np.matmul(intrinsic, extrinsic) calib_mat = torch.from_numpy(calib_mat).float() return {'calib': calib_mat} def load_mesh(self, data_dict): mesh_path = data_dict['mesh_path'] scale = data_dict['scale'] mesh_ori = trimesh.load(mesh_path, skip_materials=True, process=False, maintain_order=True) verts = mesh_ori.vertices * scale faces = mesh_ori.faces vert_normals = np.array(mesh_ori.vertex_normals) face_normals = np.array(mesh_ori.face_normals) mesh = HoppeMesh(verts, faces, vert_normals, face_normals) return { 'mesh': mesh, 'verts': torch.as_tensor(mesh.verts).float(), 'faces': torch.as_tensor(mesh.faces).long() } def add_noise(self, beta_num, smpl_pose, smpl_betas, noise_type, noise_scale, type, hashcode): np.random.seed(hashcode) if type == 'smplx': noise_idx = self.noise_smplx_idx else: noise_idx = self.noise_smpl_idx if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0: smpl_betas += (np.random.rand(beta_num) - 0.5) * 2.0 * noise_scale[noise_type.index("beta")] smpl_betas = smpl_betas.astype(np.float32) if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0: smpl_pose[noise_idx] += ( np.random.rand(len(noise_idx)) - 0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")] smpl_pose = smpl_pose.astype(np.float32) if type == 'smplx': return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...]) else: return smpl_pose, smpl_betas def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None): dataset = data_dict['dataset'] smplx_dict = {} smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True) smplx_pose = smplx_param["body_pose"] # [1,63] smplx_betas = smplx_param["betas"] # [1,10] smplx_pose, smplx_betas = self.add_noise( smplx_betas.shape[1], smplx_pose[0], smplx_betas[0], noise_type, noise_scale, type='smplx', hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8)) smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'], scale=self.datasets_dict[dataset]['scale'], smpl_type='smplx', smpl_gender='male', noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose)) smplx_dict.update({"type": "smplx", "gender": 'male', "body_pose": torch.as_tensor(smplx_pose), "betas": torch.as_tensor(smplx_betas)}) return smplx_out.vertices, smplx_dict def compute_voxel_verts(self, data_dict, noise_type=None, noise_scale=None): smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True) smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True) smpl_pose = rotation_matrix_to_angle_axis( torch.as_tensor(smpl_param['full_pose'][0])).numpy() smpl_betas = smpl_param["betas"] smpl_path = cached_download(osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl"), use_auth_token=os.environ['ICON']) tetra_path = cached_download(osp.join(self.smplx.tedra_dir, "tetra_male_adult_smpl.npz"), use_auth_token=os.environ['ICON']) smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') smpl_pose, smpl_betas = self.add_noise( smpl_model.beta_shape[0], smpl_pose.flatten(), smpl_betas[0], noise_type, noise_scale, type='smpl', hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8)) smpl_model.set_params(pose=smpl_pose.reshape(-1, 3), beta=smpl_betas, trans=smpl_param["transl"]) verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added], axis=0) * smplx_param["scale"] + smplx_param["translation"] ) * self.datasets_dict[data_dict['dataset']]['scale'] faces = np.loadtxt(cached_download(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"), use_auth_token=os.environ['ICON']), dtype=np.int32) - 1 pad_v_num = int(8000 - verts.shape[0]) pad_f_num = int(25100 - faces.shape[0]) verts = np.pad(verts, ((0, pad_v_num), (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant', constant_values=0.0).astype(np.int32) return verts, faces, pad_v_num, pad_f_num def load_smpl(self, data_dict, vis=False): smplx_verts, smplx_dict = self.compute_smpl_verts( data_dict, self.noise_type, self.noise_scale) # compute using smpl model smplx_verts = projection(smplx_verts, data_dict['calib']).float() smplx_faces = torch.as_tensor(self.smplx.faces).long() smplx_vis = torch.load(data_dict['vis_path']).float() smplx_cmap = torch.as_tensor( np.load(self.smplx.cmap_vert_path)).float() # get smpl_signs query_points = projection(data_dict['samples_geo'], data_dict['calib']).float() pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0), smplx_faces, query_points.unsqueeze(0)).float() - 0.5).squeeze(0) return_dict = { 'smpl_verts': smplx_verts, 'smpl_faces': smplx_faces, 'smpl_vis': smplx_vis, 'smpl_cmap': smplx_cmap, 'pts_signs': pts_signs } if smplx_dict is not None: return_dict.update(smplx_dict) if vis: (xy, z) = torch.as_tensor(smplx_verts).to( self.device).split([2, 1], dim=1) smplx_vis = get_visibility(xy, z, torch.as_tensor( smplx_faces).to(self.device).long()) T_normal_F, T_normal_B = self.render_normal( (smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device), smplx_faces.to(self.device)) return_dict.update({"T_normal_F": T_normal_F.squeeze(0), "T_normal_B": T_normal_B.squeeze(0)}) query_points = projection(data_dict['samples_geo'], data_dict['calib']).float() smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch( smplx_verts.unsqueeze(0).to(self.device), smplx_faces.unsqueeze(0).to(self.device), smplx_cmap.unsqueeze(0).to(self.device), smplx_vis.unsqueeze(0).to(self.device), query_points.unsqueeze(0).contiguous().to(self.device)) return_dict.update({ 'smpl_feat': torch.cat( (smplx_sdf[0].detach().cpu(), smplx_cmap[0].detach().cpu(), smplx_norm[0].detach().cpu(), smplx_vis[0].detach().cpu()), dim=1) }) return return_dict def load_smpl_voxel(self, data_dict): smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts( data_dict, self.noise_type, self.noise_scale) # compute using smpl model smpl_verts = projection(smpl_verts, data_dict['calib']) smpl_verts *= 0.5 return { 'voxel_verts': smpl_verts, 'voxel_faces': smpl_faces, 'pad_v_num': pad_v_num, 'pad_f_num': pad_f_num } def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False): mesh = data_dict['mesh'] calib = data_dict['calib'] # Samples are around the true surface with an offset n_samples_surface = 4 * self.opt.num_sample_geo vert_ids = np.arange(mesh.verts.shape[0]) thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32) thickness_sample_ratio /= thickness_sample_ratio.sum() samples_surface_ids = np.random.choice(vert_ids, n_samples_surface, replace=True, p=thickness_sample_ratio) samples_normal_ids = np.random.choice(vert_ids, self.opt.num_sample_geo // 2, replace=False, p=thickness_sample_ratio) surf_samples = mesh.verts[samples_normal_ids, :] surf_normals = mesh.vert_normals[samples_normal_ids, :] samples_surface = mesh.verts[samples_surface_ids, :] # Sampling offsets are random noise with constant scale (15cm - 20cm) offset = np.random.normal(scale=self.opt.sigma_geo, size=(n_samples_surface, 1)) samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset # Uniform samples in [-1, 1] calib_inv = np.linalg.inv(calib) n_samples_space = self.opt.num_sample_geo // 4 samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0 samples_space = projection(samples_space_img, calib_inv) # z-ray direction samples if self.opt.zray_type and not is_valid: n_samples_rayz = self.opt.ray_sample_num samples_surface_cube = projection(samples_surface, calib) samples_surface_cube_repeat = np.repeat(samples_surface_cube, n_samples_rayz, axis=0) thickness_repeat = np.repeat(0.5 * np.ones_like(samples_surface_ids), n_samples_rayz, axis=0) noise_repeat = np.random.normal(scale=0.40, size=(n_samples_surface * n_samples_rayz, )) samples_surface_cube_repeat[:, -1] += thickness_repeat * noise_repeat samples_surface_rayz = projection(samples_surface_cube_repeat, calib_inv) samples = np.concatenate( [samples_surface, samples_space, samples_surface_rayz], 0) else: samples = np.concatenate([samples_surface, samples_space], 0) np.random.shuffle(samples) # labels: in->1.0; out->0.0. if is_sdf: sdfs = mesh.get_sdf(samples) inside_samples = samples[sdfs < 0] outside_samples = samples[sdfs >= 0] inside_sdfs = sdfs[sdfs < 0] outside_sdfs = sdfs[sdfs >= 0] else: inside = mesh.contains(samples) inside_samples = samples[inside >= 0.5] outside_samples = samples[inside < 0.5] nin = inside_samples.shape[0] if nin > self.opt.num_sample_geo // 2: inside_samples = inside_samples[:self.opt.num_sample_geo // 2] outside_samples = outside_samples[:self.opt.num_sample_geo // 2] if is_sdf: inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2] outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2] else: outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)] if is_sdf: outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)] if is_sdf: samples = np.concatenate( [inside_samples, outside_samples, surf_samples], 0) labels = np.concatenate([ inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0]) ]) normals = np.zeros_like(samples) normals[-self.opt.num_sample_geo // 2:, :] = surf_normals # convert sdf from [-14, 130] to [0, 1] # outside: 0, inside: 1 # Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0) labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip) labels += self.sdf_clip labels /= (self.sdf_clip * 2) else: samples = np.concatenate([inside_samples, outside_samples]) labels = np.concatenate([ np.ones(inside_samples.shape[0]), np.zeros(outside_samples.shape[0]) ]) normals = np.zeros_like(samples) samples = torch.from_numpy(samples).float() labels = torch.from_numpy(labels).float() normals = torch.from_numpy(normals).float() return {'samples_geo': samples, 'labels_geo': labels}