import torch from .gaussian_utils import render, GaussianModel class GaussianRenderer: def __init__(self, renderer_config=None): if 'scaling_activation_type' not in renderer_config: renderer_config['scaling_activation_type'] = 'exp' if 'scale_min_act' not in renderer_config: renderer_config['scale_min_act'] = 1 renderer_config['scale_max_act'] = 1 renderer_config['scale_multi_act'] = 0.1 self.gaussian_model = GaussianModel(sh_degree=renderer_config.sh_degree, scaling_activation_type=renderer_config.scaling_activation_type, scale_min_act=renderer_config.scale_min_act, scale_max_act=renderer_config.scale_max_act, scale_multi_act=renderer_config.scale_multi_act) self.img_height = renderer_config.img_height self.img_width = renderer_config.img_width self.bg_color = renderer_config.bg_color if 'bg_color' in renderer_config else (1.0, 1.0, 1.0) def render(self, latent, output_fxfycxcy, output_c2ws, render_size=None): if render_size is None: img_height, img_width = self.img_height, self.img_width else: img_height, img_width = render_size shs_dim = (self.gaussian_model.sh_degree + 1) ** 2 * 3 xyz, features, opacity, scaling, rotation = latent.split([3, shs_dim, 1, 2, 4], dim=-1) features = features.reshape(features.shape[0], -1, shs_dim//3, 3) bs, vs = output_fxfycxcy.shape[:2] images = torch.zeros(bs, vs, 3, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) alphas = torch.zeros(bs, vs, 1, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) depths = torch.zeros(bs, vs, 1, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) surf_normals = torch.zeros(bs, vs, 3, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) rend_normals = torch.zeros(bs, vs, 3, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) dists = torch.zeros(bs, vs, 1, img_height, img_width, dtype=torch.float32, device=output_c2ws.device) for idx in range(bs): pc = self.gaussian_model.set_data(xyz[idx], features[idx], scaling[idx], rotation[idx], opacity[idx]) for vidx in range(vs): render_results = render(pc, img_height, img_width, output_c2ws[idx, vidx], output_fxfycxcy[idx, vidx], self.bg_color) image = render_results['render'] alpha = render_results['alpha'] depth = render_results['depth'] surf_normal = render_results['surf_normal'] rend_normal = render_results['rend_normal'] dist = render_results['dist'] images[idx, vidx] = image alphas[idx, vidx] = alpha depths[idx, vidx] = depth surf_normals[idx, vidx] = surf_normal rend_normals[idx, vidx] = rend_normal dists[idx, vidx] = dist results = {'image': images, 'alpha': alphas, 'depth': depths, 'surf_normals': surf_normals, 'rend_normals': rend_normals, 'dists': dists} return results