Spaces:
Running
on
Zero
Running
on
Zero
import math | |
from pathlib import Path | |
# from pytorch3d.ops import create_sphere | |
import torchvision | |
import point_cloud_utils as pcu | |
from tqdm import trange | |
import random | |
import einops | |
from einops import rearrange | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
from functools import partial | |
from torch.profiler import profile, record_function, ProfilerActivity | |
from nsr.networks_stylegan2 import Generator as StyleGAN2Backbone | |
from nsr.volumetric_rendering.renderer import ImportanceRenderer, ImportanceRendererfg_bg | |
from nsr.volumetric_rendering.ray_sampler import RaySampler | |
from nsr.triplane import OSGDecoder, Triplane, Triplane_fg_bg_plane | |
# from nsr.losses.helpers import ResidualBlock | |
from utils.dust3r.heads.dpt_head import create_dpt_head_ln3diff | |
from utils.nerf_utils import get_embedder | |
from vit.vision_transformer import TriplaneFusionBlockv4_nested, TriplaneFusionBlockv4_nested_init_from_dino_lite, TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino | |
from .vision_transformer import Block, VisionTransformer | |
from .utils import trunc_normal_ | |
from guided_diffusion import dist_util, logger | |
from pdb import set_trace as st | |
from ldm.modules.diffusionmodules.model import Encoder, Decoder | |
from torch_utils.components import PixelShuffleUpsample, ResidualBlock, Upsample, PixelUnshuffleUpsample, Conv3x3TriplaneTransformation | |
from torch_utils.distributions.distributions import DiagonalGaussianDistribution | |
from nsr.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X | |
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer | |
from nsr.common_blks import ResMlp | |
from timm.models.vision_transformer import PatchEmbed, Mlp | |
from .vision_transformer import * | |
from dit.dit_models import get_2d_sincos_pos_embed | |
from dit.dit_decoder import DiTBlock2 | |
from torch import _assert | |
from itertools import repeat | |
import collections.abc | |
from nsr.srt.layers import Transformer as SRT_TX | |
from nsr.srt.layers import PreNorm | |
# from diffusers.models.upsampling import Upsample2D | |
from torch_utils.components import NearestConvSR | |
from timm.models.vision_transformer import PatchEmbed | |
from utils.general_utils import matrix_to_quaternion, quaternion_raw_multiply, build_rotation | |
# from nsr.gs import GaussianRenderer | |
from utils.dust3r.heads import create_dpt_head | |
from ldm.modules.attention import MemoryEfficientCrossAttention, CrossAttention | |
# from nsr.geometry.camera.perspective_camera import PerspectiveCamera | |
# from nsr.geometry.render.neural_render import NeuralRender | |
# from nsr.geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry | |
# from utils.mesh_util import xatlas_uvmap | |
# From PyTorch internals | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
return tuple(x) | |
return tuple(repeat(x, n)) | |
return parse | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
def approx_gelu(): | |
return nn.GELU(approximate="tanh") | |
def init_gaussian_prediction(gaussian_pred_mlp): | |
# https://github.com/szymanowiczs/splatter-image/blob/98b465731c3273bf8f42a747d1b6ce1a93faf3d6/configs/dataset/chairs.yaml#L15 | |
out_channels = [3, 1, 3, 4, 3] # xyz, opacity, scale, rotation, rgb | |
scale_inits = [ # ! avoid affecting final value (offset) | |
0, #xyz_scale | |
0.0, #cfg.model.opacity_scale, | |
# 0.001, #cfg.model.scale_scale, | |
0, #cfg.model.scale_scale, | |
1, # rotation | |
0 | |
] # rgb | |
bias_inits = [ | |
0.0, # cfg.model.xyz_bias, no deformation here | |
0, # cfg.model.opacity_bias, sigmoid(0)=0.5 at init | |
-2.5, # scale_bias | |
0.0, # rotation | |
0.5 | |
] # rgb | |
start_channels = 0 | |
# for out_channel, b, s in zip(out_channels, bias, scale): | |
for out_channel, b, s in zip(out_channels, bias_inits, scale_inits): | |
# nn.init.xavier_uniform_( | |
# self.superresolution['conv_sr'].dpt.head[-1].weight[ | |
# start_channels:start_channels + out_channel, ...], s) | |
nn.init.constant_( | |
gaussian_pred_mlp.weight[start_channels:start_channels + | |
out_channel, ...], s) | |
nn.init.constant_( | |
gaussian_pred_mlp.bias[start_channels:start_channels + | |
out_channel], b) | |
start_channels += out_channel | |
class PatchEmbedTriplane(nn.Module): | |
""" GroupConv patchembeder on triplane | |
""" | |
def __init__( | |
self, | |
img_size=32, | |
patch_size=2, | |
in_chans=4, | |
embed_dim=768, | |
norm_layer=None, | |
flatten=True, | |
bias=True, | |
plane_n=3, | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
self.plane_n = plane_n | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.grid_size = (img_size[0] // patch_size[0], | |
img_size[1] // patch_size[1]) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] | |
self.flatten = flatten | |
self.proj = nn.Conv2d(in_chans, | |
embed_dim * self.plane_n, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=bias, | |
groups=self.plane_n) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
# st() | |
B, C, H, W = x.shape | |
_assert( | |
H == self.img_size[0], | |
f"Input image height ({H}) doesn't match model ({self.img_size[0]})." | |
) | |
_assert( | |
W == self.img_size[1], | |
f"Input image width ({W}) doesn't match model ({self.img_size[1]})." | |
) | |
x = self.proj(x) # B 3*C token_H token_W | |
x = x.reshape(B, x.shape[1] // self.plane_n, self.plane_n, x.shape[-2], | |
x.shape[-1]) # B C 3 H W | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # BC3HW -> B 3HW C | |
x = self.norm(x) | |
return x | |
# https://github.com/facebookresearch/MCC/blob/main/mcc_model.py#L81 | |
class XYZPosEmbed(nn.Module): | |
""" Masked Autoencoder with VisionTransformer backbone | |
""" | |
def __init__(self, embed_dim, multires=10): | |
super().__init__() | |
self.embed_dim = embed_dim | |
# no [cls] token here. | |
# ! use fixed PE here | |
self.embed_fn, self.embed_input_ch = get_embedder(multires) | |
# st() | |
# self.two_d_pos_embed = nn.Parameter( | |
# # torch.zeros(1, 64 + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
# torch.zeros(1, 64, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
# self.win_size = 8 | |
self.xyz_projection = nn.Linear(self.embed_input_ch, embed_dim) | |
# self.blocks = nn.ModuleList([ | |
# Block(embed_dim, num_heads=12, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) | |
# for _ in range(1) | |
# ]) | |
# self.invalid_xyz_token = nn.Parameter(torch.zeros(embed_dim,)) | |
# self.initialize_weights() | |
# def initialize_weights(self): | |
# # torch.nn.init.normal_(self.cls_token, std=.02) | |
# two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], 8, cls_token=False) | |
# self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0)) | |
# torch.nn.init.normal_(self.invalid_xyz_token, std=.02) | |
def forward(self, xyz): | |
xyz = self.embed_fn(xyz) # PE encoding | |
xyz = self.xyz_projection(xyz) # linear projection | |
return xyz | |
class gaussian_prediction(nn.Module): | |
def __init__( | |
self, | |
query_dim, | |
) -> None: | |
super().__init__() | |
self.gaussian_pred = nn.Sequential( | |
nn.SiLU(), nn.Linear(query_dim, 14, | |
bias=True)) # TODO, init require | |
self.init_gaussian_prediction() | |
def init_gaussian_prediction(self): | |
# https://github.com/szymanowiczs/splatter-image/blob/98b465731c3273bf8f42a747d1b6ce1a93faf3d6/configs/dataset/chairs.yaml#L15 | |
out_channels = [3, 1, 3, 4, 3] # xyz, opacity, scale, rotation, rgb | |
scale_inits = [ # ! avoid affecting final value (offset) | |
0, #xyz_scale | |
0.0, #cfg.model.opacity_scale, | |
# 0.001, #cfg.model.scale_scale, | |
0, #cfg.model.scale_scale, | |
1.0, # rotation | |
0 | |
] # rgb | |
bias_inits = [ | |
0.0, # cfg.model.xyz_bias, no deformation here | |
0, # cfg.model.opacity_bias, sigmoid(0)=0.5 at init | |
-2.5, # scale_bias | |
0.0, # rotation | |
0.5 | |
] # rgb | |
start_channels = 0 | |
# for out_channel, b, s in zip(out_channels, bias, scale): | |
for out_channel, b, s in zip(out_channels, bias_inits, scale_inits): | |
# nn.init.xavier_uniform_( | |
# self.superresolution['conv_sr'].dpt.head[-1].weight[ | |
# start_channels:start_channels + out_channel, ...], s) | |
nn.init.constant_( | |
self.gaussian_pred[1].weight[start_channels:start_channels + | |
out_channel, ...], s) | |
nn.init.constant_( | |
self.gaussian_pred[1].bias[start_channels:start_channels + | |
out_channel], b) | |
start_channels += out_channel | |
def forward(self, x): | |
return self.gaussian_pred(x) | |
class surfel_prediction(nn.Module): | |
# for 2dgs | |
def __init__( | |
self, | |
query_dim, | |
) -> None: | |
super().__init__() | |
self.gaussian_pred = nn.Sequential( | |
nn.SiLU(), nn.Linear(query_dim, 13, | |
bias=True)) # TODO, init require | |
self.init_gaussian_prediction() | |
def init_gaussian_prediction(self): | |
# https://github.com/szymanowiczs/splatter-image/blob/98b465731c3273bf8f42a747d1b6ce1a93faf3d6/configs/dataset/chairs.yaml#L15 | |
out_channels = [3, 1, 2, 4, 3] # xyz, opacity, scale, rotation, rgb | |
scale_inits = [ # ! avoid affecting final value (offset) | |
0, #xyz_scale | |
0.0, #cfg.model.opacity_scale, | |
# 0.001, #cfg.model.scale_scale, | |
0, #cfg.model.scale_scale, | |
1.0, # rotation | |
0 | |
] # rgb | |
bias_inits = [ | |
0.0, # cfg.model.xyz_bias, no deformation here | |
0, # cfg.model.opacity_bias, sigmoid(0)=0.5 at init | |
-2.5, # scale_bias | |
0, # scale bias, also 0 | |
0.0, # rotation | |
0.5 | |
] # rgb | |
start_channels = 0 | |
# for out_channel, b, s in zip(out_channels, bias, scale): | |
for out_channel, b, s in zip(out_channels, bias_inits, scale_inits): | |
# nn.init.xavier_uniform_( | |
# self.superresolution['conv_sr'].dpt.head[-1].weight[ | |
# start_channels:start_channels + out_channel, ...], s) | |
nn.init.constant_( | |
self.gaussian_pred[1].weight[start_channels:start_channels + | |
out_channel, ...], s) | |
nn.init.constant_( | |
self.gaussian_pred[1].bias[start_channels:start_channels + | |
out_channel], b) | |
start_channels += out_channel | |
def forward(self, x): | |
return self.gaussian_pred(x) | |
class pointInfinityWriteCA(gaussian_prediction): | |
def __init__(self, | |
query_dim, | |
context_dim, | |
heads=8, | |
dim_head=64, | |
dropout=0.0) -> None: | |
super().__init__(query_dim=query_dim) | |
self.write_ca = MemoryEfficientCrossAttention(query_dim, context_dim, | |
heads, dim_head, dropout) | |
def forward(self, x, z, return_x=False): | |
# x: point to write | |
# z: extracted latent | |
x = self.write_ca(x, z) # write from z to x | |
if return_x: | |
return self.gaussian_pred(x), x # ! integrate it into dit? | |
else: | |
return self.gaussian_pred(x) # ! integrate it into dit? | |
class pointInfinityWriteCA_cascade(pointInfinityWriteCA): | |
# gradually (in 6 times) add deformation offsets to the initialized canonical pts, follow PI | |
def __init__(self, | |
vit_depth, | |
query_dim, | |
context_dim, | |
heads=8, | |
dim_head=64, | |
dropout=0) -> None: | |
super().__init__(query_dim, context_dim, heads, dim_head, dropout) | |
del self.write_ca | |
# query_dim = 384 # to speed up CA compute | |
write_ca_interval = 12 // 4 | |
# self.deform_pred = nn.Sequential( # to-gaussian layer | |
# nn.SiLU(), nn.Linear(query_dim, 3, bias=True)) # TODO, init require | |
# query_dim = 384 here | |
self.write_ca_blocks = nn.ModuleList([ | |
MemoryEfficientCrossAttention(query_dim, context_dim, | |
heads=heads) # make it lite | |
for _ in range(write_ca_interval) | |
# for _ in range(write_ca_interval) | |
]) | |
self.hooks = [3, 7, 11] # hard coded for now | |
# [(vit_depth * 1 // 3) - 1, (vit_depth * 2 // 4) - 1, (vit_depth * 3 // 4) - 1, | |
# vit_depth - 1] | |
def forward(self, x: torch.Tensor, z: list): | |
# x is the canonical point | |
# z: extracted latent (for different layers), all layers in dit | |
# TODO, optimize memory, no need to return all layers? | |
# st() | |
z = [z[hook] for hook in self.hooks] | |
# st() | |
for idx, ca_blk in enumerate(self.write_ca_blocks): | |
x = x + ca_blk(x, z[idx]) # learn residual feature | |
return self.gaussian_pred(x) | |
def create_sphere(radius, num_points): | |
# Generate spherical coordinates | |
phi = torch.linspace(0, 2 * torch.pi, num_points) | |
theta = torch.linspace(0, torch.pi, num_points) | |
phi, theta = torch.meshgrid(phi, theta, indexing='xy') | |
# Convert spherical coordinates to Cartesian coordinates | |
x = radius * torch.sin(theta) * torch.cos(phi) | |
y = radius * torch.sin(theta) * torch.sin(phi) | |
z = radius * torch.cos(theta) | |
# Stack x, y, z coordinates | |
points = torch.stack([x.flatten(), y.flatten(), z.flatten()], dim=1) | |
return points | |
class GS_Adaptive_Write_CA(nn.Module): | |
def __init__( | |
self, | |
query_dim, | |
context_dim, | |
f=4, # upsampling ratio | |
heads=8, | |
dim_head=64, | |
dropout=0.0) -> None: | |
super().__init__() | |
self.f = f | |
self.write_ca = MemoryEfficientCrossAttention(query_dim, context_dim, | |
heads, dim_head, dropout) | |
self.gaussian_residual_pred = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(query_dim, 14, | |
bias=True)) # predict residual, before activations | |
# ! hard coded | |
self.scene_extent = 0.9 # g-buffer, [-0.45, 0.45] | |
self.percent_dense = 0.01 # 3dgs official value | |
self.residual_offset_act = lambda x: torch.tanh( | |
x) * self.scene_extent * 0.015 # avoid large deformation | |
init_gaussian_prediction(self.gaussian_residual_pred[1]) | |
# def densify_and_split(self, gaussians_base, base_gaussian_xyz_embed): | |
def forward(self, | |
gaussians_base, | |
gaussian_base_pre_activate, | |
gaussian_base_feat, | |
xyz_embed_fn, | |
shrink_scale=True): | |
# gaussians_base: xyz_base after activations and deform offset | |
# xyz_base: original features (before activations) | |
# ! use point embedder, or other features? | |
# base_gaussian_xyz_embed = xyz_embed_fn(gaussians_base[..., :3]) | |
# x = self.densify_and_split(gaussians_base, base_gaussian_xyz_embed) | |
# ! densify | |
B, N = gaussians_base.shape[:2] # gaussians upsample factor | |
# n_init_points = self.get_xyz.shape[0] | |
pos, opacity, scaling, rotation = gaussians_base[ | |
..., 0:3], gaussians_base[..., 3:4], gaussians_base[ | |
..., 4:7], gaussians_base[..., 7:11] | |
# ! filter clone/densify based on scaling range | |
split_mask = scaling.max( | |
dim=-1 | |
)[0] > self.scene_extent * self.percent_dense # shape: B 4096 | |
# clone_mask = ~split_mask | |
stds = scaling.repeat_interleave(self.f, dim=1) # 0 0 1 1 2 2... | |
means = torch.zeros_like(stds) | |
samples = torch.normal(mean=means, std=stds) # B f*N 3 | |
# rots = build_rotation(rotation).repeat(N, 1, 1) | |
# rots = rearrange(build_rotation(rearrange(rotation, 'B N ... -> (B N) ...')), '(B N) ... -> B N ...', B=B, N=N) | |
# rots = rots.repeat_interleave(self.f, dim=1) # B f*N 3 3 | |
# torch.bmm only supports ndim=3 Tensor | |
# new_xyz = torch.matmul(rots, samples.unsqueeze(-1)).squeeze(-1) + pos.repeat_interleave(self.f, dim=1) | |
new_xyz = samples + pos.repeat_interleave( | |
self.f, dim=1) # ! no rotation for now | |
# new_xyz: B f*N 3 | |
# ! new points to features | |
new_xyz_embed = xyz_embed_fn(new_xyz) | |
new_gaussian_embed = self.write_ca( | |
new_xyz_embed, gaussian_base_feat) # write from z to x | |
# ! predict gaussians residuals | |
gaussian_residual_pre_activate = self.gaussian_residual_pred( | |
new_gaussian_embed) | |
# ! add back. how to deal with new rotations? check the range first. | |
# scaling and rotation. | |
if shrink_scale: | |
gaussian_base_pre_activate[split_mask][ | |
4:7] -= 1 # reduce scale for those points | |
gaussian_base_pre_activate_repeat = gaussian_base_pre_activate.repeat_interleave( | |
self.f, dim=1) | |
# new scaling | |
# ! pre-activate scaling value, shall be negative? since more values are 0.1 before softplus. | |
# TODO wrong here, shall get new scaling before repeat | |
gaussians = gaussian_residual_pre_activate + gaussian_base_pre_activate_repeat # learn the residual | |
new_gaussians_pos = new_xyz + self.residual_offset_act( | |
gaussians[..., :3]) | |
return gaussians, new_gaussians_pos # return positions independently | |
class GS_Adaptive_Read_Write_CA(nn.Module): | |
def __init__( | |
self, | |
query_dim, | |
context_dim, | |
mlp_ratio, | |
vit_heads, | |
f=4, # upsampling ratio | |
heads=8, | |
dim_head=64, | |
dropout=0.0, | |
depth=2, | |
vit_blk=DiTBlock2) -> None: | |
super().__init__() | |
self.f = f | |
self.read_ca = MemoryEfficientCrossAttention(query_dim, context_dim, | |
heads, dim_head, dropout) | |
# more dit blocks | |
self.point_infinity_blocks = nn.ModuleList([ | |
vit_blk(context_dim, num_heads=vit_heads, mlp_ratio=mlp_ratio) | |
for _ in range(depth) # since dit-b here | |
]) | |
self.write_ca = MemoryEfficientCrossAttention(query_dim, context_dim, | |
heads, dim_head, dropout) | |
self.gaussian_residual_pred = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(query_dim, 14, | |
bias=True)) # predict residual, before activations | |
# ! hard coded | |
self.scene_extent = 0.9 # g-buffer, [-0.45, 0.45] | |
self.percent_dense = 0.01 # 3dgs official value | |
self.residual_offset_act = lambda x: torch.tanh( | |
x) * self.scene_extent * 0.015 # avoid large deformation | |
self.initialize_weights() | |
def initialize_weights(self): | |
init_gaussian_prediction(self.gaussian_residual_pred[1]) | |
for block in self.point_infinity_blocks: | |
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
# def densify_and_split(self, gaussians_base, base_gaussian_xyz_embed): | |
def forward(self, gaussians_base, gaussian_base_pre_activate, | |
gaussian_base_feat, latent_from_vit, vae_latent, xyz_embed_fn): | |
# gaussians_base: xyz_base after activations and deform offset | |
# xyz_base: original features (before activations) | |
# ========= START read CA ======== | |
latent_from_vit = self.read_ca(latent_from_vit, | |
gaussian_base_feat) # z_i -> z_(i+1) | |
for blk_idx, block in enumerate(self.point_infinity_blocks): | |
latent_from_vit = block(latent_from_vit, | |
vae_latent) # vae_latent: c | |
# ========= END read CA ======== | |
# ! use point embedder, or other features? | |
# base_gaussian_xyz_embed = xyz_embed_fn(gaussians_base[..., :3]) | |
# x = self.densify_and_split(gaussians_base, base_gaussian_xyz_embed) | |
# ! densify | |
B, N = gaussians_base.shape[:2] # gaussians upsample factor | |
# n_init_points = self.get_xyz.shape[0] | |
pos, opacity, scaling, rotation = gaussians_base[ | |
..., 0:3], gaussians_base[..., 3:4], gaussians_base[ | |
..., 4:7], gaussians_base[..., 7:11] | |
# ! filter clone/densify based on scaling range | |
split_mask = scaling.max( | |
dim=-1 | |
)[0] > self.scene_extent * self.percent_dense # shape: B 4096 | |
# clone_mask = ~split_mask | |
stds = scaling.repeat_interleave(self.f, dim=1) # 0 0 1 1 2 2... | |
means = torch.zeros_like(stds) | |
samples = torch.normal(mean=means, std=stds) # B f*N 3 | |
rots = build_rotation(rotation).repeat(N, 1, 1) | |
rots = rearrange(build_rotation( | |
rearrange(rotation, 'B N ... -> (B N) ...')), | |
'(B N) ... -> B N ...', | |
B=B, | |
N=N) | |
rots = rots.repeat_interleave(self.f, dim=1) # B f*N 3 3 | |
# torch.bmm only supports ndim=3 Tensor | |
new_xyz = torch.matmul( | |
rots, samples.unsqueeze(-1)).squeeze(-1) + pos.repeat_interleave( | |
self.f, dim=1) | |
# new_xyz = samples + pos.repeat_interleave( | |
# self.f, dim=1) # ! no rotation for now | |
# new_xyz: B f*N 3 | |
# ! new points to features | |
new_xyz_embed = xyz_embed_fn(new_xyz) | |
new_gaussian_embed = self.write_ca( | |
new_xyz_embed, latent_from_vit | |
) # ! use z_(i+1), rather than gaussian_base_feat here | |
# ! predict gaussians residuals | |
gaussian_residual_pre_activate = self.gaussian_residual_pred( | |
new_gaussian_embed) | |
# ! add back. how to deal with new rotations? check the range first. | |
# scaling and rotation. | |
gaussian_base_pre_activate[split_mask][ | |
4:7] -= 1 # reduce scale for those points | |
gaussian_base_pre_activate_repeat = gaussian_base_pre_activate.repeat_interleave( | |
self.f, dim=1) | |
# new scaling | |
# ! pre-activate scaling value, shall be negative? since more values are 0.1 before softplus. | |
# TODO wrong here, shall get new scaling before repeat | |
gaussians = gaussian_residual_pre_activate + gaussian_base_pre_activate_repeat # learn the residual | |
new_gaussians_pos = new_xyz + self.residual_offset_act( | |
gaussians[..., :3]) | |
return gaussians, new_gaussians_pos, latent_from_vit, new_gaussian_embed # return positions independently | |
class GS_Adaptive_Read_Write_CA_adaptive(GS_Adaptive_Read_Write_CA): | |
def __init__(self, | |
query_dim, | |
context_dim, | |
mlp_ratio, | |
vit_heads, | |
f=4, | |
heads=8, | |
dim_head=64, | |
dropout=0, | |
depth=2, | |
vit_blk=DiTBlock2) -> None: | |
super().__init__(query_dim, context_dim, mlp_ratio, vit_heads, f, | |
heads, dim_head, dropout, depth, vit_blk) | |
# assert self.f == 6 | |
def forward(self, gaussians_base, gaussian_base_pre_activate, | |
gaussian_base_feat, latent_from_vit, vae_latent, xyz_embed_fn): | |
# gaussians_base: xyz_base after activations and deform offset | |
# xyz_base: original features (before activations) | |
# ========= START read CA ======== | |
latent_from_vit = self.read_ca(latent_from_vit, | |
gaussian_base_feat) # z_i -> z_(i+1) | |
for blk_idx, block in enumerate(self.point_infinity_blocks): | |
latent_from_vit = block(latent_from_vit, | |
vae_latent) # vae_latent: c | |
# ========= END read CA ======== | |
# ! use point embedder, or other features? | |
# base_gaussian_xyz_embed = xyz_embed_fn(gaussians_base[..., :3]) | |
# x = self.densify_and_split(gaussians_base, base_gaussian_xyz_embed) | |
# ! densify | |
B, N = gaussians_base.shape[:2] # gaussians upsample factor | |
# n_init_points = self.get_xyz.shape[0] | |
pos, opacity, scaling, rotation = gaussians_base[ | |
..., 0:3], gaussians_base[..., 3:4], gaussians_base[ | |
..., 4:7], gaussians_base[..., 7:11] | |
# ! filter clone/densify based on scaling range | |
split_mask = scaling.max( | |
dim=-1 | |
)[0] > self.scene_extent * self.percent_dense # shape: B 4096 | |
# clone_mask = ~split_mask | |
# stds = scaling.repeat_interleave(self.f, dim=1) # B 13824 3 | |
# stds = scaling.unsqueeze(1).repeat_interleave(self.f, dim=1) # B 6 13824 3 | |
stds = scaling # B 13824 3 | |
# TODO, in mat form. axis aligned creation. | |
samples = torch.zeros(B, N, 3, 3).to(stds.device) | |
samples[..., 0, 0] = stds[..., 0] | |
samples[..., 1, 1] = stds[..., 1] | |
samples[..., 2, 2] = stds[..., 2] | |
eye_mat = torch.cat([torch.eye(3), -torch.eye(3)], | |
0) # 6 * 3, to put gaussians along the axis | |
eye_mat = eye_mat.reshape(1, 1, 6, 3).repeat(B, N, 1, | |
1).to(stds.device) | |
samples = (eye_mat @ samples).squeeze(-1) | |
# st() | |
# means = torch.zeros_like(stds) | |
# samples = torch.normal(mean=means, std=stds) # B f*N 3 | |
rots = rearrange(build_rotation( | |
rearrange(rotation, 'B N ... -> (B N) ...')), | |
'(B N) ... -> B N ...', | |
B=B, | |
N=N) | |
rots = rots.unsqueeze(2).repeat_interleave(self.f, dim=2) # B f*N 3 3 | |
# torch.bmm only supports ndim=3 Tensor | |
# new_xyz = torch.matmul(rots, samples.unsqueeze(-1)).squeeze(-1) + pos.repeat_interleave(self.f, dim=1) | |
# st() | |
# new_xyz = torch.matmul(rots, samples.unsqueeze(-1)).squeeze(-1) + pos.repeat_interleave(self.f, dim=1) | |
new_xyz = (rots @ samples.unsqueeze(-1)).squeeze(-1) + pos.unsqueeze( | |
2).repeat_interleave(self.f, dim=2) # B N 6 3 | |
new_xyz = rearrange(new_xyz, 'b n f c -> b (n f) c') | |
# ! not considering rotation here | |
# new_xyz = samples + pos.repeat_interleave( | |
# self.f, dim=1) # ! no rotation for now | |
# new_xyz: B f*N 3 | |
# ! new points to features | |
new_xyz_embed = xyz_embed_fn(new_xyz) | |
new_gaussian_embed = self.write_ca( | |
new_xyz_embed, latent_from_vit | |
) # ! use z_(i+1), rather than gaussian_base_feat here | |
# ! predict gaussians residuals | |
gaussian_residual_pre_activate = self.gaussian_residual_pred( | |
new_gaussian_embed) | |
# ! add back. how to deal with new rotations? check the range first. | |
# scaling and rotation. | |
# gaussian_base_pre_activate[split_mask][ | |
# 4:7] -= 1 # reduce scale for those points | |
gaussian_base_pre_activate_repeat = gaussian_base_pre_activate.repeat_interleave( | |
self.f, dim=1) | |
# new scaling | |
# ! pre-activate scaling value, shall be negative? since more values are 0.1 before softplus. | |
# TODO wrong here, shall get new scaling before repeat | |
gaussians = gaussian_residual_pre_activate + gaussian_base_pre_activate_repeat # learn the residual | |
# new_gaussians_pos = new_xyz + self.residual_offset_act( | |
# gaussians[..., :3]) | |
return gaussians, new_xyz, latent_from_vit, new_gaussian_embed # return positions independently | |
class GS_Adaptive_Read_Write_CA_adaptive_f14_prepend( | |
GS_Adaptive_Read_Write_CA_adaptive): | |
def __init__(self, | |
query_dim, | |
context_dim, | |
mlp_ratio, | |
vit_heads, | |
f=4, | |
heads=8, | |
dim_head=64, | |
dropout=0, | |
depth=2, | |
vit_blk=DiTBlock2, | |
no_flash_op=False,) -> None: | |
super().__init__(query_dim, context_dim, mlp_ratio, vit_heads, f, | |
heads, dim_head, dropout, depth, vit_blk) | |
# corner_mat = torch.empty(8,3) | |
# counter = 0 | |
# for i in range(-1,3,2): | |
# for j in range(-1,3,2): | |
# for k in range(-1,3,2): | |
# corner_mat[counter] = torch.Tensor([i,j,k]) | |
# counter += 1 | |
# self.corner_mat=corner_mat.contiguous().to(dist_util.dev()).reshape(1,1,8,3) | |
del self.read_ca, self.write_ca | |
del self.point_infinity_blocks | |
# ? why not saved to checkpoint | |
# self.latent_embedding = nn.Parameter(torch.randn(1, f, query_dim)).to( | |
# dist_util.dev()) | |
# ! not .cuda() here | |
self.latent_embedding = nn.Parameter(torch.randn(1, f, query_dim), | |
requires_grad=True) | |
self.transformer = SRT_TX( | |
context_dim, # 12 * 64 = 768 | |
depth=depth, | |
heads=context_dim // 64, # vit-b default. | |
mlp_dim=4 * context_dim, # 1536 by default | |
no_flash_op=no_flash_op, | |
) | |
# self.offset_act = lambda x: torch.tanh(x) * (self.scene_range[ | |
# 1]) * 0.5 # regularize small offsets | |
def forward(self, gaussians_base, gaussian_base_pre_activate, | |
gaussian_base_feat, latent_from_vit, vae_latent, xyz_embed_fn, | |
offset_act): | |
# gaussians_base: xyz_base after activations and deform offset | |
# xyz_base: original features (before activations) | |
# ========= START read CA ======== | |
# latent_from_vit = self.read_ca(latent_from_vit, | |
# gaussian_base_feat) # z_i -> z_(i+1) | |
# for blk_idx, block in enumerate(self.point_infinity_blocks): | |
# latent_from_vit = block(latent_from_vit, | |
# vae_latent) # vae_latent: c | |
# ========= END read CA ======== | |
# ! use point embedder, or other features? | |
# base_gaussian_xyz_embed = xyz_embed_fn(gaussians_base[..., :3]) | |
# x = self.densify_and_split(gaussians_base, base_gaussian_xyz_embed) | |
# ! densify | |
B, N = gaussians_base.shape[:2] # gaussians upsample factor | |
# n_init_points = self.get_xyz.shape[0] | |
pos, opacity, scaling, rotation = gaussians_base[ | |
..., 0:3], gaussians_base[..., 3:4], gaussians_base[ | |
..., 4:7], gaussians_base[..., 7:11] | |
# ! filter clone/densify based on scaling range | |
""" | |
# split_mask = scaling.max( | |
# dim=-1 | |
# )[0] > self.scene_extent * self.percent_dense # shape: B 4096 | |
stds = scaling # B 13824 3 | |
# TODO, in mat form. axis aligned creation. | |
samples = torch.zeros(B, N, 3, 3).to(stds.device) | |
samples[..., 0,0] = stds[..., 0] | |
samples[..., 1,1] = stds[..., 1] | |
samples[..., 2,2] = stds[..., 2] | |
eye_mat = torch.cat([torch.eye(3), -torch.eye(3)], 0) # 6 * 3, to put gaussians along the axis | |
eye_mat = eye_mat.reshape(1,1,6,3).repeat(B, N, 1, 1).to(stds.device) | |
samples = (eye_mat @ samples).squeeze(-1) # B N 6 3 | |
# ! create corner | |
samples_corner = stds.clone().unsqueeze(-2).repeat(1,1,8,1) # B N 8 3 | |
# ! optimize with matmul, register to self | |
samples_corner = torch.mul(samples_corner,self.corner_mat) | |
samples = torch.cat([samples, samples_corner], -2) | |
rots = rearrange(build_rotation(rearrange(rotation, 'B N ... -> (B N) ...')), '(B N) ... -> B N ...', B=B, N=N) | |
rots = rots.unsqueeze(2).repeat_interleave(self.f, dim=2) # B f*N 3 3 | |
new_xyz = (rots @ samples.unsqueeze(-1)).squeeze(-1) + pos.unsqueeze(2).repeat_interleave(self.f, dim=2) # B N 6 3 | |
new_xyz = rearrange(new_xyz, 'b n f c -> b (n f) c') | |
# ! new points to features | |
new_xyz_embed = xyz_embed_fn(new_xyz) | |
new_gaussian_embed = self.write_ca( | |
new_xyz_embed, latent_from_vit | |
) # ! use z_(i+1), rather than gaussian_base_feat here | |
""" | |
# ! [global_emb, local_emb, learnable_query_emb] self attention -> fetch last K tokens as the learned query -> add to base | |
# ! query from local point emb | |
global_local_query_emb = torch.cat( | |
[ | |
# rearrange(latent_from_vit.unsqueeze(1).expand(-1,N,-1,-1), 'B N L C -> (B N) L C'), # 8, 768, 1024. expand() returns a new view. | |
rearrange(gaussian_base_feat, | |
'B N C -> (B N) 1 C'), # 8, 2304, 1024 -> 8*2304 1 C | |
self.latent_embedding.repeat(B * N, 1, | |
1) # 1, 14, 1024 -> B*N 14 1024 | |
], | |
dim=1) # OOM if prepend global feat | |
global_local_query_emb = self.transformer( | |
global_local_query_emb) # torch.Size([18432, 15, 1024]) | |
# st() # do self attention | |
# ! query from global shape emb | |
# new_gaussian_embed = self.write_ca( | |
# global_local_query_emb, | |
# rearrange(latent_from_vit.unsqueeze(1).expand(-1,N,-1,-1), 'B N L C -> (B N) L C'), | |
# ) # ! use z_(i+1), rather than gaussian_base_feat here | |
# ! predict gaussians residuals | |
gaussian_residual_pre_activate = self.gaussian_residual_pred( | |
global_local_query_emb[:, 1:, :]) | |
gaussian_residual_pre_activate = rearrange( | |
gaussian_residual_pre_activate, '(B N) L C -> B N L C', B=B, | |
N=N) # B 2304 14 C | |
# TODO here | |
# ? new_xyz from where | |
offsets = offset_act(gaussian_residual_pre_activate[..., 0:3]) | |
new_xyz = offsets + pos.unsqueeze(2).repeat_interleave( | |
self.f, dim=2) # B N F 3 | |
new_xyz = rearrange(new_xyz, 'b n f c -> b (n f) c') | |
gaussian_base_pre_activate_repeat = gaussian_base_pre_activate.unsqueeze( | |
-2).expand(-1, -1, self.f, -1) # avoid new memory allocation | |
gaussians = rearrange(gaussian_residual_pre_activate + | |
gaussian_base_pre_activate_repeat, | |
'B N F C -> B (N F) C', | |
B=B, | |
N=N) # learn the residual in the feature space | |
# return gaussians, new_xyz, latent_from_vit, new_gaussian_embed # return positions independently | |
# return gaussians, latent_from_vit, new_gaussian_embed # return positions independently | |
return gaussians, new_xyz | |
class GS_Adaptive_Read_Write_CA_adaptive_2dgs( | |
GS_Adaptive_Read_Write_CA_adaptive_f14_prepend): | |
def __init__(self, | |
query_dim, | |
context_dim, | |
mlp_ratio, | |
vit_heads, | |
f=16, | |
heads=8, | |
dim_head=64, | |
dropout=0, | |
depth=2, | |
vit_blk=DiTBlock2, | |
no_flash_op=False, | |
cross_attention=False,) -> None: | |
super().__init__(query_dim, context_dim, mlp_ratio, vit_heads, f, | |
heads, dim_head, dropout, depth, vit_blk, no_flash_op) | |
# del self.gaussian_residual_pred # will use base one | |
self.cross_attention = cross_attention | |
if cross_attention: # since much efficient than self attention, linear complexity | |
# del self.transformer | |
self.sr_ca = CrossAttention(query_dim, context_dim, # xformers fails large batch size: https://github.com/facebookresearch/xformers/issues/845 | |
heads, dim_head, dropout, | |
no_flash_op=no_flash_op) | |
# predict residual over base (features) | |
self.gaussian_residual_pred = PreNorm( # add prenorm since using pre-norm TX as the sr module | |
query_dim, nn.Linear(query_dim, 13, bias=True)) | |
# init as full zero, since predicting residual here | |
nn.init.constant_(self.gaussian_residual_pred.fn.weight, 0) | |
nn.init.constant_(self.gaussian_residual_pred.fn.bias, 0) | |
def forward(self, | |
latent_from_vit, | |
base_gaussians, | |
skip_weight, | |
offset_act, | |
gs_pred_fn, | |
gs_act_fn, | |
gaussian_base_pre_activate=None): | |
B, N, C = latent_from_vit.shape # e.g., B 768 768 | |
if not self.cross_attention: | |
# ! query from local point emb | |
global_local_query_emb = torch.cat( | |
[ | |
rearrange(latent_from_vit, | |
'B N C -> (B N) 1 C'), # 8, 2304, 1024 -> 8*2304 1 C | |
self.latent_embedding.repeat(B * N, 1, 1).to( | |
latent_from_vit) # 1, 14, 1024 -> B*N 14 1024 | |
], | |
dim=1) # OOM if prepend global feat | |
global_local_query_emb = self.transformer( | |
global_local_query_emb) # torch.Size([18432, 15, 1024]) | |
# ! add residuals to the base features | |
global_local_query_emb = rearrange(global_local_query_emb[:, 1:], | |
'(B N) L C -> B N L C', | |
B=B, | |
N=N) # B N C f | |
else: | |
# st() | |
# for xformers debug | |
# global_local_query_emb = self.sr_ca( self.latent_embedding.repeat(B, 1, 1).to( latent_from_vit).contiguous(), latent_from_vit[:, 0:1, :],) | |
# st() | |
# self.sr_ca( self.latent_embedding.repeat(B * N, 1, 1).to(latent_from_vit)[:8000], rearrange(latent_from_vit, 'B N C -> (B N) 1 C')[:8000],).shape | |
global_local_query_emb = self.sr_ca( self.latent_embedding.repeat(B * N, 1, 1).to(latent_from_vit), rearrange(latent_from_vit, 'B N C -> (B N) 1 C'),) | |
global_local_query_emb = self.transformer( | |
global_local_query_emb) # torch.Size([18432, 15, 1024]) | |
# ! add residuals to the base features | |
global_local_query_emb = rearrange(global_local_query_emb, | |
'(B N) L C -> B N L C', | |
B=B, | |
N=N) # B N C f | |
# * predict residual features | |
gaussian_residual_pre_activate = self.gaussian_residual_pred( | |
global_local_query_emb) | |
# ! directly add xyz offsets | |
offsets = offset_act(gaussian_residual_pre_activate[..., :3]) | |
gaussians_upsampled_pos = offsets + einops.repeat( | |
base_gaussians[..., :3], 'B N C -> B N F C', | |
F=self.f) # ! reasonable init | |
# ! add residual features | |
gaussian_residual_pre_activate = gaussian_residual_pre_activate + einops.repeat( | |
gaussian_base_pre_activate, 'B N C -> B N F C', F=self.f) | |
gaussians_upsampled = gs_act_fn(pos=gaussians_upsampled_pos, | |
x=gaussian_residual_pre_activate) | |
gaussians_upsampled = rearrange(gaussians_upsampled, | |
'B N F C -> B (N F) C') | |
return gaussians_upsampled, (rearrange( | |
gaussian_residual_pre_activate, 'B N F C -> B (N F) C' | |
), rearrange( | |
global_local_query_emb, 'B N F C -> B (N F) C' | |
)) | |
class ViTTriplaneDecomposed(nn.Module): | |
def __init__( | |
self, | |
vit_decoder, | |
triplane_decoder: Triplane, | |
cls_token=False, | |
decoder_pred_size=-1, | |
unpatchify_out_chans=-1, | |
sr_ratio=2, | |
) -> None: | |
super().__init__() | |
self.superresolution = None | |
self.decomposed_IN = False | |
self.decoder_pred_3d = None | |
self.transformer_3D_blk = None | |
self.logvar = None | |
self.cls_token = cls_token | |
self.vit_decoder = vit_decoder | |
self.triplane_decoder = triplane_decoder | |
# triplane_sr_ratio = self.triplane_decoder.triplane_size / self.vit_decoder.img_size | |
# self.decoder_pred = nn.Linear(self.vit_decoder.embed_dim, | |
# self.vit_decoder.patch_size**2 * | |
# self.triplane_decoder.out_chans, | |
# bias=True) # decoder to pat | |
# self.patch_size = self.vit_decoder.patch_embed.patch_size | |
self.patch_size = 14 # TODO, hard coded here | |
if isinstance(self.patch_size, tuple): # dino-v2 | |
self.patch_size = self.patch_size[0] | |
# self.img_size = self.vit_decoder.patch_embed.img_size | |
self.img_size = None # TODO, hard coded | |
if decoder_pred_size == -1: | |
decoder_pred_size = self.patch_size**2 * self.triplane_decoder.out_chans | |
if unpatchify_out_chans == -1: | |
self.unpatchify_out_chans = self.triplane_decoder.out_chans | |
else: | |
self.unpatchify_out_chans = unpatchify_out_chans | |
self.decoder_pred = nn.Linear( | |
self.vit_decoder.embed_dim, | |
decoder_pred_size, | |
# self.patch_size**2 * | |
# self.triplane_decoder.out_chans, | |
bias=True) # decoder to pat | |
# st() | |
def triplane_decode(self, latent, c): | |
ret_dict = self.triplane_decoder(latent, c) # triplane latent -> imgs | |
ret_dict.update({'latent': latent}) | |
return ret_dict | |
def triplane_renderer(self, latent, coordinates, directions): | |
planes = latent.view(len(latent), 3, | |
self.triplane_decoder.decoder_in_chans, | |
latent.shape[-2], | |
latent.shape[-1]) # BS 96 256 256 | |
ret_dict = self.triplane_decoder.renderer.run_model( | |
planes, self.triplane_decoder.decoder, coordinates, directions, | |
self.triplane_decoder.rendering_kwargs) # triplane latent -> imgs | |
# ret_dict.update({'latent': latent}) | |
return ret_dict | |
# * increase encoded encoded latent dim to match decoder | |
def forward_vit_decoder(self, x, img_size=None): | |
# latent: (N, L, C) from DINO/CLIP ViT encoder | |
# * also dino ViT | |
# add positional encoding to each token | |
if img_size is None: | |
img_size = self.img_size | |
if self.cls_token: | |
x = x + self.vit_decoder.interpolate_pos_encoding( | |
x, img_size, img_size)[:, :] # B, L, C | |
else: | |
x = x + self.vit_decoder.interpolate_pos_encoding( | |
x, img_size, img_size)[:, 1:] # B, L, C | |
for blk in self.vit_decoder.blocks: | |
x = blk(x) | |
x = self.vit_decoder.norm(x) | |
return x | |
def unpatchify(self, x, p=None, unpatchify_out_chans=None): | |
""" | |
x: (N, L, patch_size**2 * self.out_chans) | |
imgs: (N, self.out_chans, H, W) | |
""" | |
# st() | |
if unpatchify_out_chans is None: | |
unpatchify_out_chans = self.unpatchify_out_chans | |
# p = self.vit_decoder.patch_size | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, 1:] | |
if p is None: # assign upsample patch size | |
p = self.patch_size | |
h = w = int(x.shape[1]**.5) | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, unpatchify_out_chans)) | |
x = torch.einsum('nhwpqc->nchpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], unpatchify_out_chans, h * p, | |
h * p)) | |
return imgs | |
def forward(self, latent, c, img_size): | |
latent = self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
if self.cls_token: | |
# latent, cls_token = latent[:, 1:], latent[:, :1] | |
cls_token = latent[:, :1] | |
else: | |
cls_token = None | |
# ViT decoder projection, from MAE | |
latent = self.decoder_pred( | |
latent) # pred_vit_latent -> patch or original size | |
# st() | |
latent = self.unpatchify( | |
latent) # spatial_vit_latent, B, C, H, W (B, 96, 256,256) | |
# TODO 2D convolutions -> Triplane | |
# * triplane rendering | |
# ret_dict = self.forward_triplane_decoder(latent, | |
# c) # triplane latent -> imgs | |
ret_dict = self.triplane_decoder(planes=latent, c=c) | |
ret_dict.update({'latent': latent, 'cls_token': cls_token}) | |
return ret_dict | |
# merged above class into a single class | |
class vae_3d(nn.Module): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
ldm_z_channels, | |
ldm_embed_dim, | |
plane_n=1, | |
vae_dit_token_size=16, | |
**kwargs) -> None: | |
super().__init__() | |
self.reparameterization_soft_clamp = True # some instability in training VAE | |
# st() | |
self.plane_n = plane_n | |
self.cls_token = cls_token | |
self.vit_decoder = vit_decoder | |
self.triplane_decoder = triplane_decoder | |
self.patch_size = 14 # TODO, hard coded here | |
if isinstance(self.patch_size, tuple): # dino-v2 | |
self.patch_size = self.patch_size[0] | |
self.img_size = None # TODO, hard coded | |
self.ldm_z_channels = ldm_z_channels | |
self.ldm_embed_dim = ldm_embed_dim | |
self.vae_p = 4 # resolution = 4 * 16 | |
self.token_size = vae_dit_token_size # use dino-v2 dim tradition here | |
self.vae_res = self.vae_p * self.token_size | |
self.superresolution = nn.ModuleDict({}) # put all the stuffs here | |
self.embed_dim = vit_decoder.embed_dim | |
# placeholder for compat issue | |
self.decoder_pred = None | |
self.decoder_pred_3d = None | |
self.transformer_3D_blk = None | |
self.logvar = None | |
self.register_buffer('w_avg', torch.zeros([512])) | |
def init_weights(self): | |
# ! init (learnable) PE for DiT | |
self.vit_decoder.pos_embed = nn.Parameter( | |
torch.zeros(1, self.vit_decoder.embed_dim, | |
self.vit_decoder.embed_dim), | |
requires_grad=True) # token_size = embed_size by default. | |
trunc_normal_(self.vit_decoder.pos_embed, std=.02) | |
# the base class | |
class pcd_structured_latent_space_vae_decoder(vae_3d): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, **kwargs) | |
# from splatting_dit_v4_PI_V1_trilatent_sphere | |
self.D_roll_out_input = False | |
# ! renderer | |
self.gs = triplane_decoder # compat | |
self.rendering_kwargs = self.gs.rendering_kwargs | |
self.scene_range = [ | |
self.rendering_kwargs['sampler_bbox_min'], | |
self.rendering_kwargs['sampler_bbox_max'] | |
] | |
# hyper parameters | |
self.skip_weight = torch.tensor(0.1).to(dist_util.dev()) | |
self.offset_act = lambda x: torch.tanh(x) * (self.scene_range[ | |
1]) * 0.5 # regularize small offsets | |
self.vit_decoder.pos_embed = nn.Parameter( | |
torch.zeros(1, | |
self.plane_n * (self.token_size**2 + self.cls_token), | |
vit_decoder.embed_dim)) | |
self.init_weights() # re-init weights after re-writing token_size | |
self.output_size = { | |
'gaussians_base': 128, | |
} | |
# activations | |
self.rot_act = lambda x: F.normalize(x, dim=-1) # as fixed in lgm | |
self.scene_extent = self.rendering_kwargs['sampler_bbox_max'] * 0.01 | |
scaling_factor = (self.scene_extent / | |
F.softplus(torch.tensor(0.0))).to(dist_util.dev()) | |
self.scale_act = lambda x: F.softplus( | |
x | |
) * scaling_factor # make sure F.softplus(0) is the average scale size | |
self.rgb_act = lambda x: 0.5 * torch.tanh( | |
x) + 0.5 # NOTE: may use sigmoid if train again | |
self.pos_act = lambda x: x.clamp(-0.45, 0.45) | |
self.opacity_act = lambda x: torch.sigmoid(x) | |
self.superresolution.update( | |
dict( | |
conv_sr=surfel_prediction(query_dim=vit_decoder.embed_dim), | |
quant_conv=Mlp(in_features=2 * self.ldm_z_channels, | |
out_features=2 * self.ldm_embed_dim, | |
act_layer=approx_gelu, | |
drop=0), | |
post_quant_conv=Mlp(in_features=self.ldm_z_channels, | |
out_features=vit_decoder.embed_dim, | |
act_layer=approx_gelu, | |
drop=0), | |
ldm_upsample=nn.Identity(), | |
xyz_pos_embed=nn.Identity(), | |
)) | |
# for gs prediction | |
self.superresolution.update( # f=14 here | |
dict( | |
ada_CA_f4_1=GS_Adaptive_Read_Write_CA_adaptive_2dgs( | |
self.embed_dim, | |
vit_decoder.embed_dim, | |
vit_heads=vit_decoder.num_heads, | |
mlp_ratio=vit_decoder.mlp_ratio, | |
# depth=vit_decoder.depth // 6, | |
depth=vit_decoder.depth // 6 if vit_decoder.depth==12 else 2, | |
# f=16, # | |
f=8, # | |
heads=8), # write | |
)) | |
def vae_reparameterization(self, latent, sample_posterior): | |
# latent: B 24 32 32 | |
# assert self.vae_p > 1 | |
# ! do VAE here | |
posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L | |
assert sample_posterior | |
if sample_posterior: | |
# torch.manual_seed(0) | |
# np.random.seed(0) | |
kl_latent = posterior.sample() | |
else: | |
kl_latent = posterior.mode() # B C 3 L | |
ret_dict = dict( | |
latent_normalized=rearrange(kl_latent, 'B C L -> B L C'), | |
posterior=posterior, | |
query_pcd_xyz=latent['query_pcd_xyz'], | |
) | |
return ret_dict | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_sr_noptVAE.vae_encode | |
def vae_encode(self, h): | |
# * smooth convolution before triplane | |
# B, L, C = h.shape # | |
h, query_pcd_xyz = h['h'], h['query_pcd_xyz'] | |
moments = self.superresolution['quant_conv']( | |
h) # Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), groups=3) | |
moments = rearrange(moments, | |
'B L C -> B C L') # for sd vae code compat | |
posterior = DiagonalGaussianDistribution( | |
moments, soft_clamp=self.reparameterization_soft_clamp) | |
return posterior | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_novaePT._get_base_gaussians | |
def _get_base_gaussians(self, ret_after_decoder, c=None): | |
x = ret_after_decoder['gaussian_base_pre_activate'] | |
B, N, C = x.shape # B C D H W, 14-dim voxel features | |
assert C == 13 # 2dgs | |
offsets = self.offset_act(x[..., 0:3]) # ! model prediction | |
# st() | |
# vae_sampled_xyz = ret_after_decoder['latent_normalized'][..., :3] # B L C | |
vae_sampled_xyz = ret_after_decoder['query_pcd_xyz'].to( | |
x.dtype) # ! directly use fps pcd as "anchor points" | |
pos = offsets * self.skip_weight + vae_sampled_xyz # ! reasonable init | |
opacity = self.opacity_act(x[..., 3:4]) | |
scale = self.scale_act(x[..., 4:6]) | |
rotation = self.rot_act(x[..., 6:10]) | |
rgbs = self.rgb_act(x[..., 10:]) | |
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], | |
dim=-1) # [B, N, 14] | |
return gaussians | |
# from pcd_structured_latent_space | |
def vit_decode_backbone(self, latent, img_size): | |
# assert x.ndim == 3 # N L C | |
if isinstance(latent, dict): | |
latent = latent['latent_normalized'] # B, C*3, H, W | |
latent = self.superresolution['post_quant_conv']( | |
latent) # to later dit embed dim | |
# ! directly feed to vit_decoder | |
return { | |
'latent': latent, | |
'latent_from_vit': self.forward_vit_decoder(latent, img_size) | |
} # pred_vit_latent | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_sr | |
def _gaussian_pred_activations(self, pos, x): | |
# if pos is None: | |
opacity = self.opacity_act(x[..., 3:4]) | |
scale = self.scale_act(x[..., 4:6]) | |
rotation = self.rot_act(x[..., 6:10]) | |
rgbs = self.rgb_act(x[..., 10:]) | |
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], | |
dim=-1) # [B, N, 14] | |
return gaussians.float() | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_sr | |
def vis_gaussian(self, gaussians, file_name_base): | |
# gaussians = ret_after_decoder['gaussians'] | |
# gaussians = ret_after_decoder['latent_after_vit']['gaussians_base'] | |
B = gaussians.shape[0] | |
pos, opacity, scale, rotation, rgbs = gaussians[..., 0:3], gaussians[ | |
..., 3:4], gaussians[..., 4:6], gaussians[..., | |
6:10], gaussians[..., | |
10:13] | |
file_path = Path(logger.get_dir()) | |
for b in range(B): | |
file_name = f'{file_name_base}-{b}' | |
np.save(file_path / f'{file_name}_opacity.npy', | |
opacity[b].float().detach().cpu().numpy()) | |
np.save(file_path / f'{file_name}_scale.npy', | |
scale[b].float().detach().cpu().numpy()) | |
np.save(file_path / f'{file_name}_rotation.npy', | |
rotation[b].float().detach().cpu().numpy()) | |
pcu.save_mesh_vc(str(file_path / f'{file_name}.ply'), | |
pos[b].float().detach().cpu().numpy(), | |
rgbs[b].float().detach().cpu().numpy()) | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict, return_upsampled_residual=False): | |
# from ViT_decode_backbone() | |
# latent_from_vit = latent_from_vit['latent_from_vit'] | |
# vae_sampled_xyz = ret_dict['query_pcd_xyz'].to(latent_from_vit.dtype) # ! directly use fps pcd as "anchor points" | |
gaussian_base_pre_activate = self.superresolution['conv_sr']( | |
latent_from_vit['latent_from_vit']) # B 14 H W | |
gaussians_base = self._get_base_gaussians( | |
{ | |
# 'latent_from_vit': latent_from_vit, # latent (vae latent), latent_from_vit (dit) | |
# 'ret_dict': ret_dict, | |
**ret_dict, | |
'gaussian_base_pre_activate': | |
gaussian_base_pre_activate, | |
}, ) | |
gaussians_upsampled, (gaussian_upsampled_residual_pre_activate, upsampled_global_local_query_emb) = self.superresolution['ada_CA_f4_1']( | |
latent_from_vit['latent_from_vit'], | |
gaussians_base, | |
skip_weight=self.skip_weight, | |
gs_pred_fn=self.superresolution['conv_sr'], | |
gs_act_fn=self._gaussian_pred_activations, | |
offset_act=self.offset_act, | |
gaussian_base_pre_activate=gaussian_base_pre_activate) | |
ret_dict.update({ | |
'gaussians_upsampled': gaussians_upsampled, | |
'gaussians_base': gaussians_base | |
}) # | |
if return_upsampled_residual: | |
return ret_dict, (gaussian_upsampled_residual_pre_activate, upsampled_global_local_query_emb) | |
else: | |
return ret_dict | |
def vit_decode(self, latent, img_size, sample_posterior=True, c=None): | |
ret_dict = self.vae_reparameterization(latent, sample_posterior) | |
latent = self.vit_decode_backbone(ret_dict, img_size) | |
ret_after_decoder = self.vit_decode_postprocess(latent, ret_dict) | |
return self.forward_gaussians(ret_after_decoder, c=c) | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_novaePT_sr.forward_gaussians | |
def forward_gaussians(self, ret_after_decoder, c=None): | |
# ! currently, only using upsampled gaussians for training. | |
# if True: | |
if False: | |
ret_after_decoder['gaussians'] = torch.cat([ | |
ret_after_decoder['gaussians_base'], | |
ret_after_decoder['gaussians_upsampled'], | |
], | |
dim=1) | |
else: # only adopt SR | |
# ! random drop out requires | |
ret_after_decoder['gaussians'] = ret_after_decoder[ | |
'gaussians_upsampled'] | |
# ret_after_decoder['gaussians'] = ret_after_decoder['gaussians_base'] | |
pass # directly use base. vis first. | |
ret_after_decoder.update({ | |
'gaussians': ret_after_decoder['gaussians'], | |
'pos': ret_after_decoder['gaussians'][..., :3], | |
'gaussians_base_opa': ret_after_decoder['gaussians_base'][..., 3:4] | |
}) | |
# st() | |
# self.vis_gaussian(ret_after_decoder['gaussians'], 'sr-8') | |
# self.vis_gaussian(ret_after_decoder['gaussians_base'], 'sr-8-base') | |
# pcu.save_mesh_v(f'{Path(logger.get_dir())}/anchor-fps-8.ply',ret_after_decoder['query_pcd_xyz'][0].float().detach().cpu().numpy()) | |
# st() | |
# ! render at L:8414 triplane_decode() | |
return ret_after_decoder | |
def forward_vit_decoder(self, x, img_size=None): | |
return self.vit_decoder(x) | |
# from pcd_structured_latent_space_lion_learnoffset_surfel_novaePT_sr_cascade.triplane_decode | |
def triplane_decode(self, | |
ret_after_gaussian_forward, | |
c, | |
bg_color=None, | |
render_all_scale=False, | |
**kwargs): | |
# ! render multi-res img with different gaussians | |
def render_gs(gaussians, c_data, output_size): | |
results = self.gs.render( | |
gaussians, # type: ignore | |
c_data['cam_view'], | |
c_data['cam_view_proj'], | |
c_data['cam_pos'], | |
tanfov=c_data['tanfov'], | |
bg_color=bg_color, | |
output_size=output_size, | |
) | |
results['image_raw'] = results[ | |
'image'] * 2 - 1 # [0,1] -> [-1,1], match tradition | |
results['image_depth'] = results['depth'] | |
results['image_mask'] = results['alpha'] | |
return results | |
cascade_splatting_results = {} | |
# for gaussians_key in ('gaussians_base', 'gaussians_upsampled'): | |
all_keys_to_render = list(self.output_size.keys()) | |
if self.rand_base_render and not render_all_scale: | |
keys_to_render = [random.choice(all_keys_to_render[:-1])] + [all_keys_to_render[-1]] | |
else: | |
keys_to_render = all_keys_to_render | |
for gaussians_key in keys_to_render: | |
cascade_splatting_results[gaussians_key] = render_gs(ret_after_gaussian_forward[gaussians_key], c, self.output_size[gaussians_key]) | |
return cascade_splatting_results | |
class pcd_structured_latent_space_vae_decoder_cascaded(pcd_structured_latent_space_vae_decoder): | |
# for 2dgs | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, **kwargs) | |
self.output_size.update( | |
{ | |
'gaussians_upsampled': 256, | |
'gaussians_upsampled_2': 384, | |
'gaussians_upsampled_3': 512, | |
} | |
) | |
self.rand_base_render = True | |
# further x8 up-sampling. | |
self.superresolution.update( | |
dict( | |
ada_CA_f4_2=GS_Adaptive_Read_Write_CA_adaptive_2dgs( | |
self.embed_dim, | |
vit_decoder.embed_dim, | |
vit_heads=vit_decoder.num_heads, | |
mlp_ratio=vit_decoder.mlp_ratio, | |
# depth=vit_decoder.depth // 6, | |
depth=1, | |
f=4, # | |
heads=8, | |
no_flash_op=True, # fails when bs>1 | |
cross_attention=False), # write | |
ada_CA_f4_3=GS_Adaptive_Read_Write_CA_adaptive_2dgs( | |
self.embed_dim, | |
vit_decoder.embed_dim, | |
vit_heads=vit_decoder.num_heads, | |
mlp_ratio=vit_decoder.mlp_ratio, | |
# depth=vit_decoder.depth // 6, | |
depth=1, | |
f=3, # | |
heads=8, | |
no_flash_op=True, | |
cross_attention=False), # write | |
), | |
) | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
# further x8 using upper class | |
# TODO, merge this into ln3diff open sourced code. | |
ret_dict, (gaussian_upsampled_residual_pre_activate, upsampled_global_local_query_emb) = super().vit_decode_postprocess(latent_from_vit, ret_dict, return_upsampled_residual=True) | |
gaussians_upsampled_2, (gaussian_upsampled_residual_pre_activate_2, upsampled_global_local_query_emb_2) = self.superresolution['ada_CA_f4_2']( | |
upsampled_global_local_query_emb, | |
ret_dict['gaussians_upsampled'], | |
skip_weight=self.skip_weight, | |
gs_pred_fn=self.superresolution['conv_sr'], | |
gs_act_fn=self._gaussian_pred_activations, | |
offset_act=self.offset_act, | |
gaussian_base_pre_activate=gaussian_upsampled_residual_pre_activate) | |
gaussians_upsampled_3, _ = self.superresolution['ada_CA_f4_3']( | |
upsampled_global_local_query_emb_2, | |
gaussians_upsampled_2, | |
skip_weight=self.skip_weight, | |
gs_pred_fn=self.superresolution['conv_sr'], | |
gs_act_fn=self._gaussian_pred_activations, | |
offset_act=self.offset_act, | |
gaussian_base_pre_activate=gaussian_upsampled_residual_pre_activate_2) | |
ret_dict.update({ | |
'gaussians_upsampled_2': gaussians_upsampled_2, | |
'gaussians_upsampled_3': gaussians_upsampled_3, | |
}) | |
return ret_dict | |