Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Gaussian Splatting. | |
Partially borrowed from https://github.com/graphdeco-inria/gaussian-splatting. | |
""" | |
import os | |
import torch | |
from torch import nn | |
import numpy as np | |
from diff_surfel_rasterization import ( | |
GaussianRasterizationSettings, | |
GaussianRasterizer, | |
) | |
from plyfile import PlyData, PlyElement | |
from scipy.spatial.transform import Rotation as R | |
def strip_lowerdiag(L): | |
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device) | |
uncertainty[:, 0] = L[:, 0, 0] | |
uncertainty[:, 1] = L[:, 0, 1] | |
uncertainty[:, 2] = L[:, 0, 2] | |
uncertainty[:, 3] = L[:, 1, 1] | |
uncertainty[:, 4] = L[:, 1, 2] | |
uncertainty[:, 5] = L[:, 2, 2] | |
return uncertainty | |
def strip_symmetric(sym): | |
return strip_lowerdiag(sym) | |
def build_rotation(r): | |
norm = torch.sqrt( | |
r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] | |
) | |
q = r / norm[:, None] | |
R = torch.zeros((q.size(0), 3, 3), device=r.device) | |
r = q[:, 0] | |
x = q[:, 1] | |
y = q[:, 2] | |
z = q[:, 3] | |
R[:, 0, 0] = 1 - 2 * (y * y + z * z) | |
R[:, 0, 1] = 2 * (x * y - r * z) | |
R[:, 0, 2] = 2 * (x * z + r * y) | |
R[:, 1, 0] = 2 * (x * y + r * z) | |
R[:, 1, 1] = 1 - 2 * (x * x + z * z) | |
R[:, 1, 2] = 2 * (y * z - r * x) | |
R[:, 2, 0] = 2 * (x * z - r * y) | |
R[:, 2, 1] = 2 * (y * z + r * x) | |
R[:, 2, 2] = 1 - 2 * (x * x + y * y) | |
return R | |
def build_scaling_rotation(s, r): | |
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device) | |
R = build_rotation(r) | |
L[:, 0, 0] = s[:, 0] | |
L[:, 1, 1] = s[:, 1] | |
L[:, 2, 2] = s[:, 2] | |
L = R @ L | |
return L | |
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): | |
L = build_scaling_rotation(scaling_modifier * scaling, rotation) | |
actual_covariance = L @ L.transpose(1, 2) | |
symm = strip_symmetric(actual_covariance) | |
return symm | |
def depths_to_points(view, depthmap): | |
c2w = (view.world_view_transform.T).inverse() | |
W, H = view.w, view.h | |
ndc2pix = torch.tensor([ | |
[W / 2, 0, 0, (W) / 2], | |
[0, H / 2, 0, (H) / 2], | |
[0, 0, 0, 1]]).float().cuda().T | |
projection_matrix = c2w.T @ view.full_proj_transform | |
intrins = (projection_matrix @ ndc2pix)[:3,:3].T | |
grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') | |
points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) | |
rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T | |
rays_o = c2w[:3,3] | |
points = depthmap.reshape(-1, 1) * rays_d + rays_o | |
return points | |
def depth_to_normal(view, depth): | |
""" | |
view: view camera | |
depth: depthmap | |
""" | |
points = depths_to_points(view, depth).reshape(*depth.shape[1:], 3) | |
output = torch.zeros_like(points) | |
dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) | |
dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) | |
normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) | |
output[1:-1, 1:-1, :] = normal_map | |
return output | |
class Camera(nn.Module): | |
def __init__(self, C2W, fxfycxcy, h, w): | |
""" | |
C2W: 4x4 camera-to-world matrix; opencv convention | |
fxfycxcy: 4 | |
""" | |
super().__init__() | |
self.C2W = C2W.float() | |
self.W2C = self.C2W.inverse() | |
self.znear = 0.01 | |
self.zfar = 100.0 | |
self.h = h | |
self.w = w | |
fx, fy, cx, cy = fxfycxcy[0], fxfycxcy[1], fxfycxcy[2], fxfycxcy[3] | |
self.tanfovX = 1 / (2 * fx) | |
self.tanfovY = 1 / (2 * fy) | |
self.fovX = 2 * torch.atan(self.tanfovX) | |
self.fovY = 2 * torch.atan(self.tanfovY) | |
self.shiftX = 2 * cx - 1 | |
self.shiftY = 2 * cy - 1 | |
def getProjectionMatrix(znear, zfar, fovX, fovY, shiftX, shiftY): | |
tanHalfFovY = torch.tan((fovY / 2)) | |
tanHalfFovX = torch.tan((fovX / 2)) | |
top = tanHalfFovY * znear | |
bottom = -top | |
right = tanHalfFovX * znear | |
left = -right | |
P = torch.zeros(4, 4, dtype=torch.float32, device=fovX.device) | |
z_sign = 1.0 | |
P[0, 0] = 2.0 * znear / (right - left) | |
P[1, 1] = 2.0 * znear / (top - bottom) | |
P[0, 2] = (right + left) / (right - left) + shiftX | |
P[1, 2] = (top + bottom) / (top - bottom) + shiftY | |
P[3, 2] = z_sign | |
P[2, 2] = z_sign * zfar / (zfar - znear) | |
P[2, 3] = -(zfar * znear) / (zfar - znear) | |
return P | |
self.world_view_transform = self.W2C.transpose(0, 1) | |
self.projection_matrix = getProjectionMatrix( | |
znear=self.znear, zfar=self.zfar, fovX=self.fovX, fovY=self.fovY, shiftX=self.shiftX, shiftY=self.shiftY | |
).transpose(0, 1) | |
self.full_proj_transform = ( | |
self.world_view_transform.unsqueeze(0).bmm( | |
self.projection_matrix.unsqueeze(0) | |
) | |
).squeeze(0) | |
self.camera_center = self.C2W[:3, 3] | |
class GaussianModel: | |
def setup_functions(self, scaling_activation_type='sigmoid', scale_min_act=0.001, scale_max_act=0.3, scale_multi_act=0.1): | |
if scaling_activation_type == 'exp': | |
self.scaling_activation = torch.exp | |
elif scaling_activation_type == 'softplus': | |
self.scaling_activation = torch.nn.functional.softplus | |
self.scale_multi_act = scale_multi_act | |
elif scaling_activation_type == 'sigmoid': | |
self.scale_min_act = scale_min_act | |
self.scale_max_act = scale_max_act | |
self.scaling_activation = torch.sigmoid | |
else: | |
raise NotImplementedError | |
self.scaling_activation_type = scaling_activation_type | |
self.rotation_activation = torch.nn.functional.normalize | |
self.opacity_activation = torch.sigmoid | |
self.feature_activation = torch.sigmoid | |
self.covariance_activation = build_covariance_from_scaling_rotation | |
def __init__(self, sh_degree: int, scaling_activation_type='exp', scale_min_act=0.001, scale_max_act=0.3, scale_multi_act=0.1): | |
self.sh_degree = sh_degree | |
self._xyz = torch.empty(0) | |
self._features_dc = torch.empty(0) | |
if self.sh_degree > 0: | |
self._features_rest = torch.empty(0) | |
else: | |
self._features_rest = None | |
self._scaling = torch.empty(0) | |
self._rotation = torch.empty(0) | |
self._opacity = torch.empty(0) | |
self.setup_functions(scaling_activation_type=scaling_activation_type, scale_min_act=scale_min_act, scale_max_act=scale_max_act, scale_multi_act=scale_multi_act) | |
def set_data(self, xyz, features, scaling, rotation, opacity): | |
self._xyz = xyz | |
self._features_dc = features[:, 0, :].contiguous() if self.sh_degree == 0 else features[:, 0:1, :].contiguous() | |
if self.sh_degree > 0: | |
self._features_rest = features[:, 1:, :].contiguous() | |
else: | |
self._features_rest = None | |
self._scaling = scaling | |
self._rotation = rotation | |
self._opacity = opacity | |
return self | |
def to(self, device): | |
self._xyz = self._xyz.to(device) | |
self._features_dc = self._features_dc.to(device) | |
if self.sh_degree > 0: | |
self._features_rest = self._features_rest.to(device) | |
self._scaling = self._scaling.to(device) | |
self._rotation = self._rotation.to(device) | |
self._opacity = self._opacity.to(device) | |
return self | |
def get_scaling(self): | |
if self.scaling_activation_type == 'exp': | |
scales = self.scaling_activation(self._scaling) | |
elif self.scaling_activation_type == 'softplus': | |
scales = self.scaling_activation(self._scaling) * self.scale_multi_act | |
elif self.scaling_activation_type == 'sigmoid': | |
scales = self.scale_min_act + (self.scale_max_act - self.scale_min_act) * self.scaling_activation(self._scaling) | |
return scales | |
def get_rotation(self): | |
return self.rotation_activation(self._rotation) | |
def get_xyz(self): | |
return self._xyz | |
def get_features(self): | |
if self.sh_degree > 0: | |
features_dc = self._features_dc | |
features_rest = self._features_rest | |
return torch.cat((features_dc, features_rest), dim=1) | |
else: | |
return self.feature_activation(self._features_dc) | |
def get_opacity(self): | |
return self.opacity_activation(self._opacity) | |
def get_covariance(self, scaling_modifier=1): | |
return self.covariance_activation( | |
self.get_scaling, scaling_modifier, self._rotation | |
) | |
def construct_list_of_attributes(self, num_rest=0): | |
l = ['x', 'y', 'z'] | |
# All channels except the 3 DC | |
for i in range(3): | |
l.append('f_dc_{}'.format(i)) | |
for i in range(num_rest): | |
l.append('f_rest_{}'.format(i)) | |
l.append('opacity') | |
for i in range(self._scaling.shape[1]): | |
l.append('scale_{}'.format(i)) | |
for i in range(self._rotation.shape[1]): | |
l.append('rot_{}'.format(i)) | |
return l | |
def save_ply_vis(self, path): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
xyzs = self._xyz.detach().cpu().numpy() | |
f_dc = self._features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy() | |
opacities = self._opacity.detach().cpu().numpy() | |
scales = torch.log(self.get_scaling) | |
scales = scales.detach().cpu().numpy() | |
rot_mat_vis = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) | |
xyzs = xyzs @ rot_mat_vis.T | |
rotations = self._rotation.detach().cpu().numpy() | |
rotations = R.from_quat(rotations[:, [1,2,3,0]]).as_matrix() | |
rotations = rot_mat_vis @ rotations | |
rotations = R.from_matrix(rotations).as_quat()[:, [3,0,1,2]] | |
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes(0)] | |
elements = np.empty(xyzs.shape[0], dtype=dtype_full) | |
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) | |
elements[:] = list(map(tuple, attributes)) | |
el = PlyElement.describe(elements, 'vertex') | |
PlyData([el]).write(path) | |
def save_ply(self, path): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
xyzs = self._xyz.detach().cpu().numpy() | |
f_dc = self._features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy() | |
if self.sh_degree > 0: | |
f_rest = self._features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy() | |
else: | |
f_rest = np.zeros((f_dc.shape[0], 0), dtype=f_dc.dtype) | |
opacities = self._opacity.detach().cpu().numpy() | |
scales = torch.log(self.get_scaling) | |
scales = scales.detach().cpu().numpy() | |
rotations = self._rotation.detach().cpu().numpy() | |
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes(f_rest.shape[-1])] | |
elements = np.empty(xyzs.shape[0], dtype=dtype_full) | |
attributes = np.concatenate((xyzs, f_dc, f_rest, opacities, scales, rotations), axis=1) | |
elements[:] = list(map(tuple, attributes)) | |
el = PlyElement.describe(elements, "vertex") | |
PlyData([el]).write(path) | |
# def load_ply(self, path): | |
# plydata = PlyData.read(path) | |
# xyz = np.stack((np.asarray(plydata.elements[0]["x"]), | |
# np.asarray(plydata.elements[0]["y"]), | |
# np.asarray(plydata.elements[0]["z"])), axis=1) | |
# opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] | |
# features_dc = np.zeros((xyz.shape[0], 3, 1)) | |
# features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) | |
# features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) | |
# features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) | |
# scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] | |
# scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) | |
# scales = np.zeros((xyz.shape[0], len(scale_names))) | |
# for idx, attr_name in enumerate(scale_names): | |
# scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) | |
# rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] | |
# rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) | |
# rots = np.zeros((xyz.shape[0], len(rot_names))) | |
# for idx, attr_name in enumerate(rot_names): | |
# rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) | |
# self._xyz = torch.from_numpy(xyz.astype(np.float32)) | |
# self._features_dc = torch.from_numpy(features_dc.astype(np.float32)).transpose(1, 2).contiguous() | |
# self._opacity = torch.from_numpy(opacities.astype(np.float32)).contiguous() | |
# self._scaling = torch.from_numpy(scales.astype(np.float32)).contiguous() | |
# self._rotation = torch.from_numpy(rots.astype(np.float32)).contiguous() | |
def render( | |
pc: GaussianModel, | |
height: int, | |
width: int, | |
C2W: torch.Tensor, | |
fxfycxcy: torch.Tensor, | |
bg_color=(1.0, 1.0, 1.0), | |
scaling_modifier=1.0, | |
): | |
""" | |
Render the scene. | |
""" | |
screenspace_points = ( | |
torch.zeros_like( | |
pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" | |
) | |
+ 0 | |
) | |
try: | |
screenspace_points.retain_grad() | |
except: | |
pass | |
viewpoint_camera = Camera(C2W=C2W, fxfycxcy=fxfycxcy, h=height, w=width) | |
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device) | |
raster_settings = GaussianRasterizationSettings( | |
image_height=int(viewpoint_camera.h), | |
image_width=int(viewpoint_camera.w), | |
tanfovx=viewpoint_camera.tanfovX, | |
tanfovy=viewpoint_camera.tanfovY, | |
bg=bg_color, | |
scale_modifier=scaling_modifier, | |
viewmatrix=viewpoint_camera.world_view_transform, | |
projmatrix=viewpoint_camera.full_proj_transform, | |
sh_degree=pc.sh_degree, | |
campos=viewpoint_camera.camera_center, | |
prefiltered=False, | |
debug=False, | |
) | |
rasterizer = GaussianRasterizer(raster_settings=raster_settings) | |
means3D = pc.get_xyz | |
means2D = screenspace_points | |
opacity = pc.get_opacity | |
scales = pc.get_scaling | |
rotations = pc.get_rotation | |
shs = pc.get_features | |
rendered_image, _, allmap = rasterizer( | |
means3D=means3D, | |
means2D=means2D, | |
shs=None if pc.sh_degree == 0 else shs, | |
colors_precomp=shs if pc.sh_degree == 0 else None, | |
opacities=opacity, | |
scales=scales, | |
rotations=rotations, | |
cov3D_precomp=None, | |
) | |
# additional regularizations | |
render_alpha = allmap[1:2] | |
# get normal map | |
# transform normal from view space to world space | |
render_normal = allmap[2:5] | |
render_normal = (render_normal.permute(1, 2, 0) @ (viewpoint_camera.world_view_transform[:3, :3].T)).permute(2, 0, 1) | |
# get median depth map | |
render_depth_median = allmap[5:6] | |
render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) | |
# get expected depth map | |
render_depth_expected = allmap[0:1] | |
render_depth_expected = (render_depth_expected / render_alpha) | |
render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) | |
# get depth distortion map | |
render_dist = allmap[6:7] | |
# psedo surface attributes | |
# surf depth is either median or expected by setting depth_ratio to 1 or 0 | |
# for bounded scene, use median depth, i.e., depth_ratio = 1; | |
# for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. | |
depth_ratio = 0.0 | |
surf_depth = render_depth_expected * (1 - depth_ratio) + depth_ratio * render_depth_median | |
# assume the depth points form the 'surface' and generate psudo surface normal for regularizations. | |
surf_normal = depth_to_normal(viewpoint_camera, surf_depth) | |
surf_normal = surf_normal.permute(2, 0, 1) | |
# remember to multiply with accum_alpha since render_normal is unnormalized. | |
surf_normal = surf_normal * (render_alpha).detach() | |
return { | |
"render": rendered_image, | |
"depth": surf_depth, | |
"alpha": render_alpha, | |
'surf_normal': surf_normal, | |
'rend_normal': render_normal, | |
'dist': render_dist, | |
} | |