import torch import spaces import numpy as np import trimesh from PIL import Image from typing import List from tqdm import tqdm from sklearn.neighbors import KDTree from refine.func import from_py3d_mesh, get_cameras_list, make_star_cameras_orthographic, multiview_color_projection, simple_clean_mesh, to_py3d_mesh, to_pyml_mesh from refine.opt import MeshOptimizer from refine.render import NormalsRenderer, calc_vertex_normals import pytorch3d from pytorch3d.structures import Meshes import xatlas import cv2 def mesh_uv_wrap(vertices, faces): if len(faces) > 50000: raise ValueError("The mesh has more than 50,000 faces, which is not supported.") vmapping, indices, uvs = xatlas.parametrize(vertices, faces) return vertices[vmapping], indices, uvs def stride_from_shape(shape): stride = [1] for x in reversed(shape[1:]): stride.append(stride[-1] * x) return list(reversed(stride)) def scatter_add_nd_with_count(input, count, indices, values, weights=None): # input: [..., C], D dimension + C channel # count: [..., 1], D dimension # indices: [N, D], long # values: [N, C] D = indices.shape[-1] C = input.shape[-1] size = input.shape[:-1] stride = stride_from_shape(size) assert len(size) == D input = input.view(-1, C) # [HW, C] count = count.view(-1, 1) flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N] if weights is None: weights = torch.ones_like(values[..., :1]) input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values) count.scatter_add_(0, flatten_indices.unsqueeze(1), weights) return input.view(*size, C), count.view(*size, 1) def linear_grid_put_2d(H, W, coords, values, return_count=False): # coords: [N, 2], float in [0, 1] # values: [N, C] C = values.shape[-1] indices = coords * torch.tensor( [H - 1, W - 1], dtype=torch.float32, device=coords.device ) indices_00 = indices.floor().long() # [N, 2] indices_00[:, 0].clamp_(0, H - 2) indices_00[:, 1].clamp_(0, W - 2) indices_01 = indices_00 + torch.tensor( [0, 1], dtype=torch.long, device=indices.device ) indices_10 = indices_00 + torch.tensor( [1, 0], dtype=torch.long, device=indices.device ) indices_11 = indices_00 + torch.tensor( [1, 1], dtype=torch.long, device=indices.device ) h = indices[..., 0] - indices_00[..., 0].float() w = indices[..., 1] - indices_00[..., 1].float() w_00 = (1 - h) * (1 - w) w_01 = (1 - h) * w w_10 = h * (1 - w) w_11 = h * w result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C] count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1] weights = torch.ones_like(values[..., :1]) # [N, 1] result, count = scatter_add_nd_with_count( result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1)) result, count = scatter_add_nd_with_count( result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1)) if return_count: return result, count mask = (count.squeeze(-1) > 0) result[mask] = result[mask] / count[mask].repeat(1, C) return result, count.squeeze(-1) == 0 def remove_color(arr): if arr.shape[-1] == 4: arr = arr[..., :3] # calc diffs base = arr[0, 0] diffs = np.abs(arr.astype(np.int32) - base.astype(np.int32)).sum(axis=-1) alpha = (diffs <= 80) arr[alpha] = 255 alpha = ~alpha arr = np.concatenate([arr, alpha[..., None].astype(np.int32) * 255], axis=-1) return arr def simple_remove(imgs): """Only works for normal""" if not isinstance(imgs, list): imgs = [imgs] single_input = True else: single_input = False rets = [] for img in imgs: arr = np.array(img) arr = remove_color(arr) rets.append(Image.fromarray(arr.astype(np.uint8))) if single_input: return rets[0] return rets def erode_alpha(img_list): out_img_list = [] for idx, img in enumerate(img_list): arr = np.array(img) alpha = (arr[:, :, 3] > 127).astype(np.uint8) # erode 1px import cv2 alpha = cv2.erode(alpha, np.ones((3, 3), np.uint8), iterations=1) alpha = (alpha * 255).astype(np.uint8) img = Image.fromarray(np.concatenate([arr[:, :, :3], alpha[:, :, None]], axis=-1)) out_img_list.append(img) return out_img_list def merge_small_faces(mesh, thres=1e-5): area_faces = mesh.area_faces small_faces = area_faces < thres vertices = mesh.vertices faces = mesh.faces new_vertices = vertices.tolist() vertex_mapping = {} for face_idx in np.where(small_faces)[0]: face = faces[face_idx] v1, v2, v3 = face center = np.mean(vertices[face], axis=0) new_vertex_idx = len(new_vertices) new_vertices.append(center) vertex_mapping[v1] = new_vertex_idx vertex_mapping[v2] = new_vertex_idx vertex_mapping[v3] = new_vertex_idx for k,v in vertex_mapping.items(): faces[faces == k] = v faces = faces[~small_faces] new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=faces, postprocess=False) new_mesh.remove_unreferenced_vertices() new_mesh.remove_degenerate_faces() new_mesh.remove_duplicate_faces() return new_mesh def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"): # Convert the background color to a PyTorch tensor new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device) # Convert all images to PyTorch tensors and process them imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255 img_nps = imgs[..., :3] alpha_nps = imgs[..., 3] ori_bkgds = img_nps[:, :1, :1] # Avoid divide by zero and calculate the original image alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1) ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1) ori_img_nps = torch.clamp(ori_img_nps, 0, 1) img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd) rgba_img_np =[img_nps, alpha_nps.unsqueeze(-1)], dim=-1) return rgba_img_np def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, fixed_v=None, fixed_f=None, lr=0.03, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, loss_expansion_weight=0.1, gain=0.1, remesh_interval=1, remesh_start=0, distract_mask=None, distract_bbox=None): vertices, faces = vertices.cuda(), faces.cuda() assert len(pils) == 6 mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) mv = mv[[4, 3, 2, 0, 6, 5]] renderer = NormalsRenderer(mv,proj,list(pils[0].size)) target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # init from coarse mesh opt = MeshOptimizer(vertices, faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len), lr=lr, remesh_interval=remesh_interval, remesh_start=remesh_start) _vertices = opt.vertices _faces = opt.faces if fixed_v is not None and fixed_f is not None: kdtree = KDTree(fixed_v.cpu().numpy()) mask = target_images[..., -1] < 0.5 for i in tqdm(range(steps)): faces =[_faces, fixed_f + len(_vertices)], dim=0) if fixed_f is not None else _faces vertices =[_vertices, fixed_v], dim=0) if fixed_v is not None else _vertices opt.zero_grad() opt._lr *= decay normals = calc_vertex_normals(vertices,faces) normals[:, 0] *= -1 normals[:, 2] *= -1 images = renderer.render(vertices,normals,faces) loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean() t_mask = images[..., -1] > 0.5 loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean() loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight if distract_mask is not None: hair_visible_normals = normals hair_visible_normals[len(_vertices):] = -1. _images = renderer.render(vertices,hair_visible_normals,faces) loss_distract = (_images[0][distract_mask] - target_images[0][distract_mask]).pow(2).mean() target_outside = target_images[0][..., :3].clone() target_outside[~distract_mask] = 0. loss_outside_distract = (_images[0][..., :3][~distract_mask] - target_outside[..., :3][~distract_mask]).pow(2).mean() loss = loss + loss_distract * 1. + loss_outside_distract * 10. if fixed_v is not None and fixed_f is not None: _, idx = kdtree.query(_vertices.detach().cpu().numpy(), k=1) idx = idx.squeeze() anchors = fixed_v[idx].detach() normals_fixed = calc_vertex_normals(fixed_v, fixed_f) loss_anchor = (torch.clamp(((anchors - _vertices) * normals_fixed[idx]).sum(-1), min=-0)+0).pow(3) loss_anchor_dist_mask = (anchors - _vertices).norm(dim=-1) < 0.05 loss_anchor = loss_anchor[loss_anchor_dist_mask].mean() loss = loss + loss_anchor * 100. # out of box loss_oob = (vertices.abs() > 0.99).float().mean() * 10 loss = loss + loss_oob loss.backward() opt.step() if i % remesh_interval == 0 and i >= remesh_start: _vertices,_faces = opt.remesh(poisson=False) vertices, faces = opt._vertices.detach(), opt._faces.detach() return vertices, faces def run_mesh_refine(vertices, faces, pils: List[Image.Image], fixed_v=None, fixed_f=None, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True, remesh_interval=20): poission_steps = [] assert len(pils) == 6 mv, proj = make_star_cameras_orthographic(8, 1, r=1.2) mv = mv[[4, 3, 2, 0, 6, 5]] renderer = NormalsRenderer(mv,proj,list(pils[0].size)) target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s # init from coarse mesh opt = MeshOptimizer(vertices, faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02) _vertices = opt.vertices _faces = opt.faces alpha_init = None mask = target_images[..., -1] < 0.5 for i in tqdm(range(steps)): faces =[_faces, fixed_f + len(_vertices)], dim=0) if fixed_f is not None else _faces vertices =[_vertices, fixed_v], dim=0) if fixed_v is not None else _vertices opt.zero_grad() opt._lr *= decay normals = calc_vertex_normals(vertices,faces) images = renderer.render(vertices,normals,faces) if alpha_init is None: alpha_init = images.detach() if i < update_warmup or i % update_normal_interval == 0: with torch.no_grad(): py3d_mesh = to_py3d_mesh(vertices, faces, normals) cameras = get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], device=vertices.device, focal=1/1.2) _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=[2,0.8,0.8,2,0.8,0.8], confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear')) target_normal = target_normal * 2 - 1 target_normal = torch.nn.functional.normalize(target_normal, dim=-1) target_normal[:, 0] *= -1 target_normal[:, 2] *= -1 debug_images = renderer.render(vertices,target_normal,faces) d_mask = images[..., -1] > 0.5 loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean() loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean() loss = loss_debug_l2 + loss_alpha_target_mask_l2 # out of box loss_oob = (vertices.abs() > 0.99).float().mean() * 10 loss = loss + loss_oob loss.backward() opt.step() if i % remesh_interval == 0: _vertices,_faces = opt.remesh(poisson=(i in poission_steps)) vertices, faces = opt._vertices.detach(), opt._faces.detach() if process_outputs: vertices = vertices / 2 * 1.35 vertices[..., [0, 2]] = - vertices[..., [0, 2]] return vertices, faces def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None, distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False): print(mesh_v.device, mesh_f.device) if fixed_v is not None: print('fixed_v', fixed_v.shape, fixed_v.device) if fixed_f is not None: print('fixed_f', fixed_f.shape, fixed_f.device) vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose) vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v) return geo_refine_3(vertices, faces, rgb_ls, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask) def geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None, distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False): rm_normals = simple_remove(normal_ls) # transfer the alpha channel of rm_normals to img_list for idx, img in enumerate(rm_normals): rgb_ls[idx] = Image.fromarray(np.concatenate([np.array(rgb_ls[idx])[..., :3], np.array(img)[:, :, 3:4]], axis=-1)) assert np.mean(np.array(rgb_ls[0])[..., 3]) < 250 rgb_ls = erode_alpha(rgb_ls) stage1_lr = 0.08 if fixed_v is None else 0.01 stage1_remesh_interval = 1 if fixed_v is None else 30 if no_decompose: stage1_lr = 0.03 stage1_remesh_interval = 30 if fixed_v is not None: return mesh_v, mesh_f vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f, lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.04, end_edge_len=0.010, gain=0.05, loss_expansion_weight=expansion_weight, distract_mask=distract_mask, distract_bbox=distract_bbox) vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.010, end_edge_len=0.001, decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1) return vertices, faces def geo_refine_2(vertices, faces, fixed_v=None): meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=2, apply_sub_divide=False, sub_divide_threshold=0.25) simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed() vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy() # vertices, faces = trimesh.remesh.subdivide(vertices, faces) return vertices, faces def geo_refine_3(vertices_, faces_, rgb_ls, fixed_v=None, fixed_f=None, distract_mask=None): # vertices, faces, uvs = mesh_uv_wrap(vertices_, faces_) vmapping, indices, uvs = xatlas.parametrize(vertices_, faces_) vertices, faces = vertices_[vmapping], indices def subdivide(vertices, faces, uvs): vertices, faces = trimesh.remesh.subdivide( vertices=np.hstack((vertices, uvs.copy())), faces=faces ) return vertices[:, :3], faces, vertices[:, 3:] if fixed_v is not None: dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(vertices, faces, uvs) dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs) # dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs) dense_vertices, dense_faces = trimesh.remesh.subdivide(vertices_, faces_) dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces) # dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces) else: dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(vertices, faces, uvs) dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs) dense_vertices, dense_faces = trimesh.remesh.subdivide(vertices_, faces_) dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces) origin_len_v, origin_len_f = len(dense_vertices), len(dense_faces) # concatenate fixed_v and fixed_f if fixed_v is not None and fixed_f is not None: dense_vertices, dense_faces = np.concatenate([dense_vertices, fixed_v.detach().cpu().numpy()], axis=0), np.concatenate([dense_faces, fixed_f.detach().cpu().numpy() + len(dense_vertices)], axis=0) dense_vertices, dense_faces = torch.from_numpy(dense_vertices).cuda(), torch.from_numpy(dense_faces.astype('int32')).cuda() # reconstruct meshes meshes = Meshes(verts=[dense_vertices], faces=[dense_faces], textures=pytorch3d.renderer.mesh.textures.TexturesVertex([torch.zeros_like(dense_vertices).float()])) new_meshes = multiview_color_projection(meshes, rgb_ls, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([180, 225, 270, 0, 90, 135], "cuda", focal=1/1.2), weights=[2.0, 0.5, 0.0, 1.0, 0.0, 0.5] if distract_mask is None else [2.0, 0.0, 0.5, 1.0, 0.5, 0.0], distract_mask=distract_mask) if fixed_v is not None and fixed_f is not None: dense_vertices = dense_vertices[:origin_len_v] dense_faces = dense_faces[:origin_len_f] textures = new_meshes.textures.verts_features_packed()[:origin_len_v] else: textures = new_meshes.textures.verts_features_packed() # distances = torch.cdist(torch.tensor(dense_atlas_vertices).cuda(), torch.tensor(dense_vertices).cuda()) # nearest_indices = torch.argmin(distances, dim=1) # atlas_textures = textures[nearest_indices] chunk_size = 500 atlas_textures_chunks = [] for i in range(0, len(dense_atlas_vertices), chunk_size): chunk = dense_atlas_vertices[i:i+chunk_size] distances = torch.cdist(torch.tensor(chunk).cuda(), torch.tensor(dense_vertices).cuda()) nearest_indices = torch.argmin(distances, dim=1) atlas_textures_chunks.append(textures[nearest_indices]) atlas_textures =, dim=0) dense_atlas_uvs = torch.tensor(dense_atlas_uvs, dtype=torch.float32).cuda() tex_img, mask = linear_grid_put_2d(1024, 1024, dense_atlas_uvs, atlas_textures) tex_img, mask = tex_img.cpu().numpy(), mask.cpu().numpy() tex_img = cv2.inpaint((tex_img * 255).astype(np.uint8), (mask*255).astype('uint8'), 3, cv2.INPAINT_NS) tex_img = Image.fromarray(np.transpose(tex_img,(1,0,2))[::-1]) mesh = trimesh.Trimesh(vertices, faces, process=False) # material = trimesh.visual.texture.SimpleMaterial(image=tex_img, diffuse=(255, 255, 255)) material = trimesh.visual.material.PBRMaterial( roughnessFactor=1.0, baseColorTexture=tex_img, baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8) ) texture_visuals = trimesh.visual.TextureVisuals(uv=uvs, image=tex_img, material=material) mesh.visual = texture_visuals return mesh, torch.tensor(vertices).cuda(), torch.tensor(faces.astype('int64')).cuda()