# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. import torch import torch.nn.functional as F import nvdiffrast.torch as dr from . import Renderer from . import util from . import renderutils as ru _FG_LUT = None def interpolate(attr, rast, attr_idx, rast_db=None): return dr.interpolate( attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') def xfm_points(points, matrix, use_python=True): '''Transform points. Args: points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] use_python: Use PyTorch's torch.matmul (for validation) Returns: Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. ''' out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" return out def dot(x, y): return torch.sum(x * y, -1, keepdim=True) def compute_vertex_normal(v_pos, t_pos_idx): i0 = t_pos_idx[:, 0] i1 = t_pos_idx[:, 1] i2 = t_pos_idx[:, 2] v0 = v_pos[i0, :] v1 = v_pos[i1, :] v2 = v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) # Splat face normals to vertices v_nrm = torch.zeros_like(v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where( dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) ) v_nrm = F.normalize(v_nrm, dim=1) assert torch.all(torch.isfinite(v_nrm)) return v_nrm class NeuralRender(Renderer): def __init__(self, device='cuda', camera_model=None): super(NeuralRender, self).__init__() self.device = device self.ctx = dr.RasterizeCudaContext(device=device) self.projection_mtx = None self.camera = camera_model # ============================================================================================== # pixel shader # ============================================================================================== # def shade( # self, # gb_pos, # gb_geometric_normal, # gb_normal, # gb_tangent, # gb_texc, # gb_texc_deriv, # view_pos, # ): # ################################################################################ # # Texture lookups # ################################################################################ # breakpoint() # # Separate kd into alpha and color, default alpha = 1 # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) # kd = kd[..., 0:3] # ################################################################################ # # Normal perturbation & normal bend # ################################################################################ # perturbed_nrm = None # gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) # ################################################################################ # # Evaluate BSDF # ################################################################################ # assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" # bsdf = material['bsdf'] if bsdf is None else bsdf # if bsdf == 'pbr': # if isinstance(lgt, light.EnvironmentLight): # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True) # else: # assert False, "Invalid light type" # elif bsdf == 'diffuse': # if isinstance(lgt, light.EnvironmentLight): # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False) # else: # assert False, "Invalid light type" # elif bsdf == 'normal': # shaded_col = (gb_normal + 1.0)*0.5 # elif bsdf == 'tangent': # shaded_col = (gb_tangent + 1.0)*0.5 # elif bsdf == 'kd': # shaded_col = kd # elif bsdf == 'ks': # shaded_col = ks # else: # assert False, "Invalid BSDF '%s'" % bsdf # # Return multiple buffers # buffers = { # 'shaded' : torch.cat((shaded_col, alpha), dim=-1), # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1) # } # return buffers # ============================================================================================== # Render a depth slice of the mesh (scene), some limitations: # - Single mesh # - Single light # - Single material # ============================================================================================== def render_layer( self, rast, rast_deriv, mesh, view_pos, resolution, spp, msaa ): # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution rast_out_s = rast rast_out_deriv_s = rast_deriv ################################################################################ # Interpolate attributes ################################################################################ # Interpolate world space position gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) # Compute geometric normals. We need those because of bent normals trick (for bump mapping) v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) # Compute tangent space assert mesh.v_nrm is not None and mesh.v_tng is not None gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents # Texture coordinate # assert mesh.v_tex is not None # gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) perturbed_nrm = None gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) return gb_pos, gb_normal def render_mesh( self, mesh_v_pos_bxnx3, mesh_t_pos_idx_fx3, mesh, camera_mv_bx4x4, camera_pos, mesh_v_feat_bxnxd, resolution=256, spp=1, device='cuda', hierarchical_mask=False ): assert not hierarchical_mask mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates v_pos_clip = self.camera.project(v_pos) # Projection in the camera # view_pos = torch.linalg.inv(mtx_in)[:, :3, 3] view_pos = camera_pos v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates # Render the image, # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render num_layers = 1 mask_pyramid = None assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization] layers = [] with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler: for _ in range(num_layers): rast, db = peeler.rasterize_next_layer() gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False) with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: for _ in range(num_layers): rast, db = peeler.rasterize_next_layer() gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) hard_mask = torch.clamp(rast[..., -1:], 0, 1) antialias_mask = dr.antialias( hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) depth = gb_feat[..., -2:-1] ori_mesh_feature = gb_feat[..., :-4] normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) # normal = F.normalize(normal, dim=-1) # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal def render_mesh_light( self, mesh_v_pos_bxnx3, mesh_t_pos_idx_fx3, mesh, camera_mv_bx4x4, mesh_v_feat_bxnxd, resolution=256, spp=1, device='cuda', hierarchical_mask=False ): assert not hierarchical_mask mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates v_pos_clip = self.camera.project(v_pos) # Projection in the camera v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates # Render the image, # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render num_layers = 1 mask_pyramid = None assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: for _ in range(num_layers): rast, db = peeler.rasterize_next_layer() gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) hard_mask = torch.clamp(rast[..., -1:], 0, 1) antialias_mask = dr.antialias( hard_mask.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) depth = gb_feat[..., -2:-1] ori_mesh_feature = gb_feat[..., :-4] normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) normal = F.normalize(normal, dim=-1) normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal