Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import random | |
from einops import rearrange | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
import numpy as np | |
from tqdm import trange | |
from functools import partial | |
from nsr.networks_stylegan2 import Generator as StyleGAN2Backbone | |
from nsr.volumetric_rendering.renderer import ImportanceRenderer, ImportanceRendererfg_bg | |
from nsr.volumetric_rendering.ray_sampler import RaySampler | |
from nsr.triplane import OSGDecoder, Triplane, Triplane_fg_bg_plane | |
# from nsr.losses.helpers import ResidualBlock | |
# from vit.vision_transformer import TriplaneFusionBlockv4_nested, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino | |
from vit.vision_transformer import TriplaneFusionBlockv4_nested, TriplaneFusionBlockv4_nested_init_from_dino_lite, TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino | |
from .vision_transformer import Block, VisionTransformer | |
from .utils import trunc_normal_ | |
from guided_diffusion import dist_util, logger | |
from pdb import set_trace as st | |
from ldm.modules.diffusionmodules.model import Encoder, Decoder | |
from utils.torch_utils.components import PixelShuffleUpsample, ResidualBlock, Upsample, PixelUnshuffleUpsample, Conv3x3TriplaneTransformation | |
from utils.torch_utils.distributions.distributions import DiagonalGaussianDistribution | |
from nsr.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X | |
from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer | |
from nsr.common_blks import ResMlp | |
from .vision_transformer import * | |
from dit.dit_models import get_2d_sincos_pos_embed | |
from torch import _assert | |
from itertools import repeat | |
import collections.abc | |
# From PyTorch internals | |
def _ntuple(n): | |
def parse(x): | |
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
return tuple(x) | |
return tuple(repeat(x, n)) | |
return parse | |
to_1tuple = _ntuple(1) | |
to_2tuple = _ntuple(2) | |
class PatchEmbedTriplane(nn.Module): | |
""" GroupConv patchembeder on triplane | |
""" | |
def __init__( | |
self, | |
img_size=32, | |
patch_size=2, | |
in_chans=4, | |
embed_dim=768, | |
norm_layer=None, | |
flatten=True, | |
bias=True, | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.grid_size = (img_size[0] // patch_size[0], | |
img_size[1] // patch_size[1]) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] | |
self.flatten = flatten | |
self.proj = nn.Conv2d(in_chans, | |
embed_dim * 3, | |
kernel_size=patch_size, | |
stride=patch_size, | |
bias=bias, | |
groups=3) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
B, C, H, W = x.shape | |
_assert( | |
H == self.img_size[0], | |
f"Input image height ({H}) doesn't match model ({self.img_size[0]})." | |
) | |
_assert( | |
W == self.img_size[1], | |
f"Input image width ({W}) doesn't match model ({self.img_size[1]})." | |
) | |
x = self.proj(x) # B 3*C token_H token_W | |
x = x.reshape(B, x.shape[1] // 3, 3, x.shape[-2], | |
x.shape[-1]) # B C 3 H W | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # BC3HW -> B 3HW C | |
x = self.norm(x) | |
return x | |
class PatchEmbedTriplaneRodin(PatchEmbedTriplane): | |
def __init__(self, | |
img_size=32, | |
patch_size=2, | |
in_chans=4, | |
embed_dim=768, | |
norm_layer=None, | |
flatten=True, | |
bias=True): | |
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, | |
flatten, bias) | |
self.proj = RodinRollOutConv3D_GroupConv(in_chans, | |
embed_dim * 3, | |
kernel_size=patch_size, | |
stride=patch_size, | |
padding=0) | |
class ViTTriplaneDecomposed(nn.Module): | |
def __init__( | |
self, | |
vit_decoder, | |
triplane_decoder: Triplane, | |
cls_token=False, | |
decoder_pred_size=-1, | |
unpatchify_out_chans=-1, | |
# * uvit arch | |
channel_multiplier=4, | |
use_fusion_blk=True, | |
fusion_blk_depth=4, | |
fusion_blk=TriplaneFusionBlock, | |
fusion_blk_start=0, # appy fusion blk start with? | |
ldm_z_channels=4, # | |
ldm_embed_dim=4, | |
vae_p=2, | |
token_size=None, | |
w_avg=torch.zeros([512]), | |
patch_size=None, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
# self.superresolution = None | |
self.superresolution = nn.ModuleDict({}) | |
self.decomposed_IN = False | |
self.decoder_pred_3d = None | |
self.transformer_3D_blk = None | |
self.logvar = None | |
self.channel_multiplier = channel_multiplier | |
self.cls_token = cls_token | |
self.vit_decoder = vit_decoder | |
self.triplane_decoder = triplane_decoder | |
if patch_size is None: | |
self.patch_size = self.vit_decoder.patch_embed.patch_size | |
else: | |
self.patch_size = patch_size | |
if isinstance(self.patch_size, tuple): # dino-v2 | |
self.patch_size = self.patch_size[0] | |
# self.img_size = self.vit_decoder.patch_embed.img_size | |
if unpatchify_out_chans == -1: | |
self.unpatchify_out_chans = self.triplane_decoder.out_chans | |
else: | |
self.unpatchify_out_chans = unpatchify_out_chans | |
# ! mlp decoder from mae/dino | |
if decoder_pred_size == -1: | |
decoder_pred_size = self.patch_size**2 * self.triplane_decoder.out_chans | |
self.decoder_pred = nn.Linear( | |
self.vit_decoder.embed_dim, | |
decoder_pred_size, | |
# self.patch_size**2 * | |
# self.triplane_decoder.out_chans, | |
bias=True) # decoder to pat | |
# st() | |
# triplane | |
self.plane_n = 3 | |
# ! vae | |
self.ldm_z_channels = ldm_z_channels | |
self.ldm_embed_dim = ldm_embed_dim | |
self.vae_p = vae_p | |
self.token_size = 16 # use dino-v2 dim tradition here | |
self.vae_res = self.vae_p * self.token_size | |
# ! uvit | |
# if token_size is None: | |
# token_size = 224 // self.patch_size | |
# logger.log('token_size: {}', token_size) | |
self.vit_decoder.pos_embed = nn.Parameter( | |
torch.zeros(1, 3 * (self.token_size**2 + self.cls_token), | |
vit_decoder.embed_dim)) | |
self.fusion_blk_start = fusion_blk_start | |
self.create_fusion_blks(fusion_blk_depth, use_fusion_blk, fusion_blk) | |
# self.vit_decoder.cls_token = self.vit_decoder.cls_token.clone().repeat_interleave(3, dim=0) # each plane has a separate cls token | |
# translate | |
# ! placeholder, not used here | |
self.register_buffer('w_avg', w_avg) # will replace externally | |
self.rendering_kwargs = self.triplane_decoder.rendering_kwargs | |
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): | |
# planes: (N, 3, D', H', W') | |
# points: (N, P, 3) | |
N, P = points.shape[:2] | |
if planes.ndim == 4: | |
planes = planes.reshape( | |
len(planes), | |
3, | |
-1, # ! support background plane | |
planes.shape[-2], | |
planes.shape[-1]) # BS 96 256 256 | |
# query triplane in chunks | |
outs = [] | |
for i in trange(0, points.shape[1], chunk_size): | |
chunk_points = points[:, i:i+chunk_size] | |
# query triplane | |
# st() | |
chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore | |
planes=planes, | |
decoder=self.triplane_decoder.decoder, | |
sample_coordinates=chunk_points, | |
sample_directions=torch.zeros_like(chunk_points), | |
options=self.rendering_kwargs, | |
) | |
# st() | |
outs.append(chunk_out) | |
torch.cuda.empty_cache() | |
# st() | |
# concatenate the outputs | |
point_features = { | |
k: torch.cat([out[k] for out in outs], dim=1) | |
for k in outs[0].keys() | |
} | |
return point_features | |
def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): | |
# planes: (N, 3, D', H', W') | |
# grid_size: int | |
assert isinstance(vit_decode_out, dict) | |
planes = vit_decode_out['latent_after_vit'] | |
# aabb: (N, 2, 3) | |
if aabb is None: | |
if 'sampler_bbox_min' in self.rendering_kwargs: | |
aabb = torch.tensor([ | |
[self.rendering_kwargs['sampler_bbox_min']] * 3, | |
[self.rendering_kwargs['sampler_bbox_max']] * 3, | |
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) | |
else: # shapenet dataset, follow eg3d | |
aabb = torch.tensor([ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 | |
[-self.rendering_kwargs['box_warp']/2] * 3, | |
[self.rendering_kwargs['box_warp']/2] * 3, | |
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) | |
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" | |
N = planes.shape[0] | |
# create grid points for triplane query | |
grid_points = [] | |
for i in range(N): | |
grid_points.append(torch.stack(torch.meshgrid( | |
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), | |
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), | |
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), | |
indexing='ij', | |
), dim=-1).reshape(-1, 3)) | |
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 | |
# st() | |
features = self.forward_points(planes, cube_grid) | |
# reshape into grid | |
features = { | |
k: v.reshape(N, grid_size, grid_size, grid_size, -1) | |
for k, v in features.items() | |
} | |
# st() | |
return features | |
def create_uvit_arch(self): | |
# create skip linear | |
logger.log( | |
f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, | |
self.vit_decoder.embed_dim) | |
# trunc_normal_(blk.skip_linear.weight, std=.02) | |
nn.init.constant_(blk.skip_linear.weight, 0) | |
if isinstance(blk.skip_linear, | |
nn.Linear) and blk.skip_linear.bias is not None: | |
nn.init.constant_(blk.skip_linear.bias, 0) | |
# | |
def vit_decode_backbone(self, latent, img_size): | |
return self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
def init_weights(self): | |
# Initialize (and freeze) pos_embed by sin-cos embedding: | |
p = self.token_size | |
D = self.vit_decoder.pos_embed.shape[-1] | |
grid_size = (3 * p, p) | |
pos_embed = get_2d_sincos_pos_embed(D, | |
grid_size).reshape(3 * p * p, | |
D) # H*W, D | |
self.vit_decoder.pos_embed.data.copy_( | |
torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
logger.log('init pos_embed with sincos') | |
# ! | |
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): | |
vit_decoder_blks = self.vit_decoder.blocks | |
assert len(vit_decoder_blks) == 12, 'ViT-B by default' | |
nh = self.vit_decoder.blocks[0].attn.num_heads | |
dim = self.vit_decoder.embed_dim | |
fusion_blk_start = self.fusion_blk_start | |
triplane_fusion_vit_blks = nn.ModuleList() | |
if fusion_blk_start != 0: | |
for i in range(0, fusion_blk_start): | |
triplane_fusion_vit_blks.append( | |
vit_decoder_blks[i]) # append all vit blocks in the front | |
for i in range(fusion_blk_start, len(vit_decoder_blks), | |
fusion_blk_depth): | |
vit_blks_group = vit_decoder_blks[i:i + | |
fusion_blk_depth] # moduleList | |
triplane_fusion_vit_blks.append( | |
# TriplaneFusionBlockv2(vit_blks_group, nh, dim, use_fusion_blk)) | |
fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) | |
self.vit_decoder.blocks = triplane_fusion_vit_blks | |
def triplane_decode(self, latent, c): | |
ret_dict = self.triplane_decoder(latent, c) # triplane latent -> imgs | |
ret_dict.update({'latent': latent}) | |
return ret_dict | |
def triplane_renderer(self, latent, coordinates, directions): | |
planes = latent.view(len(latent), 3, | |
self.triplane_decoder.decoder_in_chans, | |
latent.shape[-2], | |
latent.shape[-1]) # BS 96 256 256 | |
ret_dict = self.triplane_decoder.renderer.run_model( | |
planes, self.triplane_decoder.decoder, coordinates, directions, | |
self.triplane_decoder.rendering_kwargs) # triplane latent -> imgs | |
# ret_dict.update({'latent': latent}) | |
return ret_dict | |
# * increase encoded encoded latent dim to match decoder | |
# ! util functions | |
def unpatchify_triplane(self, x, p=None, unpatchify_out_chans=None): | |
""" | |
x: (N, L, patch_size**2 * self.out_chans) | |
imgs: (N, self.out_chans, H, W) | |
""" | |
if unpatchify_out_chans is None: | |
unpatchify_out_chans = self.unpatchify_out_chans // 3 | |
# p = self.vit_decoder.patch_size | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, 1:] | |
if p is None: # assign upsample patch size | |
p = self.patch_size | |
h = w = int((x.shape[1] // 3)**.5) | |
assert h * w * 3 == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], 3, h, w, p, p, unpatchify_out_chans)) | |
x = torch.einsum('ndhwpqc->ndchpwq', | |
x) # nplanes, C order in the renderer.py | |
triplanes = x.reshape(shape=(x.shape[0], unpatchify_out_chans * 3, | |
h * p, h * p)) | |
return triplanes | |
def interpolate_pos_encoding(self, x, w, h): | |
previous_dtype = x.dtype | |
npatch = x.shape[1] - 1 | |
N = self.vit_decoder.pos_embed.shape[1] - 1 # type: ignore | |
# if npatch == N and w == h: | |
# assert npatch == N and w == h | |
return self.vit_decoder.pos_embed | |
# pos_embed = self.vit_decoder.pos_embed.float() | |
# return pos_embed | |
class_pos_embed = pos_embed[:, 0] # type: ignore | |
patch_pos_embed = pos_embed[:, 1:] # type: ignore | |
dim = x.shape[-1] | |
w0 = w // self.patch_size | |
h0 = h // self.patch_size | |
# we add a small number to avoid floating point error in the interpolation | |
# see discussion at https://github.com/facebookresearch/dino/issues/8 | |
w0, h0 = w0 + 0.1, h0 + 0.1 | |
# patch_pos_embed = nn.functional.interpolate( | |
# patch_pos_embed.reshape(1, 3, int(math.sqrt(N//3)), int(math.sqrt(N//3)), dim).permute(0, 4, 1, 2, 3), | |
# scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), | |
# mode="bicubic", | |
# ) # ! no interpolation needed, just add, since the resolution shall match | |
# assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] | |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | |
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), | |
dim=1).to(previous_dtype) | |
def forward_vit_decoder(self, x, img_size=None): | |
# latent: (N, L, C) from DINO/CLIP ViT encoder | |
# * also dino ViT | |
# add positional encoding to each token | |
if img_size is None: | |
img_size = self.img_size | |
if self.cls_token: | |
x = x + self.vit_decoder.interpolate_pos_encoding( | |
x, img_size, img_size)[:, :] # B, L, C | |
else: | |
x = x + self.vit_decoder.interpolate_pos_encoding( | |
x, img_size, img_size)[:, 1:] # B, L, C | |
for blk in self.vit_decoder.blocks: | |
x = blk(x) | |
x = self.vit_decoder.norm(x) | |
return x | |
def unpatchify(self, x, p=None, unpatchify_out_chans=None): | |
""" | |
x: (N, L, patch_size**2 * self.out_chans) | |
imgs: (N, self.out_chans, H, W) | |
""" | |
# st() | |
if unpatchify_out_chans is None: | |
unpatchify_out_chans = self.unpatchify_out_chans | |
# p = self.vit_decoder.patch_size | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, 1:] | |
if p is None: # assign upsample patch size | |
p = self.patch_size | |
h = w = int(x.shape[1]**.5) | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, unpatchify_out_chans)) | |
x = torch.einsum('nhwpqc->nchpwq', x) | |
imgs = x.reshape(shape=(x.shape[0], unpatchify_out_chans, h * p, | |
h * p)) | |
return imgs | |
def forward(self, latent, c, img_size): | |
latent = self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
if self.cls_token: | |
# latent, cls_token = latent[:, 1:], latent[:, :1] | |
cls_token = latent[:, :1] | |
else: | |
cls_token = None | |
# ViT decoder projection, from MAE | |
latent = self.decoder_pred( | |
latent) # pred_vit_latent -> patch or original size | |
# st() | |
latent = self.unpatchify( | |
latent) # spatial_vit_latent, B, C, H, W (B, 96, 256,256) | |
# TODO 2D convolutions -> Triplane | |
# * triplane rendering | |
# ret_dict = self.forward_triplane_decoder(latent, | |
# c) # triplane latent -> imgs | |
ret_dict = self.triplane_decoder(planes=latent, c=c) | |
ret_dict.update({'latent': latent, 'cls_token': cls_token}) | |
return ret_dict | |
class VAE_LDM_V4_vit3D_v3_conv3D_depth2_xformer_mha_PEinit_2d_sincos_uvit_RodinRollOutConv_4x4_lite_mlp_unshuffle_4XC_final( | |
ViTTriplaneDecomposed): | |
""" | |
1. reuse attention proj layer from dino | |
2. reuse attention; first self then 3D cross attention | |
""" | |
""" 4*4 SR with 2X channels | |
""" | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane, | |
cls_token, | |
# normalize_feat=True, | |
# sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
channel_multiplier=4, | |
fusion_blk=TriplaneFusionBlockv3, | |
**kwargs) -> None: | |
super().__init__( | |
vit_decoder, | |
triplane_decoder, | |
cls_token, | |
# normalize_feat, | |
# sr_ratio, | |
fusion_blk=fusion_blk, # type: ignore | |
use_fusion_blk=use_fusion_blk, | |
fusion_blk_depth=fusion_blk_depth, | |
channel_multiplier=channel_multiplier, | |
decoder_pred_size=(4 // 1)**2 * | |
int(triplane_decoder.out_chans // 3 * channel_multiplier), | |
**kwargs) | |
patch_size = vit_decoder.patch_embed.patch_size # type: ignore | |
self.reparameterization_soft_clamp = False | |
if isinstance(patch_size, tuple): | |
patch_size = patch_size[0] | |
# ! todo, hard coded | |
unpatchify_out_chans = triplane_decoder.out_chans * 1, | |
if unpatchify_out_chans == -1: | |
unpatchify_out_chans = triplane_decoder.out_chans * 3 | |
ldm_z_channels = triplane_decoder.out_chans | |
# ldm_embed_dim = 16 # https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/kl-f16/config.yaml | |
ldm_embed_dim = triplane_decoder.out_chans | |
ldm_z_channels = ldm_embed_dim = triplane_decoder.out_chans | |
self.superresolution.update( | |
dict( | |
after_vit_conv=nn.Conv2d( | |
int(triplane_decoder.out_chans * 2), | |
triplane_decoder.out_chans * 2, # for vae features | |
3, | |
padding=1), | |
quant_conv=torch.nn.Conv2d(2 * ldm_z_channels, | |
2 * ldm_embed_dim, 1), | |
ldm_downsample=nn.Linear( | |
384, | |
# vit_decoder.embed_dim, | |
self.vae_p * self.vae_p * 3 * self.ldm_z_channels * | |
2, # 48 | |
bias=True), | |
ldm_upsample=nn.Linear(self.vae_p * self.vae_p * | |
self.ldm_z_channels * 1, | |
vit_decoder.embed_dim, | |
bias=True), # ? too high dim upsample | |
quant_mlp=Mlp(2 * self.ldm_z_channels, | |
out_features=2 * self.ldm_embed_dim), | |
conv_sr=RodinConv3D4X_lite_mlp_as_residual( | |
int(triplane_decoder.out_chans * channel_multiplier), | |
int(triplane_decoder.out_chans * 1)))) | |
has_token = bool(self.cls_token) | |
self.vit_decoder.pos_embed = nn.Parameter( | |
torch.zeros(1, 3 * 16 * 16 + has_token, vit_decoder.embed_dim)) | |
self.init_weights() | |
self.reparameterization_soft_clamp = True # some instability in training VAE | |
self.create_uvit_arch() | |
def vae_reparameterization(self, latent, sample_posterior): | |
"""input: latent from ViT encoder | |
""" | |
# ! first downsample for VAE | |
latents3D = self.superresolution['ldm_downsample'](latent) # B L 24 | |
if self.vae_p > 1: | |
latents3D = self.unpatchify3D( | |
latents3D, | |
p=self.vae_p, | |
unpatchify_out_chans=self.ldm_z_channels * | |
2) # B 3 H W unpatchify_out_chans, H=W=16 now | |
latents3D = latents3D.reshape( | |
latents3D.shape[0], 3, -1, latents3D.shape[-1] | |
) # B 3 H*W C (H=self.vae_p*self.token_size) | |
else: | |
latents3D = latents3D.reshape(latents3D.shape[0], | |
latents3D.shape[1], 3, | |
2 * self.ldm_z_channels) # B L 3 C | |
latents3D = latents3D.permute(0, 2, 1, 3) # B 3 L C | |
# ! maintain the cls token here | |
# latent3D = latent.reshape() | |
# ! do VAE here | |
posterior = self.vae_encode(latents3D) # B self.ldm_z_channels 3 L | |
if sample_posterior: | |
latent = posterior.sample() | |
else: | |
latent = posterior.mode() # B C 3 L | |
log_q = posterior.log_p(latent) # same shape as latent | |
# latent = latent.permute(0, 2, 3, 4, | |
# 1) # C to the last dim, B 3 16 16 4, for unpachify 3D | |
# ! for LSGM KL code | |
latent_normalized_2Ddiffusion = latent.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
log_q_2Ddiffusion = log_q.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C | |
latent = latent.reshape(latent.shape[0], -1, | |
latent.shape[-1]) # B 3*L C | |
ret_dict = dict( | |
normal_entropy=posterior.normal_entropy(), | |
latent_normalized=latent, | |
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # | |
log_q_2Ddiffusion=log_q_2Ddiffusion, | |
log_q=log_q, | |
posterior=posterior, | |
latent_name= | |
'latent_normalized' # for which latent to decode; could be modified externally | |
) | |
return ret_dict | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
# ViT decoder projection, from MAE | |
latent = self.decoder_pred( | |
latent_from_vit | |
) # pred_vit_latent -> patch or original size; B 768 384 | |
latent = self.unpatchify_triplane( | |
latent, | |
p=4, | |
unpatchify_out_chans=int( | |
self.channel_multiplier * self.unpatchify_out_chans // | |
3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) | |
# 4X SR with Rodin Conv 3D | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
sr_w_code = self.w_avg | |
assert sr_w_code is not None | |
ret_dict.update( | |
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
def forward_vit_decoder(self, x, img_size=None): | |
# latent: (N, L, C) from DINO/CLIP ViT encoder | |
# * also dino ViT | |
# add positional encoding to each token | |
if img_size is None: | |
img_size = self.img_size | |
# if self.cls_token: | |
# st() | |
x = x + self.interpolate_pos_encoding(x, img_size, | |
img_size)[:, :] # B, L, C | |
B, L, C = x.shape # has [cls] token in N | |
x = x.view(B, 3, L // 3, C) | |
skips = [x] | |
assert self.fusion_blk_start == 0 | |
# in blks | |
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // | |
2 - 1]: | |
x = blk(x) # B 3 N C | |
skips.append(x) | |
# mid blks | |
# for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - | |
1:len(self.vit_decoder.blocks) // | |
2]: | |
x = blk(x) # B 3 N C | |
# out blks | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
x = x + blk.skip_linear(torch.cat([x, skips.pop()], | |
dim=-1)) # long skip connections | |
x = blk(x) # B 3 N C | |
x = self.vit_decoder.norm(x) | |
# post process shape | |
x = x.view(B, L, C) | |
return x | |
def triplane_decode(self, | |
vit_decode_out, | |
c, | |
return_raw_only=False, | |
**kwargs): | |
if isinstance(vit_decode_out, dict): | |
latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) | |
for k in ('latent_after_vit', | |
'sr_w_code')) | |
else: | |
latent_after_vit = vit_decode_out | |
sr_w_code = None | |
vit_decode_out = dict(latent_normalized=latent_after_vit | |
) # for later dict update compatability | |
# * triplane rendering | |
ret_dict = self.triplane_decoder(latent_after_vit, | |
c, | |
ws=sr_w_code, | |
return_raw_only=return_raw_only, | |
**kwargs) # triplane latent -> imgs | |
ret_dict.update({ | |
'latent_after_vit': latent_after_vit, | |
**vit_decode_out | |
}) | |
return ret_dict | |
def vit_decode_backbone(self, latent, img_size): | |
# assert x.ndim == 3 # N L C | |
if isinstance(latent, dict): | |
if 'latent_normalized' not in latent: | |
latent = latent[ | |
'latent_normalized_2Ddiffusion'] # B, C*3, H, W | |
else: | |
latent = latent[ | |
'latent_normalized'] # TODO, just for compatability now | |
# st() | |
if latent.ndim != 3: # B 3*4 16 16 | |
latent = latent.reshape(latent.shape[0], latent.shape[1] // 3, 3, | |
(self.vae_p * self.token_size)**2).permute( | |
0, 2, 3, 1) # B C 3 L => B 3 L C | |
latent = latent.reshape(latent.shape[0], -1, | |
latent.shape[-1]) # B 3*L C | |
assert latent.shape == ( | |
# latent.shape[0], 3 * (self.token_size**2), | |
latent.shape[0], | |
3 * ((self.vae_p * self.token_size)**2), | |
self.ldm_z_channels), f'latent.shape: {latent.shape}' | |
latent = self.superresolution['ldm_upsample'](latent) | |
return super().vit_decode_backbone( | |
latent, img_size) # torch.Size([8, 3072, 768]) | |
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn( | |
ViTTriplaneDecomposed): | |
# lite version, no sd-bg, use TriplaneFusionBlockv4_nested_init_from_dino | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
# normalize_feat=True, | |
# sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, | |
channel_multiplier=4, | |
ldm_z_channels=4, # | |
ldm_embed_dim=4, | |
vae_p=2, | |
**kwargs) -> None: | |
# st() | |
super().__init__( | |
vit_decoder, | |
triplane_decoder, | |
cls_token, | |
# normalize_feat, | |
channel_multiplier=channel_multiplier, | |
use_fusion_blk=use_fusion_blk, | |
fusion_blk_depth=fusion_blk_depth, | |
fusion_blk=fusion_blk, | |
ldm_z_channels=ldm_z_channels, | |
ldm_embed_dim=ldm_embed_dim, | |
vae_p=vae_p, | |
decoder_pred_size=(4 // 1)**2 * | |
int(triplane_decoder.out_chans // 3 * channel_multiplier), | |
**kwargs) | |
logger.log( | |
f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') | |
# latent vae modules | |
self.superresolution.update( | |
dict( | |
ldm_downsample=nn.Linear( | |
384, | |
self.vae_p * self.vae_p * 3 * self.ldm_z_channels * | |
2, # 48 | |
bias=True), | |
ldm_upsample=PatchEmbedTriplane( | |
self.vae_p * self.token_size, | |
self.vae_p, | |
3 * self.ldm_embed_dim, # B 3 L C | |
vit_decoder.embed_dim, | |
bias=True), | |
quant_conv=nn.Conv2d(2 * 3 * self.ldm_z_channels, | |
2 * self.ldm_embed_dim * 3, | |
kernel_size=1, | |
groups=3), | |
conv_sr=RodinConv3D4X_lite_mlp_as_residual_lite( | |
int(triplane_decoder.out_chans * channel_multiplier), | |
int(triplane_decoder.out_chans * 1)))) | |
# ! initialize weights | |
self.init_weights() | |
self.reparameterization_soft_clamp = True # some instability in training VAE | |
self.create_uvit_arch() | |
# create skip linear, adapted from uvit | |
# for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
# blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, | |
# self.vit_decoder.embed_dim) | |
# # trunc_normal_(blk.skip_linear.weight, std=.02) | |
# nn.init.constant_(blk.skip_linear.weight, 0) | |
# if isinstance(blk.skip_linear, | |
# nn.Linear) and blk.skip_linear.bias is not None: | |
# nn.init.constant_(blk.skip_linear.bias, 0) | |
def vit_decode(self, latent, img_size, sample_posterior=True): | |
ret_dict = self.vae_reparameterization(latent, sample_posterior) | |
# latent = ret_dict['latent_normalized'] | |
latent = self.vit_decode_backbone(ret_dict, img_size) | |
return self.vit_decode_postprocess(latent, ret_dict) | |
# # ! merge? | |
def unpatchify3D(self, x, p, unpatchify_out_chans, plane_n=3): | |
""" | |
x: (N, L, patch_size**2 * self.out_chans) | |
return: 3D latents | |
""" | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, 1:] | |
h = w = int(x.shape[1]**.5) | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, plane_n, | |
unpatchify_out_chans)) | |
x = torch.einsum( | |
'nhwpqdc->ndhpwqc', x | |
) # nplanes, C little endian tradiition, as defined in the renderer.py | |
latents3D = x.reshape(shape=(x.shape[0], plane_n, h * p, h * p, | |
unpatchify_out_chans)) | |
return latents3D | |
# ! merge? | |
def vae_encode(self, h): | |
# * smooth convolution before triplane | |
# h = self.superresolution['after_vit_conv'](h) | |
# h = h.permute(0, 2, 3, 1) # B 64 64 6 | |
B, _, H, W = h.shape | |
moments = self.superresolution['quant_conv'](h) | |
moments = moments.reshape( | |
B, | |
# moments.shape[1] // 3, | |
moments.shape[1] // self.plane_n, | |
# 3, | |
self.plane_n, | |
H, | |
W, | |
) # B C 3 H W | |
moments = moments.flatten(-2) # B C 3 L | |
posterior = DiagonalGaussianDistribution( | |
moments, soft_clamp=self.reparameterization_soft_clamp) | |
return posterior | |
def vae_reparameterization(self, latent, sample_posterior): | |
"""input: latent from ViT encoder | |
""" | |
# ! first downsample for VAE | |
# st() # latent: B 256 384 | |
latents3D = self.superresolution['ldm_downsample']( | |
latent) # latents3D: B 256 96 | |
assert self.vae_p > 1 | |
latents3D = self.unpatchify3D( | |
latents3D, | |
p=self.vae_p, | |
unpatchify_out_chans=self.ldm_z_channels * | |
2) # B 3 H W unpatchify_out_chans, H=W=16 now | |
# latents3D = latents3D.reshape( | |
# latents3D.shape[0], 3, -1, latents3D.shape[-1] | |
# ) # B 3 H*W C (H=self.vae_p*self.token_size) | |
# else: | |
# latents3D = latents3D.reshape(latents3D.shape[0], | |
# latents3D.shape[1], 3, | |
# 2 * self.ldm_z_channels) # B L 3 C | |
# latents3D = latents3D.permute(0, 2, 1, 3) # B 3 L C | |
B, _, H, W, C = latents3D.shape | |
latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, | |
W) # B 3C H W | |
# ! do VAE here | |
posterior = self.vae_encode(latents3D) # B self.ldm_z_channels 3 L | |
if sample_posterior: | |
latent = posterior.sample() | |
else: | |
latent = posterior.mode() # B C 3 L | |
log_q = posterior.log_p(latent) # same shape as latent | |
# ! for LSGM KL code | |
latent_normalized_2Ddiffusion = latent.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
log_q_2Ddiffusion = log_q.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
# st() | |
latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C | |
latent = latent.reshape(latent.shape[0], -1, | |
latent.shape[-1]) # B 3*L C | |
ret_dict = dict( | |
normal_entropy=posterior.normal_entropy(), | |
latent_normalized=latent, | |
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # | |
log_q_2Ddiffusion=log_q_2Ddiffusion, | |
log_q=log_q, | |
posterior=posterior, | |
) | |
return ret_dict | |
def vit_decode_backbone(self, latent, img_size): | |
# assert x.ndim == 3 # N L C | |
if isinstance(latent, dict): | |
latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W | |
# assert latent.shape == ( | |
# latent.shape[0], 3 * (self.token_size * self.vae_p)**2, | |
# self.ldm_z_channels), f'latent.shape: {latent.shape}' | |
# st() # latent: B 12 32 32 | |
latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 | |
latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) | |
# latent: torch.Size([8, 768, 768]) | |
# ! directly feed to vit_decoder | |
return self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
def triplane_decode(self, | |
vit_decode_out, | |
c, | |
return_raw_only=False, | |
**kwargs): | |
if isinstance(vit_decode_out, dict): | |
latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) | |
for k in ('latent_after_vit', | |
'sr_w_code')) | |
else: | |
latent_after_vit = vit_decode_out | |
sr_w_code = None | |
vit_decode_out = dict(latent_normalized=latent_after_vit | |
) # for later dict update compatability | |
# * triplane rendering | |
ret_dict = self.triplane_decoder(latent_after_vit, | |
c, | |
ws=sr_w_code, | |
return_raw_only=return_raw_only, | |
**kwargs) # triplane latent -> imgs | |
ret_dict.update({ | |
'latent_after_vit': latent_after_vit, | |
**vit_decode_out | |
}) | |
return ret_dict | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
# ViT decoder projection, from MAE | |
latent = self.decoder_pred( | |
latent_from_vit | |
) # pred_vit_latent -> patch or original size; B 768 384 | |
latent = self.unpatchify_triplane( | |
latent, | |
p=4, | |
unpatchify_out_chans=int( | |
self.channel_multiplier * self.unpatchify_out_chans // | |
3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) | |
# 4X SR with Rodin Conv 3D | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
sr_w_code = self.w_avg | |
assert sr_w_code is not None | |
ret_dict.update( | |
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
def forward_vit_decoder(self, x, img_size=None): | |
# latent: (N, L, C) from DINO/CLIP ViT encoder | |
# * also dino ViT | |
# add positional encoding to each token | |
if img_size is None: | |
img_size = self.img_size | |
# if self.cls_token: | |
# st() | |
x = x + self.interpolate_pos_encoding(x, img_size, | |
img_size)[:, :] # B, L, C | |
B, L, C = x.shape # has [cls] token in N | |
x = x.view(B, 3, L // 3, C) | |
skips = [x] | |
assert self.fusion_blk_start == 0 | |
# in blks | |
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // | |
2 - 1]: | |
x = blk(x) # B 3 N C | |
skips.append(x) | |
# mid blks | |
# for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - | |
1:len(self.vit_decoder.blocks) // | |
2]: | |
x = blk(x) # B 3 N C | |
# out blks | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
x = x + blk.skip_linear(torch.cat([x, skips.pop()], | |
dim=-1)) # long skip connections | |
x = blk(x) # B 3 N C | |
x = self.vit_decoder.norm(x) | |
# post process shape | |
x = x.view(B, L, C) | |
return x | |
# ! SD version | |
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD( | |
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): | |
def __init__(self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, | |
triplane_decoder, | |
cls_token, | |
# sr_ratio=sr_ratio, # not used | |
use_fusion_blk=use_fusion_blk, | |
fusion_blk_depth=fusion_blk_depth, | |
fusion_blk=fusion_blk, | |
channel_multiplier=channel_multiplier, | |
**kwargs) | |
for k in [ | |
'ldm_downsample', | |
# 'conv_sr' | |
]: | |
del self.superresolution[k] | |
def vae_reparameterization(self, latent, sample_posterior): | |
# latent: B 24 32 32 | |
assert self.vae_p > 1 | |
# latents3D = self.unpatchify3D( | |
# latents3D, | |
# p=self.vae_p, | |
# unpatchify_out_chans=self.ldm_z_channels * | |
# 2) # B 3 H W unpatchify_out_chans, H=W=16 now | |
# B, C3, H, W = latent.shape | |
# latents3D = latent.reshape(B, 3, C3//3, H, W) | |
# latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, | |
# W) # B 3C H W | |
# ! do VAE here | |
posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L | |
if sample_posterior: | |
latent = posterior.sample() | |
else: | |
latent = posterior.mode() # B C 3 L | |
log_q = posterior.log_p(latent) # same shape as latent | |
# ! for LSGM KL code | |
latent_normalized_2Ddiffusion = latent.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
log_q_2Ddiffusion = log_q.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C | |
latent = latent.reshape(latent.shape[0], -1, | |
latent.shape[-1]) # B 3*L C | |
ret_dict = dict( | |
normal_entropy=posterior.normal_entropy(), | |
latent_normalized=latent, | |
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # | |
log_q_2Ddiffusion=log_q_2Ddiffusion, | |
log_q=log_q, | |
posterior=posterior, | |
) | |
return ret_dict | |
class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD_D( | |
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): | |
def __init__(self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, | |
normalize_feat, sr_ratio, use_fusion_blk, | |
fusion_blk_depth, fusion_blk, channel_multiplier, | |
**kwargs) | |
self.decoder_pred = None # directly un-patchembed | |
self.superresolution.update( | |
dict(conv_sr=Decoder( # serve as Deconv | |
resolution=128, | |
in_channels=3, | |
# ch=64, | |
ch=32, | |
ch_mult=[1, 2, 2, 4], | |
# num_res_blocks=2, | |
# ch_mult=[1,2,4], | |
num_res_blocks=1, | |
dropout=0.0, | |
attn_resolutions=[], | |
out_ch=32, | |
# z_channels=vit_decoder.embed_dim//4, | |
z_channels=vit_decoder.embed_dim, | |
))) | |
# ''' # for SD Decoder, verify encoder first | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
def unflatten_token(x, p=None): | |
B, L, C = x.shape | |
x = x.reshape(B, 3, L // 3, C) | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, :, 1:] # B 3 256 C | |
h = w = int((x.shape[2])**.5) | |
assert h * w == x.shape[2] | |
if p is None: | |
x = x.reshape(shape=(B, 3, h, w, -1)) | |
x = rearrange( | |
x, 'b n h w c->(b n) c h w' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) | |
x = rearrange( | |
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' | |
) # merge plane into Batch and prepare for rendering | |
return x | |
latent = unflatten_token(latent_from_vit) | |
# latent = unflatten_token(latent_from_vit, p=2) | |
# ! SD SR | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
# sr_w_code = self.w_avg | |
# assert sr_w_code is not None | |
# ret_dict.update( | |
# dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
# latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
# ''' | |
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_lite3DAttn( | |
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): | |
def __init__(self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, | |
normalize_feat, sr_ratio, use_fusion_blk, | |
fusion_blk_depth, fusion_blk, channel_multiplier, | |
**kwargs) | |
# 1. convert output plane token to B L 3 C//3 shape | |
# 2. change vit decoder fusion arch (fusion block) | |
# 3. output follow B L 3 C//3 with decoder input dim C//3 | |
# TODO: ablate basic decoder design, on the metrics (input/novelview both) | |
self.decoder_pred = nn.Linear(self.vit_decoder.embed_dim // 3, | |
2048, | |
bias=True) # decoder to patch | |
# st() | |
self.superresolution.update( | |
dict(ldm_upsample=PatchEmbedTriplaneRodin( | |
self.vae_p * self.token_size, | |
self.vae_p, | |
3 * self.ldm_embed_dim, # B 3 L C | |
vit_decoder.embed_dim // 3, | |
bias=True))) | |
# ! original pos_embed | |
has_token = bool(self.cls_token) | |
self.vit_decoder.pos_embed = nn.Parameter( | |
torch.zeros(1, 16 * 16 + has_token, vit_decoder.embed_dim)) | |
def forward(self, latent, c, img_size): | |
latent_normalized = self.vit_decode(latent, img_size) | |
return self.triplane_decode(latent_normalized, c) | |
def vae_reparameterization(self, latent, sample_posterior): | |
# latent: B 24 32 32 | |
assert self.vae_p > 1 | |
# ! do VAE here | |
# st() | |
posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L | |
if sample_posterior: | |
latent = posterior.sample() | |
else: | |
latent = posterior.mode() # B C 3 L | |
log_q = posterior.log_p(latent) # same shape as latent | |
# ! for LSGM KL code | |
latent_normalized_2Ddiffusion = latent.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
log_q_2Ddiffusion = log_q.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
# TODO, add a conv_after_quant | |
# ! reshape for ViT decoder | |
latent = latent.permute(0, 3, 1, 2) # B C 3 L -> B L C 3 | |
latent = latent.reshape(*latent.shape[:2], -1) # B L C3 | |
ret_dict = dict( | |
normal_entropy=posterior.normal_entropy(), | |
latent_normalized=latent, | |
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # | |
log_q_2Ddiffusion=log_q_2Ddiffusion, | |
log_q=log_q, | |
posterior=posterior, | |
) | |
return ret_dict | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
B, N, C = latent_from_vit.shape | |
latent_from_vit = latent_from_vit.reshape(B, N, C // 3, 3).permute( | |
0, 3, 1, 2) # -> B 3 N C//3 | |
# ! remaining unchanged | |
# ViT decoder projection, from MAE | |
latent = self.decoder_pred( | |
latent_from_vit | |
) # pred_vit_latent -> patch or original size; B 768 384 | |
latent = latent.reshape(B, 3 * N, -1) # B L C | |
latent = self.unpatchify_triplane( | |
latent, | |
p=4, | |
unpatchify_out_chans=int( | |
self.channel_multiplier * self.unpatchify_out_chans // | |
3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) | |
# 4X SR with Rodin Conv 3D | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
sr_w_code = self.w_avg | |
assert sr_w_code is not None | |
ret_dict.update( | |
dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
def vit_decode_backbone(self, latent, img_size): | |
# assert x.ndim == 3 # N L C | |
if isinstance(latent, dict): | |
latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W | |
# assert latent.shape == ( | |
# latent.shape[0], 3 * (self.token_size * self.vae_p)**2, | |
# self.ldm_z_channels), f'latent.shape: {latent.shape}' | |
# st() # latent: B 12 32 32 | |
latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 | |
latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) | |
# latent: torch.Size([8, 768, 768]) | |
B, N3, C = latent.shape | |
latent = latent.reshape(B, 3, N3 // 3, | |
C).permute(0, 2, 3, 1) # B 3HW C -> B HW C 3 | |
latent = latent.reshape(*latent.shape[:2], -1) # B HW C3 | |
# ! directly feed to vit_decoder | |
return self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
def forward_vit_decoder(self, x, img_size=None): | |
# latent: (N, L, C) from DINO/CLIP ViT encoder | |
# * also dino ViT | |
# add positional encoding to each token | |
if img_size is None: | |
img_size = self.img_size | |
# if self.cls_token: | |
x = x + self.interpolate_pos_encoding(x, img_size, | |
img_size)[:, :] # B, L, C | |
B, L, C = x.shape # has [cls] token in N | |
# ! no need to reshape here | |
# x = x.view(B, 3, L // 3, C) | |
skips = [x] | |
assert self.fusion_blk_start == 0 | |
# in blks | |
for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // | |
2 - 1]: | |
x = blk(x) # B 3 N C | |
skips.append(x) | |
# mid blks | |
# for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - | |
1:len(self.vit_decoder.blocks) // | |
2]: | |
x = blk(x) # B 3 N C | |
# out blks | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
x = x + blk.skip_linear(torch.cat([x, skips.pop()], | |
dim=-1)) # long skip connections | |
x = blk(x) # B 3 N C | |
x = self.vit_decoder.norm(x) | |
# post process shape | |
x = x.view(B, L, C) | |
return x | |
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): | |
vit_decoder_blks = self.vit_decoder.blocks | |
assert len(vit_decoder_blks) == 12, 'ViT-B by default' | |
nh = self.vit_decoder.blocks[ | |
0].attn.num_heads // 3 # ! lighter, actually divisible by 4 | |
dim = self.vit_decoder.embed_dim // 3 # ! separate | |
fusion_blk_start = self.fusion_blk_start | |
triplane_fusion_vit_blks = nn.ModuleList() | |
if fusion_blk_start != 0: | |
for i in range(0, fusion_blk_start): | |
triplane_fusion_vit_blks.append( | |
vit_decoder_blks[i]) # append all vit blocks in the front | |
for i in range(fusion_blk_start, len(vit_decoder_blks), | |
fusion_blk_depth): | |
vit_blks_group = vit_decoder_blks[i:i + | |
fusion_blk_depth] # moduleList | |
triplane_fusion_vit_blks.append( | |
# TriplaneFusionBlockv2(vit_blks_group, nh, dim, use_fusion_blk)) | |
fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) | |
self.vit_decoder.blocks = triplane_fusion_vit_blks | |
# self.vit_decoder.blocks = triplane_fusion_vit_blks | |
# default for objaverse | |
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S( | |
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__( | |
vit_decoder, | |
triplane_decoder, | |
cls_token, | |
use_fusion_blk=use_fusion_blk, | |
fusion_blk_depth=fusion_blk_depth, | |
fusion_blk=fusion_blk, | |
channel_multiplier=channel_multiplier, | |
patch_size=-1, # placeholder, since we use dit here | |
token_size=2, | |
**kwargs) | |
self.D_roll_out_input = False | |
for k in [ | |
'ldm_downsample', | |
# 'conv_sr' | |
]: | |
del self.superresolution[k] | |
self.decoder_pred = None # directly un-patchembed | |
self.superresolution.update( | |
dict( | |
conv_sr=Decoder( # serve as Deconv | |
resolution=128, | |
# resolution=256, | |
in_channels=3, | |
# ch=64, | |
ch=32, | |
# ch=16, | |
ch_mult=[1, 2, 2, 4], | |
# ch_mult=[1, 1, 2, 2], | |
# num_res_blocks=2, | |
# ch_mult=[1,2,4], | |
# num_res_blocks=0, | |
num_res_blocks=1, | |
dropout=0.0, | |
attn_resolutions=[], | |
out_ch=32, | |
# z_channels=vit_decoder.embed_dim//4, | |
z_channels=vit_decoder.embed_dim, | |
# z_channels=vit_decoder.embed_dim//2, | |
), | |
# after_vit_upsampler=Upsample2D(channels=vit_decoder.embed_dim,use_conv=True, use_conv_transpose=False, out_channels=vit_decoder.embed_dim//2) | |
)) | |
# del skip_lienar | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
del blk.skip_linear | |
def forward_points(self, | |
planes, | |
points: torch.Tensor, | |
chunk_size: int = 2**16): | |
# planes: (N, 3, D', H', W') | |
# points: (N, P, 3) | |
N, P = points.shape[:2] | |
if planes.ndim == 4: | |
planes = planes.reshape( | |
len(planes), | |
3, | |
-1, # ! support background plane | |
planes.shape[-2], | |
planes.shape[-1]) # BS 96 256 256 | |
# query triplane in chunks | |
outs = [] | |
for i in trange(0, points.shape[1], chunk_size): | |
chunk_points = points[:, i:i + chunk_size] | |
# query triplane | |
# st() | |
chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore | |
planes=planes, | |
decoder=self.triplane_decoder.decoder, | |
sample_coordinates=chunk_points, | |
sample_directions=torch.zeros_like(chunk_points), | |
options=self.rendering_kwargs, | |
) | |
# st() | |
outs.append(chunk_out) | |
torch.cuda.empty_cache() | |
# st() | |
# concatenate the outputs | |
point_features = { | |
k: torch.cat([out[k] for out in outs], dim=1) | |
for k in outs[0].keys() | |
} | |
return point_features | |
def triplane_decode_grid(self, | |
vit_decode_out, | |
grid_size, | |
aabb: torch.Tensor = None, | |
**kwargs): | |
# planes: (N, 3, D', H', W') | |
# grid_size: int | |
assert isinstance(vit_decode_out, dict) | |
planes = vit_decode_out['latent_after_vit'] | |
# aabb: (N, 2, 3) | |
if aabb is None: | |
if 'sampler_bbox_min' in self.rendering_kwargs: | |
aabb = torch.tensor([ | |
[self.rendering_kwargs['sampler_bbox_min']] * 3, | |
[self.rendering_kwargs['sampler_bbox_max']] * 3, | |
], | |
device=planes.device, | |
dtype=planes.dtype).unsqueeze(0).repeat( | |
planes.shape[0], 1, 1) | |
else: # shapenet dataset, follow eg3d | |
aabb = torch.tensor( | |
[ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 | |
[-self.rendering_kwargs['box_warp'] / 2] * 3, | |
[self.rendering_kwargs['box_warp'] / 2] * 3, | |
], | |
device=planes.device, | |
dtype=planes.dtype).unsqueeze(0).repeat( | |
planes.shape[0], 1, 1) | |
assert planes.shape[0] == aabb.shape[ | |
0], "Batch size mismatch for planes and aabb" | |
N = planes.shape[0] | |
# create grid points for triplane query | |
grid_points = [] | |
for i in range(N): | |
grid_points.append( | |
torch.stack(torch.meshgrid( | |
torch.linspace(aabb[i, 0, 0], | |
aabb[i, 1, 0], | |
grid_size, | |
device=planes.device), | |
torch.linspace(aabb[i, 0, 1], | |
aabb[i, 1, 1], | |
grid_size, | |
device=planes.device), | |
torch.linspace(aabb[i, 0, 2], | |
aabb[i, 1, 2], | |
grid_size, | |
device=planes.device), | |
indexing='ij', | |
), | |
dim=-1).reshape(-1, 3)) | |
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 | |
features = self.forward_points(planes, cube_grid) | |
# reshape into grid | |
features = { | |
k: v.reshape(N, grid_size, grid_size, grid_size, -1) | |
for k, v in features.items() | |
} | |
# st() | |
return features | |
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): | |
# no need to fuse anymore | |
pass | |
def forward_vit_decoder(self, x, img_size=None): | |
# st() | |
return self.vit_decoder(x) | |
def vit_decode_backbone(self, latent, img_size): | |
# assert x.ndim == 3 # N L C | |
if isinstance(latent, dict): | |
latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W | |
# assert latent.shape == ( | |
# latent.shape[0], 3 * (self.token_size * self.vae_p)**2, | |
# self.ldm_z_channels), f'latent.shape: {latent.shape}' | |
# st() # latent: B 12 32 32 | |
# st() | |
latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 | |
latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) | |
# latent: torch.Size([8, 768, 768]) | |
# ! directly feed to vit_decoder | |
return self.forward_vit_decoder(latent, img_size) # pred_vit_latent | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
def unflatten_token(x, p=None): | |
B, L, C = x.shape | |
x = x.reshape(B, 3, L // 3, C) | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, :, 1:] # B 3 256 C | |
h = w = int((x.shape[2])**.5) | |
assert h * w == x.shape[2] | |
if p is None: | |
x = x.reshape(shape=(B, 3, h, w, -1)) | |
if not self.D_roll_out_input: | |
x = rearrange( | |
x, 'b n h w c->(b n) c h w' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = rearrange( | |
x, 'b n h w c->b c h (n w)' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) | |
if self.D_roll_out_input: | |
x = rearrange( | |
x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = rearrange( | |
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' | |
) # merge plane into Batch and prepare for rendering | |
return x | |
latent = unflatten_token( | |
latent_from_vit) # B 3 h w vit_decoder.embed_dim | |
# ! x2 upsapmle, 16 -32 before sending into SD Decoder | |
# latent = self.superresolution['after_vit_upsampler'](latent) # B*3 192 32 32 | |
# latent = unflatten_token(latent_from_vit, p=2) | |
# ! SD SR | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
if not self.D_roll_out_input: | |
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) | |
else: | |
latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
# sr_w_code = self.w_avg | |
# assert sr_w_code is not None | |
# ret_dict.update( | |
# dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
# latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
def vae_reparameterization(self, latent, sample_posterior): | |
# latent: B 24 32 32 | |
assert self.vae_p > 1 | |
# latents3D = self.unpatchify3D( | |
# latents3D, | |
# p=self.vae_p, | |
# unpatchify_out_chans=self.ldm_z_channels * | |
# 2) # B 3 H W unpatchify_out_chans, H=W=16 now | |
# B, C3, H, W = latent.shape | |
# latents3D = latent.reshape(B, 3, C3//3, H, W) | |
# latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, | |
# W) # B 3C H W | |
# ! do VAE here | |
posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L | |
if sample_posterior: | |
latent = posterior.sample() | |
else: | |
latent = posterior.mode() # B C 3 L | |
log_q = posterior.log_p(latent) # same shape as latent | |
# ! for LSGM KL code | |
latent_normalized_2Ddiffusion = latent.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
log_q_2Ddiffusion = log_q.reshape( | |
latent.shape[0], -1, self.token_size * self.vae_p, | |
self.token_size * self.vae_p) # B, 3*4, 16 16 | |
# st() | |
latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C | |
latent = latent.reshape(latent.shape[0], -1, | |
latent.shape[-1]) # B 3*L C | |
ret_dict = dict( | |
normal_entropy=posterior.normal_entropy(), | |
latent_normalized=latent, | |
latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # | |
log_q_2Ddiffusion=log_q_2Ddiffusion, | |
log_q=log_q, | |
posterior=posterior, | |
) | |
return ret_dict | |
def vit_decode(self, latent, img_size, sample_posterior=True, **kwargs): | |
return super().vit_decode(latent, img_size, sample_posterior) | |
# objv class | |
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout( | |
RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, | |
normalize_feat, sr_ratio, use_fusion_blk, | |
fusion_blk_depth, fusion_blk, channel_multiplier, | |
**kwargs) | |
# final version, above + SD-Decoder | |
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D( | |
RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout | |
): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, | |
normalize_feat, sr_ratio, use_fusion_blk, | |
fusion_blk_depth, fusion_blk, channel_multiplier, | |
**kwargs) | |
self.decoder_pred = None # directly un-patchembed | |
self.superresolution.update( | |
dict( | |
conv_sr=Decoder( # serve as Deconv | |
resolution=128, | |
# resolution=256, | |
in_channels=3, | |
# ch=64, | |
ch=32, | |
# ch=16, | |
ch_mult=[1, 2, 2, 4], | |
# ch_mult=[1, 1, 2, 2], | |
# num_res_blocks=2, | |
# ch_mult=[1,2,4], | |
# num_res_blocks=0, | |
num_res_blocks=1, | |
dropout=0.0, | |
attn_resolutions=[], | |
out_ch=32, | |
# z_channels=vit_decoder.embed_dim//4, | |
z_channels=vit_decoder.embed_dim, | |
# z_channels=vit_decoder.embed_dim//2, | |
), | |
# after_vit_upsampler=Upsample2D(channels=vit_decoder.embed_dim,use_conv=True, use_conv_transpose=False, out_channels=vit_decoder.embed_dim//2) | |
)) | |
self.D_roll_out_input = False | |
# ''' # for SD Decoder | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
if self.cls_token: | |
cls_token = latent_from_vit[:, :1] | |
else: | |
cls_token = None | |
def unflatten_token(x, p=None): | |
B, L, C = x.shape | |
x = x.reshape(B, 3, L // 3, C) | |
if self.cls_token: # TODO, how to better use cls token | |
x = x[:, :, 1:] # B 3 256 C | |
h = w = int((x.shape[2])**.5) | |
assert h * w == x.shape[2] | |
if p is None: | |
x = x.reshape(shape=(B, 3, h, w, -1)) | |
if not self.D_roll_out_input: | |
x = rearrange( | |
x, 'b n h w c->(b n) c h w' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = rearrange( | |
x, 'b n h w c->b c h (n w)' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = x.reshape(shape=(B, 3, h, w, p, p, -1)) | |
if self.D_roll_out_input: | |
x = rearrange( | |
x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' | |
) # merge plane into Batch and prepare for rendering | |
else: | |
x = rearrange( | |
x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' | |
) # merge plane into Batch and prepare for rendering | |
return x | |
latent = unflatten_token( | |
latent_from_vit) # B 3 h w vit_decoder.embed_dim | |
# ! x2 upsapmle, 16 -32 before sending into SD Decoder | |
# latent = self.superresolution['after_vit_upsampler'](latent) # B*3 192 32 32 | |
# latent = unflatten_token(latent_from_vit, p=2) | |
# ! SD SR | |
latent = self.superresolution['conv_sr'](latent) # still B 3C H W | |
if not self.D_roll_out_input: | |
latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) | |
else: | |
latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) | |
ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) | |
# include the w_avg for now | |
# sr_w_code = self.w_avg | |
# assert sr_w_code is not None | |
# ret_dict.update( | |
# dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( | |
# latent_from_vit.shape[0], 0), )) # type: ignore | |
return ret_dict | |
# ''' | |
class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder( | |
RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D | |
): | |
def __init__( | |
self, | |
vit_decoder: VisionTransformer, | |
triplane_decoder: Triplane_fg_bg_plane, | |
cls_token, | |
normalize_feat=True, | |
sr_ratio=2, | |
use_fusion_blk=True, | |
fusion_blk_depth=2, | |
fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, | |
channel_multiplier=4, | |
**kwargs) -> None: | |
super().__init__(vit_decoder, triplane_decoder, cls_token, | |
normalize_feat, sr_ratio, use_fusion_blk, | |
fusion_blk_depth, fusion_blk, channel_multiplier, | |
patch_size=-1, | |
**kwargs) | |
# del skip_lienar | |
for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: | |
del blk.skip_linear | |
def forward_points(self, | |
planes, | |
points: torch.Tensor, | |
chunk_size: int = 2**16): | |
# planes: (N, 3, D', H', W') | |
# points: (N, P, 3) | |
N, P = points.shape[:2] | |
if planes.ndim == 4: | |
planes = planes.reshape( | |
len(planes), | |
3, | |
-1, # ! support background plane | |
planes.shape[-2], | |
planes.shape[-1]) # BS 96 256 256 | |
# query triplane in chunks | |
outs = [] | |
for i in trange(0, points.shape[1], chunk_size): | |
chunk_points = points[:, i:i + chunk_size] | |
# query triplane | |
# st() | |
chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore | |
planes=planes, | |
decoder=self.triplane_decoder.decoder, | |
sample_coordinates=chunk_points, | |
sample_directions=torch.zeros_like(chunk_points), | |
options=self.rendering_kwargs, | |
) | |
# st() | |
outs.append(chunk_out) | |
torch.cuda.empty_cache() | |
# st() | |
# concatenate the outputs | |
point_features = { | |
k: torch.cat([out[k] for out in outs], dim=1) | |
for k in outs[0].keys() | |
} | |
return point_features | |
def triplane_decode_grid(self, | |
vit_decode_out, | |
grid_size, | |
aabb: torch.Tensor = None, | |
**kwargs): | |
# planes: (N, 3, D', H', W') | |
# grid_size: int | |
assert isinstance(vit_decode_out, dict) | |
planes = vit_decode_out['latent_after_vit'] | |
# aabb: (N, 2, 3) | |
if aabb is None: | |
if 'sampler_bbox_min' in self.rendering_kwargs: | |
aabb = torch.tensor([ | |
[self.rendering_kwargs['sampler_bbox_min']] * 3, | |
[self.rendering_kwargs['sampler_bbox_max']] * 3, | |
], | |
device=planes.device, | |
dtype=planes.dtype).unsqueeze(0).repeat( | |
planes.shape[0], 1, 1) | |
else: # shapenet dataset, follow eg3d | |
aabb = torch.tensor( | |
[ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 | |
[-self.rendering_kwargs['box_warp'] / 2] * 3, | |
[self.rendering_kwargs['box_warp'] / 2] * 3, | |
], | |
device=planes.device, | |
dtype=planes.dtype).unsqueeze(0).repeat( | |
planes.shape[0], 1, 1) | |
assert planes.shape[0] == aabb.shape[ | |
0], "Batch size mismatch for planes and aabb" | |
N = planes.shape[0] | |
# create grid points for triplane query | |
grid_points = [] | |
for i in range(N): | |
grid_points.append( | |
torch.stack(torch.meshgrid( | |
torch.linspace(aabb[i, 0, 0], | |
aabb[i, 1, 0], | |
grid_size, | |
device=planes.device), | |
torch.linspace(aabb[i, 0, 1], | |
aabb[i, 1, 1], | |
grid_size, | |
device=planes.device), | |
torch.linspace(aabb[i, 0, 2], | |
aabb[i, 1, 2], | |
grid_size, | |
device=planes.device), | |
indexing='ij', | |
), | |
dim=-1).reshape(-1, 3)) | |
cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 | |
# st() | |
features = self.forward_points(planes, cube_grid) | |
# reshape into grid | |
features = { | |
k: v.reshape(N, grid_size, grid_size, grid_size, -1) | |
for k, v in features.items() | |
} | |
# st() | |
return features | |
def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): | |
# no need to fuse anymore | |
pass | |
def forward_vit_decoder(self, x, img_size=None): | |
# st() | |
return self.vit_decoder(x) | |
def vit_decode_backbone(self, latent, img_size): | |
return super().vit_decode_backbone(latent, img_size) | |
# ! flag2 | |
def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): | |
return super().vit_decode_postprocess(latent_from_vit, ret_dict) | |
def vae_reparameterization(self, latent, sample_posterior): | |
return super().vae_reparameterization(latent, sample_posterior) | |