Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn | |
import torch.nn.functional as F | |
from .sh import eval_sh_bases | |
import numpy as np | |
import time | |
def get_ray_directions_blender(H, W, focal, center=None): | |
""" | |
Get ray directions for all pixels in camera coordinate. | |
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ | |
ray-tracing-generating-camera-rays/standard-coordinate-systems | |
Inputs: | |
H, W, focal: image height, width and focal length | |
Outputs: | |
directions: (H, W, 3), the direction of the rays in camera coordinate | |
""" | |
grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 | |
i, j = grid.unbind(-1) | |
# the direction here is without +0.5 pixel centering as calibration is not so accurate | |
# see https://github.com/bmild/nerf/issues/24 | |
cent = center if center is not None else [W / 2, H / 2] | |
directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], | |
-1) # (H, W, 3) | |
return directions | |
def get_rays(directions, c2w): | |
""" | |
Get ray origin and normalized directions in world coordinate for all pixels in one image. | |
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ | |
ray-tracing-generating-camera-rays/standard-coordinate-systems | |
Inputs: | |
directions: (H, W, 3) precomputed ray directions in camera coordinate | |
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate | |
Outputs: | |
rays_o: (H*W, 3), the origin of the rays in world coordinate | |
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate | |
""" | |
# Rotate ray directions from camera coordinate to the world coordinate | |
rays_d = directions @ c2w[:3, :3].T # (H, W, 3) | |
# rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) | |
# The origin of all rays is the camera origin in world coordinate | |
rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) | |
rays_d = rays_d.view(-1, 3) | |
rays_o = rays_o.view(-1, 3) | |
return rays_o, rays_d | |
def positional_encoding(positions, freqs): | |
freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) | |
pts = (positions[..., None] * freq_bands).reshape( | |
positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) | |
pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) | |
return pts | |
def raw2alpha(sigma, dist): | |
# sigma, dist [N_rays, N_samples] | |
alpha = 1. - torch.exp(-sigma*dist) | |
T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0],alpha.shape[1], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) | |
weights = alpha * T[:,:, :-1] # [N_rays, N_samples] | |
return alpha, weights, T[:,:,-1:] | |
def SHRender(xyz_sampled, viewdirs, features): | |
sh_mult = eval_sh_bases(2, viewdirs)[:, None] | |
rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) | |
rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) | |
return rgb | |
def RGBRender(xyz_sampled, viewdirs, features): | |
rgb = features | |
return rgb | |
class AlphaGridMask(torch.nn.Module): | |
def __init__(self, device, aabb, alpha_volume): | |
super(AlphaGridMask, self).__init__() | |
self.device = device | |
self.aabb=aabb.to(self.device) | |
self.aabbSize = self.aabb[1] - self.aabb[0] | |
self.invgridSize = 1.0/self.aabbSize * 2 | |
self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) | |
self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) | |
def sample_alpha(self, xyz_sampled): | |
xyz_sampled = self.normalize_coord(xyz_sampled) | |
alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) | |
return alpha_vals | |
def normalize_coord(self, xyz_sampled): | |
return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 | |
class MLPRender_Fea(torch.nn.Module): | |
def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): | |
super(MLPRender_Fea, self).__init__() | |
self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel | |
self.viewpe = viewpe | |
self.feape = feape | |
layer1 = torch.nn.Linear(self.in_mlpC, featureC) | |
layer2 = torch.nn.Linear(featureC, featureC) | |
layer3 = torch.nn.Linear(featureC,3) | |
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) | |
torch.nn.init.constant_(self.mlp[-1].bias, 0) | |
def forward(self, pts, viewdirs, features): | |
indata = [features, viewdirs] | |
if self.feape > 0: | |
indata += [positional_encoding(features, self.feape)] | |
if self.viewpe > 0: | |
indata += [positional_encoding(viewdirs, self.viewpe)] | |
mlp_in = torch.cat(indata, dim=-1) | |
rgb = self.mlp(mlp_in) | |
rgb = torch.sigmoid(rgb) | |
return rgb | |
class MLPRender_PE(torch.nn.Module): | |
def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): | |
super(MLPRender_PE, self).__init__() | |
self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # | |
self.viewpe = viewpe | |
self.pospe = pospe | |
layer1 = torch.nn.Linear(self.in_mlpC, featureC) | |
layer2 = torch.nn.Linear(featureC, featureC) | |
layer3 = torch.nn.Linear(featureC,3) | |
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) | |
torch.nn.init.constant_(self.mlp[-1].bias, 0) | |
def forward(self, pts, viewdirs, features): | |
indata = [features, viewdirs] | |
if self.pospe > 0: | |
indata += [positional_encoding(pts, self.pospe)] | |
if self.viewpe > 0: | |
indata += [positional_encoding(viewdirs, self.viewpe)] | |
mlp_in = torch.cat(indata, dim=-1) | |
rgb = self.mlp(mlp_in) | |
rgb = torch.sigmoid(rgb) | |
return rgb | |
class MLPRender(torch.nn.Module): | |
def __init__(self,inChanel, viewpe=6, featureC=128): | |
super(MLPRender, self).__init__() | |
self.in_mlpC = (3+2*viewpe*3) + inChanel | |
self.viewpe = viewpe | |
layer1 = torch.nn.Linear(self.in_mlpC, featureC) | |
layer2 = torch.nn.Linear(featureC, featureC) | |
layer3 = torch.nn.Linear(featureC,3) | |
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) | |
torch.nn.init.constant_(self.mlp[-1].bias, 0) | |
def forward(self, pts, viewdirs, features): | |
indata = [features, viewdirs] | |
if self.viewpe > 0: | |
indata += [positional_encoding(viewdirs, self.viewpe)] | |
mlp_in = torch.cat(indata, dim=-1) | |
rgb = self.mlp(mlp_in) | |
rgb = torch.sigmoid(rgb) | |
return rgb | |
class TensorBase(torch.nn.Module): | |
def __init__(self, aabb, gridSize, density_n_comp = 16, appearance_n_comp = 48, app_dim = 27, density_dim = 8, | |
shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], | |
density_shift = -10, alphaMask_thres=0.0001, distance_scale=25, rayMarch_weight_thres=0.0001, | |
pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=0.5, | |
fea2denseAct = 'softplus'): | |
super(TensorBase, self).__init__() | |
self.density_n_comp = density_n_comp | |
self.app_n_comp = appearance_n_comp | |
self.app_dim = app_dim | |
self.density_dim=density_dim | |
self.aabb = aabb | |
self.alphaMask = alphaMask | |
#self.device=device | |
self.density_shift = density_shift | |
self.alphaMask_thres = alphaMask_thres | |
self.distance_scale = distance_scale | |
self.rayMarch_weight_thres = rayMarch_weight_thres | |
self.fea2denseAct = fea2denseAct | |
self.near_far = near_far | |
self.step_ratio = 0.9 #step_ratio原作0.5 | |
self.update_stepSize(gridSize) | |
self.matMode = [[0,1], [0,2], [1,2]] | |
self.vecMode = [2, 1, 0] | |
self.comp_w = [1,1,1] | |
#self.init_svd_volume(gridSize[0], device) | |
self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC | |
self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC) | |
def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC): | |
if shadingMode == 'MLP_PE': | |
self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC) | |
elif shadingMode == 'MLP_Fea': | |
self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC) | |
elif shadingMode == 'MLP': | |
self.renderModule = MLPRender(self.app_dim, view_pe, featureC) | |
elif shadingMode == 'SH': | |
self.renderModule = SHRender | |
elif shadingMode == 'RGB': | |
assert self.app_dim == 3 | |
self.renderModule = RGBRender | |
else: | |
print("Unrecognized shading module") | |
exit() | |
print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) | |
print(self.renderModule) | |
def update_stepSize(self, gridSize): | |
self.aabbSize = self.aabb[1] - self.aabb[0] | |
self.invaabbSize = 2.0/self.aabbSize | |
self.gridSize= gridSize.float() | |
self.units=self.aabbSize / (self.gridSize-1) | |
self.stepSize=torch.mean(self.units)*self.step_ratio # TBD step_ratio? why so small 0.5 | |
self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) | |
self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 | |
print("sampling step size: ", self.stepSize) | |
print("sampling number: ", self.nSamples) | |
def init_svd_volume(self, res, device): | |
pass | |
def compute_features(self, xyz_sampled): | |
pass | |
def compute_densityfeature(self, xyz_sampled): | |
pass | |
def compute_appfeature(self, xyz_sampled): | |
pass | |
def normalize_coord(self, xyz_sampled): | |
if xyz_sampled.device!=self.invaabbSize.device: | |
self.invaabbSize=self.invaabbSize.to(xyz_sampled.device) | |
return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 | |
def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): | |
pass | |
def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): | |
N_samples = N_samples if N_samples > 0 else self.nSamples | |
near, far = self.near_far | |
interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) | |
if is_train: | |
interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) | |
rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] | |
mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) | |
return rays_pts, interpx, ~mask_outbbox | |
def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): | |
N_samples = N_samples if N_samples>0 else self.nSamples | |
stepsize = self.stepSize | |
near, far = self.near_far | |
vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) | |
rate_a = (self.aabb[1] - rays_o) / vec | |
rate_b = (self.aabb[0] - rays_o) / vec | |
t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) | |
rng = torch.arange(N_samples)[None,None].float() | |
if is_train: | |
rng = rng.repeat(rays_d.shape[-3],rays_d.shape[-2],1) | |
rng += torch.rand_like(rng[...,[0]]) | |
step = stepsize * rng.to(rays_o.device) | |
interpx = (t_min[...,None] + step) | |
rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] | |
mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) | |
return rays_pts, interpx, ~mask_outbbox | |
def shrink(self, new_aabb, voxel_size): | |
pass | |
def getDenseAlpha(self,gridSize=None): | |
gridSize = self.gridSize if gridSize is None else gridSize | |
samples = torch.stack(torch.meshgrid( | |
torch.linspace(0, 1, gridSize[0]), | |
torch.linspace(0, 1, gridSize[1]), | |
torch.linspace(0, 1, gridSize[2]), | |
), -1).to(self.device) | |
dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples | |
# dense_xyz = dense_xyz | |
# print(self.stepSize, self.distance_scale*self.aabbDiag) | |
alpha = torch.zeros_like(dense_xyz[...,0]) | |
for i in range(gridSize[0]): | |
alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2])) | |
return alpha, dense_xyz | |
def feature2density(self, density_features): | |
if self.fea2denseAct == "softplus": | |
return F.softplus(density_features+self.density_shift) | |
elif self.fea2denseAct == "relu": | |
return F.relu(density_features) | |
def compute_alpha(self, xyz_locs, length=1): | |
if self.alphaMask is not None: | |
alphas = self.alphaMask.sample_alpha(xyz_locs) | |
alpha_mask = alphas > 0 | |
else: | |
alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) | |
sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) | |
if alpha_mask.any(): | |
xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) | |
sigma_feature = self.compute_densityfeature(xyz_sampled) | |
validsigma = self.feature2density(sigma_feature) | |
sigma[alpha_mask] = validsigma | |
alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) | |
return alpha | |
def forward(self, svd_volume, rays_o, rays_d, bg_color, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): | |
self.svd_volume=svd_volume | |
self.app_plane=svd_volume['app_planes'] | |
self.app_line=svd_volume['app_lines'] | |
self.basis_mat=svd_volume['basis_mat'] | |
self.density_plane=svd_volume['density_planes'] | |
self.density_line=svd_volume['density_lines'] | |
B,V,H,W,_=rays_o.shape | |
rays_o=rays_o.reshape(B,-1, 3) | |
rays_d=rays_d.reshape(B,-1, 3) | |
if ndc_ray: | |
pass | |
else: | |
#B,H*W*V,sample_num,3 | |
xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_o, rays_d, is_train=is_train,N_samples=N_samples) | |
dists = torch.cat((z_vals[..., 1:] - z_vals[..., :-1], torch.zeros_like(z_vals[..., :1])), dim=-1) | |
rays_d = rays_d.unsqueeze(-2).expand(xyz_sampled.shape) | |
xyz_sampled = self.normalize_coord(xyz_sampled) | |
sigma_feature = self.compute_densityfeature(xyz_sampled) | |
sigma = self.feature2density(sigma_feature) | |
alpha, weight, bg_weight = raw2alpha(sigma, dists) | |
app_features = self.compute_appfeature(xyz_sampled) | |
rgbs = self.renderModule(xyz_sampled, rays_d, app_features) | |
#rgb[app_mask] = valid_rgbs | |
acc_map = torch.sum(weight, -1) | |
rgb_map = torch.sum(weight[..., None] * rgbs, -2) | |
if white_bg or (is_train and torch.rand((1,))<0.5): | |
rgb_map = rgb_map + (1. - acc_map[..., None]) | |
rgb_map = rgb_map.clamp(0,1) | |
rgb_map=rgb_map.view(B,V,H,W,3).permute(0,1,4,2,3) | |
with torch.no_grad(): | |
depth_map = torch.sum(weight * z_vals, -1) | |
depth_map=depth_map.view(B,V,H,W,1).permute(0,1,4,2,3) | |
acc_map=acc_map.view(B,V,H,W,1).permute(0,1,4,2,3) | |
results = { | |
'image':rgb_map, | |
'alpha':acc_map, | |
'depth_map':depth_map | |
} | |
return results # rgb, sigma, alpha, weight, bg_weight | |