bluestyle97's picture
Update freesplatter/models/renderer/gaussian_utils.py
33aea3f verified
"""
Gaussian Splatting.
Partially borrowed from https://github.com/graphdeco-inria/gaussian-splatting.
"""
import os
import torch
from torch import nn
import numpy as np
from diff_gaussian_rasterization import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation as R
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device)
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def build_rotation(r):
norm = torch.sqrt(
r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]
)
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device=r.device)
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y * y + z * z)
R[:, 0, 1] = 2 * (x * y - r * z)
R[:, 0, 2] = 2 * (x * z + r * y)
R[:, 1, 0] = 2 * (x * y + r * z)
R[:, 1, 1] = 1 - 2 * (x * x + z * z)
R[:, 1, 2] = 2 * (y * z - r * x)
R[:, 2, 0] = 2 * (x * z - r * y)
R[:, 2, 1] = 2 * (y * z + r * x)
R[:, 2, 2] = 1 - 2 * (x * x + y * y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device)
R = build_rotation(r)
L[:, 0, 0] = s[:, 0]
L[:, 1, 1] = s[:, 1]
L[:, 2, 2] = s[:, 2]
L = R @ L
return L
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
class Camera(nn.Module):
def __init__(self, C2W, fxfycxcy, h, w):
"""
C2W: 4x4 camera-to-world matrix; opencv convention
fxfycxcy: 4
"""
super().__init__()
self.C2W = C2W.float()
self.W2C = self.C2W.inverse()
self.znear = 0.01
self.zfar = 100.0
self.h = h
self.w = w
fx, fy, cx, cy = fxfycxcy[0], fxfycxcy[1], fxfycxcy[2], fxfycxcy[3]
self.tanfovX = 1 / (2 * fx)
self.tanfovY = 1 / (2 * fy)
self.fovX = 2 * torch.atan(self.tanfovX)
self.fovY = 2 * torch.atan(self.tanfovY)
self.shiftX = 2 * cx - 1
self.shiftY = 2 * cy - 1
def getProjectionMatrix(znear, zfar, fovX, fovY, shiftX, shiftY):
tanHalfFovY = torch.tan((fovY / 2))
tanHalfFovX = torch.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4, dtype=torch.float32, device=fovX.device)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left) + shiftX
P[1, 2] = (top + bottom) / (top - bottom) + shiftY
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
self.world_view_transform = self.W2C.transpose(0, 1)
self.projection_matrix = getProjectionMatrix(
znear=self.znear, zfar=self.zfar, fovX=self.fovX, fovY=self.fovY, shiftX=self.shiftX, shiftY=self.shiftY
).transpose(0, 1)
self.full_proj_transform = (
self.world_view_transform.unsqueeze(0).bmm(
self.projection_matrix.unsqueeze(0)
)
).squeeze(0)
self.camera_center = self.C2W[:3, 3]
class GaussianModel:
def setup_functions(self, scaling_activation_type='sigmoid', scale_min_act=0.001, scale_max_act=0.3, scale_multi_act=0.1):
if scaling_activation_type == 'exp':
self.scaling_activation = torch.exp
elif scaling_activation_type == 'softplus':
self.scaling_activation = torch.nn.functional.softplus
self.scale_multi_act = scale_multi_act
elif scaling_activation_type == 'sigmoid':
self.scale_min_act = scale_min_act
self.scale_max_act = scale_max_act
self.scaling_activation = torch.sigmoid
else:
raise NotImplementedError
self.scaling_activation_type = scaling_activation_type
self.rotation_activation = torch.nn.functional.normalize
self.opacity_activation = torch.sigmoid
self.feature_activation = torch.sigmoid
self.covariance_activation = build_covariance_from_scaling_rotation
def __init__(self, sh_degree: int, scaling_activation_type='exp', scale_min_act=0.001, scale_max_act=0.3, scale_multi_act=0.1):
self.sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
if self.sh_degree > 0:
self._features_rest = torch.empty(0)
else:
self._features_rest = None
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.setup_functions(scaling_activation_type=scaling_activation_type, scale_min_act=scale_min_act, scale_max_act=scale_max_act, scale_multi_act=scale_multi_act)
def set_data(self, xyz, features, scaling, rotation, opacity, rescale=None):
self._xyz = xyz
self._features_dc = features[:, 0, :].contiguous() if self.sh_degree == 0 else features[:, 0:1, :].contiguous()
if self.sh_degree > 0:
self._features_rest = features[:, 1:, :].contiguous()
else:
self._features_rest = None
self._scaling = scaling
self._rotation = rotation
self._opacity = opacity
if rescale is None:
rescale = torch.ones(1).to(xyz)
self._rescale = rescale
return self
def to(self, device):
self._xyz = self._xyz.to(device)
self._features_dc = self._features_dc.to(device)
if self.sh_degree > 0:
self._features_rest = self._features_rest.to(device)
self._scaling = self._scaling.to(device)
self._rotation = self._rotation.to(device)
self._opacity = self._opacity.to(device)
return self
@property
def get_scaling(self):
if self.scaling_activation_type == 'exp':
scales = self.scaling_activation(self._scaling)
elif self.scaling_activation_type == 'softplus':
scales = self.scaling_activation(self._scaling) * self.scale_multi_act
elif self.scaling_activation_type == 'sigmoid':
scales = self.scale_min_act + (self.scale_max_act - self.scale_min_act) * self.scaling_activation(self._scaling)
scales = scales * self._rescale
return scales
@property
def get_rotation(self):
return self.rotation_activation(self._rotation)
@property
def get_xyz(self):
xyz = self._xyz * self._rescale
return xyz
@property
def get_features(self):
if self.sh_degree > 0:
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
else:
return self.feature_activation(self._features_dc)
@property
def get_opacity(self):
return self.opacity_activation(self._opacity)
def get_covariance(self, scaling_modifier=1):
return self.covariance_activation(
self.get_scaling, scaling_modifier, self._rotation
)
def construct_list_of_attributes(self, num_rest=0):
l = ['x', 'y', 'z']
# All channels except the 3 DC
for i in range(3):
l.append('f_dc_{}'.format(i))
for i in range(num_rest):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l
def save_ply_vis(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
xyzs = self._xyz.detach().cpu().numpy()
f_dc = self._features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scales = torch.log(self.get_scaling)
scales = scales.detach().cpu().numpy()
rot_mat_vis = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
xyzs = xyzs @ rot_mat_vis.T
rotations = self._rotation.detach().cpu().numpy()
rotations = R.from_quat(rotations[:, [1,2,3,0]]).as_matrix()
rotations = rot_mat_vis @ rotations
rotations = R.from_matrix(rotations).as_quat()[:, [3,0,1,2]]
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes(0)]
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def save_ply(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
xyzs = self._xyz.detach().cpu().numpy()
f_dc = self._features_dc.detach().flatten(start_dim=1).contiguous().cpu().numpy()
if self.sh_degree > 0:
f_rest = self._features_rest.detach().flatten(start_dim=1).contiguous().cpu().numpy()
else:
f_rest = np.zeros((f_dc.shape[0], 0), dtype=f_dc.dtype)
opacities = self._opacity.detach().cpu().numpy()
scales = torch.log(self.get_scaling)
scales = scales.detach().cpu().numpy()
rotations = self._rotation.detach().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes(f_rest.shape[-1])]
elements = np.empty(xyzs.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyzs, f_dc, f_rest, opacities, scales, rotations), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, "vertex")
PlyData([el]).write(path)
# def load_ply(self, path):
# plydata = PlyData.read(path)
# xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
# np.asarray(plydata.elements[0]["y"]),
# np.asarray(plydata.elements[0]["z"])), axis=1)
# opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
# features_dc = np.zeros((xyz.shape[0], 3, 1))
# features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
# features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
# features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
# scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
# scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
# scales = np.zeros((xyz.shape[0], len(scale_names)))
# for idx, attr_name in enumerate(scale_names):
# scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
# rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
# rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
# rots = np.zeros((xyz.shape[0], len(rot_names)))
# for idx, attr_name in enumerate(rot_names):
# rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
# self._xyz = torch.from_numpy(xyz.astype(np.float32))
# self._features_dc = torch.from_numpy(features_dc.astype(np.float32)).transpose(1, 2).contiguous()
# self._opacity = torch.from_numpy(opacities.astype(np.float32)).contiguous()
# self._scaling = torch.from_numpy(scales.astype(np.float32)).contiguous()
# self._rotation = torch.from_numpy(rots.astype(np.float32)).contiguous()
def render(
pc: GaussianModel,
height: int,
width: int,
C2W: torch.Tensor,
fxfycxcy: torch.Tensor,
bg_color=(1.0, 1.0, 1.0),
scaling_modifier=1.0,
):
"""
Render the scene.
"""
screenspace_points = (
torch.zeros_like(
pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
)
+ 0
)
try:
screenspace_points.retain_grad()
except:
pass
viewpoint_camera = Camera(C2W=C2W, fxfycxcy=fxfycxcy, h=height, w=width)
bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.h),
image_width=int(viewpoint_camera.w),
tanfovx=viewpoint_camera.tanfovX,
tanfovy=viewpoint_camera.tanfovY,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False,
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
scales = pc.get_scaling
rotations = pc.get_rotation
shs = pc.get_features
rendered_image, _, rendered_depth, rendered_alpha = rasterizer(
means3D=means3D,
means2D=means2D,
shs=None if pc.sh_degree == 0 else shs,
colors_precomp=shs if pc.sh_degree == 0 else None,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=None,
)
return {
"render": rendered_image,
"alpha": rendered_alpha,
"depth": rendered_depth,
}