Spaces:
Running
on
Zero
Running
on
Zero
from calendar import c | |
import imageio | |
import torchvision | |
import random | |
# import einops | |
import kornia | |
import einops | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .layers import RayEncoder, Transformer, PreNorm | |
from pdb import set_trace as st | |
from pathlib import Path | |
import math | |
from ldm.modules.attention import MemoryEfficientCrossAttention | |
from timm.models.vision_transformer import PatchEmbed | |
from ldm.modules.diffusionmodules.model import Encoder | |
from guided_diffusion import dist_util, logger | |
import point_cloud_utils as pcu | |
import pytorch3d.ops | |
from pytorch3d.ops.utils import masked_gather | |
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer | |
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras | |
from pytorch3d.structures import Pointclouds | |
from timm.models.vision_transformer import PatchEmbed, Mlp | |
from vit.vit_triplane import XYZPosEmbed | |
from utils.geometry import index, perspective | |
def approx_gelu(): | |
return nn.GELU(approximate="tanh") | |
class SRTConvBlock(nn.Module): | |
def __init__(self, idim, hdim=None, odim=None): | |
super().__init__() | |
if hdim is None: | |
hdim = idim | |
if odim is None: | |
odim = 2 * hdim | |
conv_kwargs = {'bias': False, 'kernel_size': 3, 'padding': 1} | |
self.layers = nn.Sequential( | |
nn.Conv2d(idim, hdim, stride=1, **conv_kwargs), nn.ReLU(), | |
nn.Conv2d(hdim, odim, stride=2, **conv_kwargs), nn.ReLU()) | |
def forward(self, x): | |
return self.layers(x) | |
class SRTEncoder(nn.Module): | |
""" Scene Representation Transformer Encoder, as presented in the SRT paper at CVPR 2022 (caveats below)""" | |
def __init__(self, | |
num_conv_blocks=4, | |
num_att_blocks=10, | |
pos_start_octave=0, | |
scale_embeddings=False): | |
super().__init__() | |
self.ray_encoder = RayEncoder(pos_octaves=15, | |
pos_start_octave=pos_start_octave, | |
ray_octaves=15) | |
conv_blocks = [SRTConvBlock(idim=183, hdim=96)] | |
cur_hdim = 192 | |
for i in range(1, num_conv_blocks): | |
conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) | |
cur_hdim *= 2 | |
self.conv_blocks = nn.Sequential(*conv_blocks) | |
self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) | |
# Original SRT initializes with stddev=1/math.sqrt(d). | |
# But model initialization likely also differs between torch & jax, and this worked, so, eh. | |
embedding_stdev = (1. / math.sqrt(768)) if scale_embeddings else 1. | |
self.pixel_embedding = nn.Parameter( | |
torch.randn(1, 768, 15, 20) * embedding_stdev) | |
self.canonical_camera_embedding = nn.Parameter( | |
torch.randn(1, 1, 768) * embedding_stdev) | |
self.non_canonical_camera_embedding = nn.Parameter( | |
torch.randn(1, 1, 768) * embedding_stdev) | |
# SRT as in the CVPR paper does not use actual self attention, but a special type: | |
# the current features in the Nth layer don't self-attend, but they | |
# always attend into the initial patch embedding (i.e., the output of | |
# the CNN). SRT further used post-normalization rather than | |
# pre-normalization. Since then though, in OSRT, pre-norm and regular | |
# self-attention was found to perform better overall. So that's what | |
# we do here, though it may be less stable under some circumstances. | |
self.transformer = Transformer(768, | |
depth=num_att_blocks, | |
heads=12, | |
dim_head=64, | |
mlp_dim=1536, | |
selfatt=True) | |
def forward(self, images, camera_pos, rays): | |
""" | |
Args: | |
images: [batch_size, num_images, 3, height, width]. | |
Assume the first image is canonical - shuffling happens in the data loader. | |
camera_pos: [batch_size, num_images, 3] | |
rays: [batch_size, num_images, height, width, 3] | |
Returns: | |
scene representation: [batch_size, num_patches, channels_per_patch] | |
""" | |
batch_size, num_images = images.shape[:2] | |
x = images.flatten(0, 1) | |
camera_pos = camera_pos.flatten(0, 1) | |
rays = rays.flatten(0, 1) | |
canonical_idxs = torch.zeros(batch_size, num_images) | |
canonical_idxs[:, 0] = 1 | |
canonical_idxs = canonical_idxs.flatten( | |
0, 1).unsqueeze(-1).unsqueeze(-1).to(x) | |
camera_id_embedding = canonical_idxs * self.canonical_camera_embedding + \ | |
(1. - canonical_idxs) * self.non_canonical_camera_embedding | |
ray_enc = self.ray_encoder(camera_pos, rays) | |
x = torch.cat((x, ray_enc), 1) | |
x = self.conv_blocks(x) | |
x = self.per_patch_linear(x) | |
height, width = x.shape[2:] | |
x = x + self.pixel_embedding[:, :, :height, :width] | |
x = x.flatten(2, 3).permute(0, 2, 1) | |
x = x + camera_id_embedding | |
patches_per_image, channels_per_patch = x.shape[1:] | |
x = x.reshape(batch_size, num_images * patches_per_image, | |
channels_per_patch) | |
x = self.transformer(x) | |
return x | |
class ImprovedSRTEncoder(nn.Module): | |
""" | |
Scene Representation Transformer Encoder with the improvements from Appendix A.4 in the OSRT paper. | |
""" | |
def __init__(self, | |
num_conv_blocks=3, | |
num_att_blocks=5, | |
pos_start_octave=0): | |
super().__init__() | |
self.ray_encoder = RayEncoder(pos_octaves=15, | |
pos_start_octave=pos_start_octave, | |
ray_octaves=15) | |
conv_blocks = [SRTConvBlock(idim=183, hdim=96)] | |
cur_hdim = 192 | |
for i in range(1, num_conv_blocks): | |
conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) | |
cur_hdim *= 2 | |
self.conv_blocks = nn.Sequential(*conv_blocks) | |
self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) | |
self.transformer = Transformer(768, | |
depth=num_att_blocks, | |
heads=12, | |
dim_head=64, | |
mlp_dim=1536, | |
selfatt=True) | |
def forward(self, images, camera_pos, rays): | |
""" | |
Args: | |
images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. | |
camera_pos: [batch_size, num_images, 3] | |
rays: [batch_size, num_images, height, width, 3] | |
Returns: | |
scene representation: [batch_size, num_patches, channels_per_patch] | |
""" | |
batch_size, num_images = images.shape[:2] | |
x = images.flatten(0, 1) | |
camera_pos = camera_pos.flatten(0, 1) | |
rays = rays.flatten(0, 1) | |
ray_enc = self.ray_encoder(camera_pos, rays) | |
x = torch.cat((x, ray_enc), 1) | |
x = self.conv_blocks(x) | |
x = self.per_patch_linear(x) | |
x = x.flatten(2, 3).permute(0, 2, 1) | |
patches_per_image, channels_per_patch = x.shape[1:] | |
x = x.reshape(batch_size, num_images * patches_per_image, | |
channels_per_patch) | |
x = self.transformer(x) | |
return x | |
class ImprovedSRTEncoderVAE(nn.Module): | |
""" | |
Modified from ImprovedSRTEncoder | |
1. replace conv_blocks to timm embedder | |
2. replace ray_PE with Plucker coordinate | |
3. add xformers/flash for transformer attention | |
""" | |
def __init__( | |
self, | |
*, | |
ch, | |
out_ch, | |
ch_mult=(1, 2, 4, 8), | |
num_res_blocks, | |
attn_resolutions, | |
dropout=0.0, | |
resamp_with_conv=True, | |
in_channels, | |
resolution, | |
z_channels, | |
double_z=True, | |
num_frames=4, | |
num_att_blocks=5, | |
tx_dim=768, | |
num_heads=12, | |
mlp_ratio=2, # denoted by srt | |
patch_size=16, | |
decomposed=False, | |
**kwargs): | |
super().__init__() | |
# self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, | |
# ray_octaves=15) | |
# conv_blocks = [SRTConvBlock(idim=183, hdim=96)] | |
# cur_hdim = 192 | |
# for i in range(1, num_conv_blocks): | |
# conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) | |
# cur_hdim *= 2 | |
self.num_frames = num_frames | |
self.embed_dim = tx_dim | |
self.embedder = PatchEmbed( | |
img_size=256, | |
patch_size=patch_size, | |
# patch_size=8, # compare the performance | |
in_chans=in_channels, | |
embed_dim=self.embed_dim, | |
norm_layer=None, | |
flatten=True, | |
bias=True, | |
) # downsample f=16 here. | |
# same configuration as vit-B | |
if not decomposed: | |
self.transformer = Transformer( | |
self.embed_dim, # 12 * 64 = 768 | |
depth=num_att_blocks, | |
heads=num_heads, | |
mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default | |
) | |
else: | |
self.transformer_selfattn = Transformer( | |
self.embed_dim, # 12 * 64 = 768 | |
depth=1, | |
heads=num_heads, | |
mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default | |
) | |
self.transformer = Transformer( | |
self.embed_dim, # 12 * 64 = 768 | |
# depth=num_att_blocks-1, | |
depth=num_att_blocks, | |
heads=num_heads, | |
mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default | |
) | |
# to a compact latent, with CA | |
# query_dim = 4*(1+double_z) | |
query_dim = 12 * (1 + double_z | |
) # for high-quality 3D encoding, follow direct3D | |
self.latent_embedding = nn.Parameter( | |
torch.randn(1, 32 * 32 * 3, query_dim)) | |
self.readout_ca = MemoryEfficientCrossAttention( | |
query_dim, | |
self.embed_dim, | |
) | |
def forward_tx(self, x): | |
x = self.transformer(x) # B VL C | |
# ? 3DPE | |
x = self.readout_ca(self.latent_embedding.repeat(x.shape[0], 1, 1), x) | |
# ! reshape to 3D latent here. how to make the latent 3D-aware? Later. Performance first. | |
x = einops.rearrange(x, 'B (N H W) C -> B C (N H) W', H=32, W=32, N=3) | |
return x | |
def forward(self, x, **kwargs): | |
""" | |
Args: | |
images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. | |
camera_pos: [batch_size, num_images, 3] | |
rays: [batch_size, num_images, height, width, 3] | |
Returns: | |
scene representation: [batch_size, num_patches, channels_per_patch] | |
""" | |
x = self.embedder(x) # B L C | |
x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) | |
x = self.forward_tx(x) | |
return x | |
# ! ablation the srt design | |
class ImprovedSRTEncoderVAE_K8(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(patch_size=8, **kwargs) | |
class ImprovedSRTEncoderVAE_L6(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(num_att_blocks=6, **kwargs) | |
class ImprovedSRTEncoderVAE_L5_vitl(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(num_att_blocks=5, tx_dim=1024, num_heads=16, **kwargs) | |
class ImprovedSRTEncoderVAE_mlp_ratio4(ImprovedSRTEncoderVAE | |
): # ! by default now | |
def __init__(self, **kwargs): | |
super().__init__(mlp_ratio=4, **kwargs) | |
class ImprovedSRTEncoderVAE_mlp_ratio4_decomposed( | |
ImprovedSRTEncoderVAE_mlp_ratio4): | |
def __init__(self, **kwargs): | |
super().__init__(decomposed=True, **kwargs) # just decompose tx | |
def forward(self, x, **kwargs): | |
""" | |
Args: | |
images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. | |
camera_pos: [batch_size, num_images, 3] | |
rays: [batch_size, num_images, height, width, 3] | |
Returns: | |
scene representation: [batch_size, num_patches, channels_per_patch] | |
""" | |
x = self.embedder(x) # B L C | |
# x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) | |
x = self.transformer_selfattn(x) | |
x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) | |
x = self.forward_tx(x) | |
return x | |
class ImprovedSRTEncoderVAE_mlp_ratio4_f8(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(mlp_ratio=4, patch_size=8, **kwargs) | |
class ImprovedSRTEncoderVAE_mlp_ratio4_f8_L6(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(mlp_ratio=4, patch_size=8, num_att_blocks=6, **kwargs) | |
class ImprovedSRTEncoderVAE_mlp_ratio4_L6(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(mlp_ratio=4, num_att_blocks=6, **kwargs) | |
# ! an SD VAE with one SRT attention + one CA attention for KL | |
class HybridEncoder(Encoder): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
# st() | |
self.srt = ImprovedSRTEncoderVAE( | |
**kwargs, | |
# num_frames=4, | |
num_att_blocks=1, # only one layer required | |
tx_dim=self.conv_out.weight.shape[1], | |
num_heads=8, # 256 / 64 | |
mlp_ratio=4, # denoted by srt | |
# patch_size=16, | |
) | |
del self.srt.embedder # use original | |
self.conv_out = nn.Identity() | |
def forward(self, x, **kwargs): | |
x = super().forward(x) | |
x = einops.rearrange(x, | |
'(B V) C H W -> B (V H W) C', | |
V=self.srt.num_frames) | |
x = self.srt.forward_tx(x) | |
return x | |
class ImprovedSRTEncoderVAE_mlp_ratio4_heavyPatchify(ImprovedSRTEncoderVAE): | |
def __init__(self, **kwargs): | |
super().__init__(mlp_ratio=4, **kwargs) | |
del self.embedder | |
conv_blocks = [SRTConvBlock(idim=10, hdim=48)] # match the ViT-B dim | |
cur_hdim = 48 * 2 | |
for i in range(1, | |
4): # f=16 still. could reduce attention layers by one? | |
conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) | |
cur_hdim *= 2 | |
self.embedder = nn.Sequential(*conv_blocks) | |
def forward(self, x, **kwargs): | |
""" | |
Args: | |
images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. | |
camera_pos: [batch_size, num_images, 3] | |
rays: [batch_size, num_images, height, width, 3] | |
Returns: | |
scene representation: [batch_size, num_patches, channels_per_patch] | |
""" | |
x = self.embedder(x) # B C H W | |
x = einops.rearrange(x, | |
'(B V) C H W -> B (V H W) C', | |
V=self.num_frames) | |
x = self.transformer(x) # B VL C | |
# ? 3DPE | |
x = self.readout_ca(self.latent_embedding.repeat(x.shape[0], 1, 1), x) | |
# ! reshape to 3D latent here. how to make the latent 3D-aware? Later. Performance first. | |
x = einops.rearrange(x, 'B (N H W) C -> B C (N H) W', H=32, W=32, N=3) | |
return x | |
class HybridEncoderPCDStructuredLatent(Encoder): | |
def __init__(self, num_frames, latent_num=768, **kwargs): | |
super().__init__(**kwargs) | |
# st() | |
self.num_frames = num_frames | |
tx_dim = self.conv_out.weight.shape[1] # after encoder mid_layers | |
self.srt = ImprovedSRTEncoderVAE( | |
**kwargs, | |
# num_frames=4, | |
num_att_blocks=3, # only one layer required | |
tx_dim=tx_dim, | |
num_heads=8, # 256 / 64 | |
mlp_ratio=4, # denoted by srt | |
) | |
del self.srt.embedder, self.srt.readout_ca, self.srt.latent_embedding # use original | |
# self.box_pool2d = kornia.filters.BlurPool2D(kernel_size=(8,8), stride=8) | |
self.box_pool2d = kornia.filters.BlurPool2D(kernel_size=(8, 8), | |
stride=8) | |
# self.pool2d = kornia.filters.MedianBlur(kernel_size=(8,8), stride=8) | |
self.agg_ca = MemoryEfficientCrossAttention( | |
tx_dim, | |
tx_dim, | |
qk_norm=True, # as in vit-22B | |
) | |
self.spatial_token_reshape = lambda x: einops.rearrange( | |
x, '(B V) C H W -> B (V H W) C', V=self.num_frames) | |
self.latent_num = latent_num # 768 * 3 by default | |
self.xyz_pos_embed = XYZPosEmbed(tx_dim) | |
# ! VAE part | |
self.conv_out = nn.Identity() | |
self.Mlp_out = PreNorm( | |
tx_dim, # ! add PreNorm before VAE reduction, stablize training. | |
Mlp( | |
in_features=tx_dim, # reduce dim | |
hidden_features=tx_dim, | |
out_features=self.z_channels * 2, # double_z | |
act_layer=approx_gelu, | |
drop=0)) | |
self.ca_no_pcd = False | |
self.pixel_aligned_query = False | |
self.pc2 = True | |
if self.pc2: | |
# https://github.com/lukemelas/projection-conditioned-point-cloud-diffusion/blob/64fd55a0d00b52735cf02e11c5112374c7104ece/experiments/model/projection_model.py#L87 | |
# Save rasterization settings | |
raster_point_radius: float = 0.0075 # point size | |
image_size = 512 # ? hard coded | |
raster_points_per_pixel: int = 1 | |
bin_size: int = 0 | |
self.raster_settings = PointsRasterizationSettings( | |
image_size=(image_size, image_size), | |
radius=raster_point_radius, | |
points_per_pixel=raster_points_per_pixel, | |
bin_size=bin_size, | |
) | |
self.scale_factor = 1 | |
# def _process_token_xyz(self, token_xyz, h): | |
# # pad zero xyz points to reasonable value. | |
# nonzero_mask = (token_xyz != 0).all(dim=2) # Shape: (B, N) | |
# non_zero_token_xyz = token_xyz[nonzero_mask] | |
# non_zero_token_h = h[nonzero_mask] | |
# # for loop to get foreground points of each instance | |
# # TODO, accelerate with vmap | |
# # No, directly use sparse pcd as input as surface points? fps sampling 768 from 4096 points. | |
# # All points here should not have 0 xyz. | |
# # fg_token_xyz = [] | |
# # for idx in range(token_xyz.shape[1]): | |
# fps_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
# non_zero_token_xyz, K=self.latent_num) # B self.latent_num | |
# # pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# # pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# pcu.save_mesh_v(f'token_xyz3.ply', token_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# # xyz = self.spatial_token_reshape(xyz) | |
# # pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# st() | |
# query_h = masked_gather(non_zero_token_h, fps_idx) # torch.gather with dim expansion | |
# return query_h, fps_xyz | |
def _process_token_xyz(self, pcd, pcd_h): | |
# ! 16x uniform downsample before FPS. | |
# rand_start_pt = random.randint(0,16) | |
# query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
# pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=True) # B self.latent_num | |
# query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], fps_idx) # torch.gather with dim expansion | |
# ! fps very slow on high-res pcd | |
query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
pcd, K=self.latent_num, | |
# random_start_point=False) # B self.latent_num | |
random_start_point=True) # B self.latent_num | |
query_pcd_h = masked_gather(pcd_h, | |
fps_idx) # torch.gather with dim expansion | |
# pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# pcu.save_mesh_v(f'query_pcd_xyz.ply', query_pcd_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# pcu.save_mesh_v(f'pcd_xyz.ply', pcd[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# xyz = self.spatial_token_reshape(xyz) | |
# pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
return query_pcd_h, query_pcd_xyz | |
def forward(self, x, pcd, **kwargs): | |
# def forward(self, x, num_frames=None): | |
assert x.shape[1] == 15 # rgb(3),normal(3),plucker_ray(6),xyz(3) | |
xyz = x[:, -3:, ...] # for fps downsampling | |
# 0. retrieve VAE tokens | |
h = super().forward( | |
x, num_frames=self.num_frames | |
) # ! support data augmentation, different FPS different latent corresponding to the same instance? | |
# st() | |
# pcu.save_mesh_v(f'{Path(logger.get_dir())}/anchor_all.ply',pcd[0].float().detach().cpu().numpy()) | |
# ! add 3D PE. | |
# 1. unproj 2D tokens to 3D | |
token_xyz = xyz[..., 4::8, 4::8] | |
if self.pixel_aligned_query: | |
# h = self.spatial_token_reshape(h) # V frames merge to a single latent here. | |
# h = h + self.xyz_pos_embed(token_xyz) # directly add PE to h here. | |
# # ! PE over surface fps-pcd | |
# pcd_h = self.xyz_pos_embed(pcd) # directly add PE to h here. | |
# 2. fps sampling surface as pcd-structured latent. | |
h, query_pcd_xyz = self._process_token_xyz( | |
pcd, token_xyz, h, c=kwargs.get('c'), | |
x=x) # aggregate with pixel-aligned operation. | |
elif self.pc2: # rasterize the point cloud to multi-view feature maps | |
# https://github.com/lukemelas/projection-conditioned-point-cloud-diffusion/blob/64fd55a0d00b52735cf02e11c5112374c7104ece/experiments/model/projection_model.py#L128 | |
# ! prepare the features before projection | |
token_xyz = self.spatial_token_reshape(token_xyz) | |
h = self.spatial_token_reshape( | |
h) # V frames merge to a single latent here. | |
# directly add PE to h here. | |
h = h + self.xyz_pos_embed(token_xyz) # h: B L C | |
# ! prepare pytorch3d camera | |
c = kwargs['c'] # gs_format dict | |
focal_length = c['orig_pose'][..., 16:17] # B V 1 | |
img_h, img_w = x.shape[-2:] | |
R, T = c['R'], c['T'] # B V 3 3, B V 3 | |
# ! bs=1 test. will merge B, V later for parallel compute. | |
V = focal_length.shape[1] | |
principal_point = torch.zeros(V, 2) | |
img_size = torch.Tensor([img_h, img_w]).unsqueeze(0).repeat_interleave(V, 0).to(focal_length) | |
camera = PerspectiveCameras(focal_length=focal_length[0],principal_point=principal_point, R=R[0], T=T[0], image_size=img_size) | |
# camera = PerspectiveCameras(focal_length=focal_length, R=R, T=T, image_size=(img_h, img_w)) | |
# !Create rasterizer | |
rasterizer = PointsRasterizer(cameras=camera.to(pcd.device), raster_settings=self.raster_settings) | |
fragments = rasterizer(Pointclouds(pcd[0:1].repeat_interleave(V, 0))) # (B, H, W, R) | |
fragments_idx: Tensor = fragments.idx.long() | |
visible_pixels = (fragments_idx > -1) # (B, H, W, R) | |
view_idx = 0 # Index of the viewpoint | |
# (Pdb) fragments.zbuf.shape | |
# torch.Size([8, 512, 512, 1]) | |
# depth_image = fragments.zbuf[0, ..., 0].cpu().numpy() # Take the nearest point's depth | |
# depth_image = (depth_image - depth_image.min()) / (depth_image.max()-depth_image.min()) | |
# imageio.imwrite('tmp/depth.jpg', (depth_image*255.0).astype(np.uint8)) | |
# st() | |
points_to_visible_pixels = fragments_idx[visible_pixels] | |
# ! visualize the results | |
# for debug | |
normal = x[:, 3:6, ...] | |
normal_map = (normal * 127.5 + 127.5).float().to( | |
torch.uint8) # BV 3 H W | |
st() | |
pass | |
else: | |
token_xyz = self.spatial_token_reshape(token_xyz) | |
h = self.spatial_token_reshape( | |
h) # V frames merge to a single latent here. | |
h = h + self.xyz_pos_embed(token_xyz) # directly add PE to h here. | |
# ! PE over surface fps-pcd | |
pcd_h = self.xyz_pos_embed(pcd) # directly add PE to h here. | |
# 2. fps sampling surface as pcd-structured latent. | |
query_pcd_h, query_pcd_xyz = self._process_token_xyz(pcd, pcd_h) | |
# 2.5 Cross attention to aggregate from all tokens. | |
if self.ca_no_pcd: | |
h = self.agg_ca(query_pcd_h, h) | |
else: | |
h = self.agg_ca( | |
query_pcd_h, torch.cat([h, pcd_h], dim=1) | |
) # cross attend to aggregate info from both vae-h and pcd-h | |
# 3. add vit TX (5 layers, concat xyz-PE) | |
# h = h + self.xyz_pos_embed(fps_xyz) # TODO, add PE of query pts. directly add to h here. | |
h = self.srt.transformer(h) # B L C | |
h = self.Mlp_out(h) # equivalent to conv_out, 256 -> 8 in sd-VAE | |
# h = einops.rearrange(h, 'B L C -> B C L') # for VAE compat | |
return { | |
'h': h, | |
'query_pcd_xyz': query_pcd_xyz | |
} # h_0, point cloud-structured latent space. For VAE later. | |
class HybridEncoderPCDStructuredLatentUniformFPS( | |
HybridEncoderPCDStructuredLatent): | |
def __init__(self, num_frames, latent_num=768, **kwargs): | |
super().__init__(num_frames, latent_num, **kwargs) | |
self.ca_no_pcd = True # check speed up ratio | |
def _process_token_xyz(self, pcd, pcd_h): | |
# ! 16x uniform downsample before FPS. | |
rand_start_pt = random.randint(0, 16) | |
# rand_start_pt = 0 | |
query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
# pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=False) # B self.latent_num | |
pcd[:, rand_start_pt::16], | |
K=self.latent_num, | |
random_start_point=True) # B self.latent_num | |
query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], | |
fps_idx) # torch.gather with dim expansion | |
# st() | |
# ! fps very slow on high-res pcd | |
# query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
# pcd, K=self.latent_num, random_start_point=True) # B self.latent_num | |
# query_pcd_h = masked_gather(pcd_h, fps_idx) # torch.gather with dim expansion | |
# pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling | |
# pcu.save_mesh_v(f'query_pcd_xyz.ply', query_pcd_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# pcu.save_mesh_v(f'pcd_xyz.ply', pcd[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
# xyz = self.spatial_token_reshape(xyz) | |
# pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) | |
return query_pcd_h, query_pcd_xyz | |
class HybridEncoderPCDStructuredLatentSNoPCD(HybridEncoderPCDStructuredLatent): | |
def __init__(self, num_frames, latent_num=768, **kwargs): | |
super().__init__(num_frames, latent_num, **kwargs) | |
self.ca_no_pcd = True | |
class HybridEncoderPCDStructuredLatentSNoPCD_PC2(HybridEncoderPCDStructuredLatentSNoPCD): | |
def __init__(self, num_frames, latent_num=768, **kwargs): | |
super().__init__(num_frames, latent_num, **kwargs) | |
self.pc2 = True | |
class HybridEncoderPCDStructuredLatentSNoPCD_PixelAlignedQuery( | |
HybridEncoderPCDStructuredLatent): | |
def __init__(self, num_frames, latent_num=768, **kwargs): | |
super().__init__(num_frames, latent_num, **kwargs) | |
self.ca_no_pcd = True | |
self.pixel_aligned_query = True | |
self.F = 4 # pixel-aligned query from nearest F views | |
del self.agg_ca # for average pooling now. | |
def _pcd_to_homo(self, pcd): | |
return torch.cat([pcd, torch.ones_like(pcd[..., 0:1])], -1) | |
# ! FPS sampling | |
def _process_token_xyz(self, pcd, token_xyz, h, c, x=None): | |
V = c['cam_pos'].shape[1] | |
# (Pdb) p c.keys() | |
# dict_keys(['source_cv2wT_quat', 'cam_view', 'cam_view_proj', 'cam_pos', 'tanfov', 'orig_pose', 'orig_c2w', 'orig_w2c']) | |
# (Pdb) p c['cam_view'].shape | |
# torch.Size([8, 9, 4, 4]) | |
# (Pdb) p c['cam_pos'].shape | |
# torch.Size([8, 9, 3]) | |
# ! 16x uniform downsample before FPS. | |
# rand_start_pt = random.randint(0,16) | |
# query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
# pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=True) # B self.latent_num | |
# query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], fps_idx) # torch.gather with dim expansion | |
# ! fps very slow on high-res pcd, but better. | |
# ''' | |
query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( | |
pcd, K=self.latent_num, random_start_point=True) # B self.latent_num | |
# query_pcd_h = masked_gather(pcd_h, fps_idx) # torch.gather with dim expansion | |
# ''' | |
# ! use unprojected xyz for pixel-aligned projection check | |
# query_pcd_xyz = self.spatial_token_reshape(token_xyz) | |
B, N = query_pcd_xyz.shape[:2] | |
normal = x[:, 3:6, ...] | |
normal_map = (normal * 127.5 + 127.5).float().to( | |
torch.uint8) # BV 3 H W | |
normal_map = einops.rearrange(normal_map, | |
'(B V) C H W -> B V C H W', | |
B=B, | |
V=V).detach().cpu() # V C H W | |
img_size = normal_map.shape[-1] | |
# ! ====== single-view debug here | |
for b in range(c['orig_w2c'].shape[0]): | |
for V in range(c['orig_w2c'].shape[1]): | |
selected_normal = normal_map[b, V] | |
proj_point = c['orig_w2c'][b, V] @ self._pcd_to_homo(query_pcd_xyz[b]).permute(1, 0) | |
proj_point[:2, ...] /= proj_point[2, ...] | |
proj_point[2, ...] = 1 # homo | |
intrin = c['orig_intrin'][b, V] | |
proj_point = intrin @ proj_point[:3] | |
proj_point = proj_point.permute(1,0)[..., :2] # 768 4 | |
# st() | |
# proj_point = c['cam_view_proj'][b, V] @ self._pcd_to_homo(query_pcd_xyz[b]).permute(1, 0) | |
# plot proj_point and save | |
for uv_idx in range(proj_point.shape[0]): | |
# uv = proj_point[uv_idx] * 127.5 + 127.5 | |
# uv = proj_point[uv_idx] * 127.5 + 127.5 | |
uv = proj_point[uv_idx] * img_size | |
x, y = int(uv[0].clip(0, img_size)), int(uv[1].clip(0, img_size)) | |
selected_normal[:, max(y - 1, 0):min(y + 1, img_size), | |
max(x - 1, 0):min(x + 1, img_size)] = torch.Tensor([ | |
255, 0, 0 | |
]).reshape(3, 1, 1).to(selected_normal) # set to red | |
torchvision.utils.save_image(selected_normal.float(), | |
f'tmp/pifu_normal_{b}_{V}.jpg', | |
normalize=True, | |
value_range=(0, 255)) | |
st() | |
pass | |
st() | |
# ! ====== single-view debug done | |
# ! project pcd to each views | |
batched_query_pcd = einops.repeat(self._pcd_to_homo(query_pcd_xyz), | |
'B N C -> (B V N) C 1', | |
V=V) | |
batched_cam_view_proj = einops.repeat(c['cam_view_proj'], | |
'B V H W -> (B V N) H W', | |
N=N) | |
batched_proj_uv = einops.rearrange( | |
(batched_cam_view_proj @ batched_query_pcd), | |
'(B V N) L 1 -> (B V) L N', | |
B=B, | |
V=V, | |
N=N) # BV 4 N | |
batched_proj_uv = batched_proj_uv[..., :2, :] # BV N 2 | |
# draw projected UV coordinate on 2d normal map | |
# idx_to_vis = 15 * 32 + 16 # middle of the img | |
# idx_to_vis = 16 * 6 + 15 * 32 + 16 # middle of the img | |
idx_to_vis = 0 # use fps points here | |
# st() | |
selected_proj_uv = einops.rearrange(batched_proj_uv, | |
'(B V) C N -> B V C N', | |
B=B, | |
V=V, | |
N=N)[0, ..., | |
idx_to_vis] # V 2 N -> V 2 | |
# selected_normal = einops.rearrange(normal_map, | |
# '(B V) C H W -> B V C H W', | |
# B=B, | |
# V=V)[0].detach().cpu() # V C H W | |
for uv_idx in range(selected_proj_uv.shape[0]): | |
uv = selected_proj_uv[uv_idx] * 127.5 + 127.5 | |
x, y = int(uv[0].clip(0, 255)), int(uv[1].clip(0, 255)) | |
selected_normal[uv_idx, :, | |
max(y - 5, 0):min(y + 5, 255), | |
max(x - 5, 0):min(x + 5, 255)] = torch.Tensor([ | |
255, 0, 0 | |
]).reshape(3, 1, | |
1).to(selected_normal) # set to red | |
# selected_normal[uv_idx, :, max(y-5, 0):min(y+5, 255), max(x-5,0):min(x+5,255)] = torch.Tensor([255,0,0]).to(selected_normal) # set to red | |
# st() | |
torchvision.utils.save_image(selected_normal.float(), | |
'pifu_normal.jpg', | |
normalize=True, | |
value_range=(0, 255)) | |
st() | |
pass | |
# ! grid sample | |
query_pcd_h = index( | |
h, batched_proj_uv) # h: (B V) C H W, uv: (B V) N 2 -> BV 256 768 | |
query_pcd_h_to_gather = einops.rearrange(query_pcd_h, | |
'(B V) C N -> B N V C', | |
B=B, | |
V=V, | |
N=N) | |
# ! find nearest F views | |
_, knn_idx, _ = pytorch3d.ops.knn_points( | |
query_pcd_xyz, c['cam_pos'], K=self.F, | |
return_nn=False) # knn_idx: B N F | |
knn_idx_expanded = knn_idx[..., None].expand( | |
-1, -1, -1, query_pcd_h_to_gather.shape[-1]) # B N F -> B N F C | |
knn_pcd_h = torch.gather( | |
query_pcd_h_to_gather, dim=2, | |
index=knn_idx_expanded) # torch.Size([8, 768, 4, 256]) | |
# average pooling knn feature. | |
query_pcd_h = knn_pcd_h.mean(dim=2) | |
# add PE | |
pcd_h = self.xyz_pos_embed(query_pcd_xyz) # pcd_h as PE feature. | |
query_pcd_h = query_pcd_h + pcd_h | |
# TODO: QKV aggregation with pcd_h as q, query_pcd_h as kv. Requires gather? | |
'''not used; binary mask for aggregation. | |
# * mask idx not used anymore. torch.gather() instead, more flexible. | |
# knn_idx_mask = torch.zeros((B,N,V), device=knn_idx.device) | |
# knn_idx_mask.scatter_(dim=2, index=knn_idx, src=torch.ones_like(knn_idx_mask)) # ! B N V | |
# try gather | |
# gather_idx = einops.rearrange(knn_idx_mask, 'B N V -> B N V 1').bool() | |
# query_pcd_h = einops.rearrange(query_pcd_h, "(B V) C N -> B N V C", B=pcd_h.shape[0], N=self.latent_num, V=V) # torch.Size([8, 768, 4, 256]) | |
# ! apply KNN mask and average the feature. | |
# query_pcd_h = einops.reduce(query_pcd_h * knn_idx_mask.unsqueeze(-1), 'B N V C -> B N C', 'sum') / self.F # B 768 256. average pooling aggregated feature, like in pifu. | |
''' | |
''' | |
# pixel-aligned projection, not efficient enough. | |
knn_cam_view_proj = pytorch3d.ops.knn_gather(einops.rearrange(c['cam_view_proj'], 'B V H W-> B V (H W)'), knn_idx) # get corresponding cam_view_projection matrix (P matrix) | |
knn_cam_view_proj = einops.rearrange(knn_cam_view_proj, 'B N F (H W) -> (B N F) H W', H=4, W=4) # for matmul. H=W=4 here, P matrix. | |
batched_query_pcd = einops.repeat(self._pcd_to_homo(query_pcd_xyz), 'B N C -> (B N F) C 1', F=self.F) | |
xyz = knn_cam_view_proj @ batched_query_pcd # BNF 4 1 | |
# st() | |
knn_spatial_feat = pytorch3d.ops.knn_gather(einops.rearrange(h, '(B V) C H W -> B V (C H W)', V=self.num_frames), knn_idx) # get corresponding feat for grid_sample | |
knn_spatial_feat = einops.rearrange(knn_spatial_feat, 'B N F (C H W) -> (B N F) C H W', C=h.shape[-3], H=h.shape[-2], W=h.shape[-1]) | |
''' | |
# grid_sample | |
# https://github.com/shunsukesaito/PIFu/blob/f0a9c99ef887e1eb360e865a87aa5f166231980e/lib/geometry.py#L15 | |
# average pooling multi-view extracted information | |
# return query_pcd_h, query_pcd_xyz | |
return query_pcd_h, query_pcd_xyz | |