import os import tyro import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from lgm.options import AllConfigs, Options from lgm.gs import GaussianRenderer import mcubes import nerfacc import nvdiffrast.torch as dr import kiui from kiui.mesh import Mesh from kiui.mesh_utils import clean_mesh, decimate_mesh from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency from kiui.op import uv_padding, safe_normalize, inverse_sigmoid from kiui.cam import orbit_camera, get_perspective from kiui.nn import MLP, trunc_exp from kiui.gridencoder import GridEncoder def get_rays(pose, h, w, fovy, opengl=True): x, y = torch.meshgrid( torch.arange(w, device=pose.device), torch.arange(h, device=pose.device), indexing="xy", ) x = x.flatten() y = y.flatten() cx = w * 0.5 cy = h * 0.5 focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) camera_dirs = F.pad( torch.stack( [ (x - cx + 0.5) / focal, (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), ], dim=-1, ), (0, 1), value=(-1.0 if opengl else 1.0), ) # [hw, 3] rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] rays_d = safe_normalize(rays_d) return rays_o, rays_d # Triple renderer of gaussians, gaussian, and diso mesh. # gaussian --> nerf --> mesh class Converter(nn.Module): def __init__(self, opt: Options): super().__init__() self.opt = opt self.device = torch.device("cuda") # gs renderer self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) self.proj_matrix[0, 0] = 1 / self.tan_half_fov self.proj_matrix[1, 1] = 1 / self.tan_half_fov self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear) self.proj_matrix[2, 3] = 1 self.gs_renderer = GaussianRenderer(opt) self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device) # nerf renderer if not self.opt.force_cuda_rast: self.glctx = dr.RasterizeGLContext() else: self.glctx = dr.RasterizeCudaContext() self.step = 0 self.render_step_size = 5e-3 self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device) self.estimator = nerfacc.OccGridEstimator( roi_aabb=self.aabb, resolution=64, levels=1 ) self.encoder_density = GridEncoder( num_levels=12 ) # VMEncoder(output_dim=16, mode='sum') self.encoder = GridEncoder(num_levels=12) self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False) self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False) # mesh renderer self.proj = ( torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device) ) self.v = self.f = None self.vt = self.ft = None self.deform = None self.albedo = None @torch.no_grad() def render_gs(self, pose): cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device) cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction # cameras needed by gaussian rasterizer cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] cam_pos = -cam_poses[:, :3, 3] # [V, 3] out = self.gs_renderer.render( self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), ) image = out["image"].squeeze(1).squeeze(0) # [C, H, W] alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W] return image, alpha def get_density(self, xs): # xs: [..., 3] prefix = xs.shape[:-1] xs = xs.view(-1, 3) feats = self.encoder_density(xs) density = trunc_exp(self.mlp_density(feats)) density = density.view(*prefix, 1) return density def render_nerf(self, pose): pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) # get rays resolution = self.opt.output_size rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy) # update occ grid if self.training: def occ_eval_fn(xs): sigmas = self.get_density(xs) return self.render_step_size * sigmas self.estimator.update_every_n_steps( self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8 ) self.step += 1 # render def sigma_fn(t_starts, t_ends, ray_indices): t_origins = rays_o[ray_indices] t_dirs = rays_d[ray_indices] xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 sigmas = self.get_density(xs) return sigmas.squeeze(-1) with torch.no_grad(): ray_indices, t_starts, t_ends = self.estimator.sampling( rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.01, far_plane=100, render_step_size=self.render_step_size, stratified=self.training, cone_angle=0, ) t_origins = rays_o[ray_indices] t_dirs = rays_d[ray_indices] xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 sigmas = self.get_density(xs).squeeze(-1) rgbs = torch.sigmoid(self.mlp(self.encoder(xs))) n_rays = rays_o.shape[0] weights, trans, alphas = nerfacc.render_weight_from_density( t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays ) color = nerfacc.accumulate_along_rays( weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays ) alpha = nerfacc.accumulate_along_rays( weights, values=None, ray_indices=ray_indices, n_rays=n_rays ) color = color + 1 * (1.0 - alpha) color = ( color.view(resolution, resolution, 3) .clamp(0, 1) .permute(2, 0, 1) .contiguous() ) alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous() return color, alpha def fit_nerf(self, iters=512, resolution=128): self.opt.output_size = resolution optimizer = torch.optim.Adam( [ {"params": self.encoder_density.parameters(), "lr": 1e-2}, {"params": self.encoder.parameters(), "lr": 1e-2}, {"params": self.mlp_density.parameters(), "lr": 1e-3}, {"params": self.mlp.parameters(), "lr": 1e-3}, ] ) print(f"[INFO] fitting nerf...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-45, 45) hor = np.random.randint(-180, 180) rad = np.random.uniform(1.5, 3.0) pose = orbit_camera(ver, hor, rad) image_gt, alpha_gt = self.render_gs(pose) image_pred, alpha_pred = self.render_nerf(pose) # if i % 200 == 0: # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred) loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( alpha_pred, alpha_gt ) loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss() loss.backward() self.encoder_density.grad_total_variation(1e-8) optimizer.step() optimizer.zero_grad() pbar.set_description(f"MSE = {loss_mse.item():.6f}") print(f"[INFO] finished fitting nerf!") def render_mesh(self, pose): h = w = self.opt.output_size v = self.v + self.deform f = self.f pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) # get v_clip and render rgb v_cam = ( torch.matmul( F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T ) .float() .unsqueeze(0) ) v_clip = v_cam @ self.proj.T rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] alpha = ( dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) ) # [H, W] important to enable gradients! if self.albedo is None: xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] xyzs = xyzs.view(-1, 3) mask = (alpha > 0).view(-1) image = torch.zeros_like(xyzs, dtype=torch.float32) if mask.any(): masked_albedo = torch.sigmoid( self.mlp(self.encoder(xyzs[mask].detach(), bound=1)) ) image[mask] = masked_albedo.float() else: texc, texc_db = dr.interpolate( self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all" ) image = torch.sigmoid( dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) ) # [1, H, W, 3] image = image.view(1, h, w, 3) # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] image = alpha * image + (1 - alpha) return image, alpha def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4): self.opt.output_size = resolution # init mesh from nerf grid_size = 256 sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32) S = 128 density_thresh = 10 X = torch.linspace(-1, 1, grid_size).split(S) Y = torch.linspace(-1, 1, grid_size).split(S) Z = torch.linspace(-1, 1, grid_size).split(S) for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") pts = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [S, 3] val = self.get_density(pts.to(self.device)) sigmas[ xi * S : xi * S + len(xs), yi * S : yi * S + len(ys), zi * S : zi * S + len(zs), ] = ( val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() ) # [S, 1] --> [x, y, z] print( f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})" ) vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) vertices = vertices / (grid_size - 1.0) * 2 - 1 # clean vertices = vertices.astype(np.float32) triangles = triangles.astype(np.int32) vertices, triangles = clean_mesh( vertices, triangles, remesh=True, remesh_size=0.01 ) if triangles.shape[0] > decimate_target: vertices, triangles = decimate_mesh( vertices, triangles, decimate_target, optimalplacement=False ) self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) # fit mesh from gs lr_factor = 1 optimizer = torch.optim.Adam( [ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, {"params": self.deform, "lr": 1e-4}, ] ) print(f"[INFO] fitting mesh...") pbar = tqdm.trange(iters) for i in pbar: ver = np.random.randint(-10, 10) hor = np.random.randint(-180, 180) rad = self.opt.cam_radius # np.random.uniform(1, 2) pose = orbit_camera(ver, hor, rad) image_gt, alpha_gt = self.render_gs(pose) image_pred, alpha_pred = self.render_mesh(pose) loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( alpha_pred, alpha_gt ) # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f) loss_normal = normal_consistency(self.v + self.deform, self.f) loss_offsets = (self.deform**2).sum(-1).mean() loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets loss.backward() optimizer.step() optimizer.zero_grad() # remesh periodically if i > 0 and i % 512 == 0: vertices = (self.v + self.deform).detach().cpu().numpy() triangles = self.f.detach().cpu().numpy() vertices, triangles = clean_mesh( vertices, triangles, remesh=True, remesh_size=0.01 ) if triangles.shape[0] > decimate_target: vertices, triangles = decimate_mesh( vertices, triangles, decimate_target, optimalplacement=False ) self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) lr_factor *= 0.5 optimizer = torch.optim.Adam( [ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, {"params": self.deform, "lr": 1e-4}, ] ) pbar.set_description(f"MSE = {loss_mse.item():.6f}") # last clean vertices = (self.v + self.deform).detach().cpu().numpy() triangles = self.f.detach().cpu().numpy() vertices, triangles = clean_mesh(vertices, triangles, remesh=False) self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device)) print(f"[INFO] finished fitting mesh!") # uv mesh refine def fit_mesh_uv( self, iters=512, resolution=512, texture_resolution=1024, padding=2 ): self.opt.output_size = resolution # unwrap uv print(f"[INFO] uv unwrapping...") mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) mesh.auto_normal() mesh.auto_uv() self.vt = mesh.vt self.ft = mesh.ft # render uv maps h = w = texture_resolution uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] uv = torch.cat( (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1 ) # [N, 4] rast, _ = dr.rasterize( self.glctx, uv.unsqueeze(0), mesh.ft, (h, w) ) # [1, h, w, 4] xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] mask, _ = dr.interpolate( torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f ) # [1, h, w, 1] # masked query xyzs = xyzs.view(-1, 3) mask = (mask > 0).view(-1) albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) if mask.any(): print(f"[INFO] querying texture...") xyzs = xyzs[mask] # [M, 3] # batched inference to avoid OOM batch = [] head = 0 while head < xyzs.shape[0]: tail = min(head + 640000, xyzs.shape[0]) batch.append( torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float() ) head += 640000 albedo[mask] = torch.cat(batch, dim=0) albedo = albedo.view(h, w, -1) mask = mask.view(h, w) albedo = uv_padding(albedo, mask, padding) # optimize texture self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device) optimizer = torch.optim.Adam( [ {"params": self.albedo, "lr": 1e-3}, ] ) print(f"[INFO] fitting mesh texture...") pbar = tqdm.trange(iters) for i in pbar: # shrink to front view as we care more about it... ver = np.random.randint(-5, 5) hor = np.random.randint(-15, 15) rad = self.opt.cam_radius # np.random.uniform(1, 2) pose = orbit_camera(ver, hor, rad) image_gt, alpha_gt = self.render_gs(pose) image_pred, alpha_pred = self.render_mesh(pose) loss_mse = F.mse_loss(image_pred, image_gt) loss = loss_mse loss.backward() optimizer.step() optimizer.zero_grad() pbar.set_description(f"MSE = {loss_mse.item():.6f}") print(f"[INFO] finished fitting mesh texture!") @torch.no_grad() def export_mesh(self, path): mesh = Mesh( v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device, ) mesh.auto_normal() mesh.write(path) opt = tyro.cli(AllConfigs) # load a saved ply and convert to mesh assert opt.test_path.endswith( ".ply" ), "--test_path must be a .ply file saved by infer.py" converter = Converter(opt).cuda() converter.fit_nerf() converter.fit_mesh() converter.fit_mesh_uv() converter.export_mesh(opt.test_path.replace(".ply", ".glb"))