from calendar import c import imageio import torchvision import random # import einops import kornia import einops import numpy as np import torch import torch.nn as nn from .layers import RayEncoder, Transformer, PreNorm from pdb import set_trace as st from pathlib import Path import math from ldm.modules.attention import MemoryEfficientCrossAttention from timm.models.vision_transformer import PatchEmbed from ldm.modules.diffusionmodules.model import Encoder from guided_diffusion import dist_util, logger import point_cloud_utils as pcu import pytorch3d.ops from pytorch3d.ops.utils import masked_gather from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.structures import Pointclouds from timm.models.vision_transformer import PatchEmbed, Mlp from vit.vit_triplane import XYZPosEmbed from utils.geometry import index, perspective def approx_gelu(): return nn.GELU(approximate="tanh") class SRTConvBlock(nn.Module): def __init__(self, idim, hdim=None, odim=None): super().__init__() if hdim is None: hdim = idim if odim is None: odim = 2 * hdim conv_kwargs = {'bias': False, 'kernel_size': 3, 'padding': 1} self.layers = nn.Sequential( nn.Conv2d(idim, hdim, stride=1, **conv_kwargs), nn.ReLU(), nn.Conv2d(hdim, odim, stride=2, **conv_kwargs), nn.ReLU()) def forward(self, x): return self.layers(x) class SRTEncoder(nn.Module): """ Scene Representation Transformer Encoder, as presented in the SRT paper at CVPR 2022 (caveats below)""" def __init__(self, num_conv_blocks=4, num_att_blocks=10, pos_start_octave=0, scale_embeddings=False): super().__init__() self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, ray_octaves=15) conv_blocks = [SRTConvBlock(idim=183, hdim=96)] cur_hdim = 192 for i in range(1, num_conv_blocks): conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) cur_hdim *= 2 self.conv_blocks = nn.Sequential(*conv_blocks) self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) # Original SRT initializes with stddev=1/math.sqrt(d). # But model initialization likely also differs between torch & jax, and this worked, so, eh. embedding_stdev = (1. / math.sqrt(768)) if scale_embeddings else 1. self.pixel_embedding = nn.Parameter( torch.randn(1, 768, 15, 20) * embedding_stdev) self.canonical_camera_embedding = nn.Parameter( torch.randn(1, 1, 768) * embedding_stdev) self.non_canonical_camera_embedding = nn.Parameter( torch.randn(1, 1, 768) * embedding_stdev) # SRT as in the CVPR paper does not use actual self attention, but a special type: # the current features in the Nth layer don't self-attend, but they # always attend into the initial patch embedding (i.e., the output of # the CNN). SRT further used post-normalization rather than # pre-normalization. Since then though, in OSRT, pre-norm and regular # self-attention was found to perform better overall. So that's what # we do here, though it may be less stable under some circumstances. self.transformer = Transformer(768, depth=num_att_blocks, heads=12, dim_head=64, mlp_dim=1536, selfatt=True) def forward(self, images, camera_pos, rays): """ Args: images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical - shuffling happens in the data loader. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] Returns: scene representation: [batch_size, num_patches, channels_per_patch] """ batch_size, num_images = images.shape[:2] x = images.flatten(0, 1) camera_pos = camera_pos.flatten(0, 1) rays = rays.flatten(0, 1) canonical_idxs = torch.zeros(batch_size, num_images) canonical_idxs[:, 0] = 1 canonical_idxs = canonical_idxs.flatten( 0, 1).unsqueeze(-1).unsqueeze(-1).to(x) camera_id_embedding = canonical_idxs * self.canonical_camera_embedding + \ (1. - canonical_idxs) * self.non_canonical_camera_embedding ray_enc = self.ray_encoder(camera_pos, rays) x = torch.cat((x, ray_enc), 1) x = self.conv_blocks(x) x = self.per_patch_linear(x) height, width = x.shape[2:] x = x + self.pixel_embedding[:, :, :height, :width] x = x.flatten(2, 3).permute(0, 2, 1) x = x + camera_id_embedding patches_per_image, channels_per_patch = x.shape[1:] x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch) x = self.transformer(x) return x class ImprovedSRTEncoder(nn.Module): """ Scene Representation Transformer Encoder with the improvements from Appendix A.4 in the OSRT paper. """ def __init__(self, num_conv_blocks=3, num_att_blocks=5, pos_start_octave=0): super().__init__() self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, ray_octaves=15) conv_blocks = [SRTConvBlock(idim=183, hdim=96)] cur_hdim = 192 for i in range(1, num_conv_blocks): conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) cur_hdim *= 2 self.conv_blocks = nn.Sequential(*conv_blocks) self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) self.transformer = Transformer(768, depth=num_att_blocks, heads=12, dim_head=64, mlp_dim=1536, selfatt=True) def forward(self, images, camera_pos, rays): """ Args: images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] Returns: scene representation: [batch_size, num_patches, channels_per_patch] """ batch_size, num_images = images.shape[:2] x = images.flatten(0, 1) camera_pos = camera_pos.flatten(0, 1) rays = rays.flatten(0, 1) ray_enc = self.ray_encoder(camera_pos, rays) x = torch.cat((x, ray_enc), 1) x = self.conv_blocks(x) x = self.per_patch_linear(x) x = x.flatten(2, 3).permute(0, 2, 1) patches_per_image, channels_per_patch = x.shape[1:] x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch) x = self.transformer(x) return x class ImprovedSRTEncoderVAE(nn.Module): """ Modified from ImprovedSRTEncoder 1. replace conv_blocks to timm embedder 2. replace ray_PE with Plucker coordinate 3. add xformers/flash for transformer attention """ def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, num_frames=4, num_att_blocks=5, tx_dim=768, num_heads=12, mlp_ratio=2, # denoted by srt patch_size=16, decomposed=False, **kwargs): super().__init__() # self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, # ray_octaves=15) # conv_blocks = [SRTConvBlock(idim=183, hdim=96)] # cur_hdim = 192 # for i in range(1, num_conv_blocks): # conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) # cur_hdim *= 2 self.num_frames = num_frames self.embed_dim = tx_dim self.embedder = PatchEmbed( img_size=256, patch_size=patch_size, # patch_size=8, # compare the performance in_chans=in_channels, embed_dim=self.embed_dim, norm_layer=None, flatten=True, bias=True, ) # downsample f=16 here. # same configuration as vit-B if not decomposed: self.transformer = Transformer( self.embed_dim, # 12 * 64 = 768 depth=num_att_blocks, heads=num_heads, mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default ) else: self.transformer_selfattn = Transformer( self.embed_dim, # 12 * 64 = 768 depth=1, heads=num_heads, mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default ) self.transformer = Transformer( self.embed_dim, # 12 * 64 = 768 # depth=num_att_blocks-1, depth=num_att_blocks, heads=num_heads, mlp_dim=mlp_ratio * self.embed_dim, # 1536 by default ) # to a compact latent, with CA # query_dim = 4*(1+double_z) query_dim = 12 * (1 + double_z ) # for high-quality 3D encoding, follow direct3D self.latent_embedding = nn.Parameter( torch.randn(1, 32 * 32 * 3, query_dim)) self.readout_ca = MemoryEfficientCrossAttention( query_dim, self.embed_dim, ) def forward_tx(self, x): x = self.transformer(x) # B VL C # ? 3DPE x = self.readout_ca(self.latent_embedding.repeat(x.shape[0], 1, 1), x) # ! reshape to 3D latent here. how to make the latent 3D-aware? Later. Performance first. x = einops.rearrange(x, 'B (N H W) C -> B C (N H) W', H=32, W=32, N=3) return x def forward(self, x, **kwargs): """ Args: images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] Returns: scene representation: [batch_size, num_patches, channels_per_patch] """ x = self.embedder(x) # B L C x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) x = self.forward_tx(x) return x # ! ablation the srt design class ImprovedSRTEncoderVAE_K8(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(patch_size=8, **kwargs) class ImprovedSRTEncoderVAE_L6(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(num_att_blocks=6, **kwargs) class ImprovedSRTEncoderVAE_L5_vitl(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(num_att_blocks=5, tx_dim=1024, num_heads=16, **kwargs) class ImprovedSRTEncoderVAE_mlp_ratio4(ImprovedSRTEncoderVAE ): # ! by default now def __init__(self, **kwargs): super().__init__(mlp_ratio=4, **kwargs) class ImprovedSRTEncoderVAE_mlp_ratio4_decomposed( ImprovedSRTEncoderVAE_mlp_ratio4): def __init__(self, **kwargs): super().__init__(decomposed=True, **kwargs) # just decompose tx def forward(self, x, **kwargs): """ Args: images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] Returns: scene representation: [batch_size, num_patches, channels_per_patch] """ x = self.embedder(x) # B L C # x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) x = self.transformer_selfattn(x) x = einops.rearrange(x, '(B V) L C -> B (V L) C', V=self.num_frames) x = self.forward_tx(x) return x class ImprovedSRTEncoderVAE_mlp_ratio4_f8(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(mlp_ratio=4, patch_size=8, **kwargs) class ImprovedSRTEncoderVAE_mlp_ratio4_f8_L6(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(mlp_ratio=4, patch_size=8, num_att_blocks=6, **kwargs) class ImprovedSRTEncoderVAE_mlp_ratio4_L6(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(mlp_ratio=4, num_att_blocks=6, **kwargs) # ! an SD VAE with one SRT attention + one CA attention for KL class HybridEncoder(Encoder): def __init__(self, **kwargs): super().__init__(**kwargs) # st() self.srt = ImprovedSRTEncoderVAE( **kwargs, # num_frames=4, num_att_blocks=1, # only one layer required tx_dim=self.conv_out.weight.shape[1], num_heads=8, # 256 / 64 mlp_ratio=4, # denoted by srt # patch_size=16, ) del self.srt.embedder # use original self.conv_out = nn.Identity() def forward(self, x, **kwargs): x = super().forward(x) x = einops.rearrange(x, '(B V) C H W -> B (V H W) C', V=self.srt.num_frames) x = self.srt.forward_tx(x) return x class ImprovedSRTEncoderVAE_mlp_ratio4_heavyPatchify(ImprovedSRTEncoderVAE): def __init__(self, **kwargs): super().__init__(mlp_ratio=4, **kwargs) del self.embedder conv_blocks = [SRTConvBlock(idim=10, hdim=48)] # match the ViT-B dim cur_hdim = 48 * 2 for i in range(1, 4): # f=16 still. could reduce attention layers by one? conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) cur_hdim *= 2 self.embedder = nn.Sequential(*conv_blocks) def forward(self, x, **kwargs): """ Args: images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] Returns: scene representation: [batch_size, num_patches, channels_per_patch] """ x = self.embedder(x) # B C H W x = einops.rearrange(x, '(B V) C H W -> B (V H W) C', V=self.num_frames) x = self.transformer(x) # B VL C # ? 3DPE x = self.readout_ca(self.latent_embedding.repeat(x.shape[0], 1, 1), x) # ! reshape to 3D latent here. how to make the latent 3D-aware? Later. Performance first. x = einops.rearrange(x, 'B (N H W) C -> B C (N H) W', H=32, W=32, N=3) return x class HybridEncoderPCDStructuredLatent(Encoder): def __init__(self, num_frames, latent_num=768, **kwargs): super().__init__(**kwargs) # st() self.num_frames = num_frames tx_dim = self.conv_out.weight.shape[1] # after encoder mid_layers self.srt = ImprovedSRTEncoderVAE( **kwargs, # num_frames=4, num_att_blocks=3, # only one layer required tx_dim=tx_dim, num_heads=8, # 256 / 64 mlp_ratio=4, # denoted by srt ) del self.srt.embedder, self.srt.readout_ca, self.srt.latent_embedding # use original # self.box_pool2d = kornia.filters.BlurPool2D(kernel_size=(8,8), stride=8) self.box_pool2d = kornia.filters.BlurPool2D(kernel_size=(8, 8), stride=8) # self.pool2d = kornia.filters.MedianBlur(kernel_size=(8,8), stride=8) self.agg_ca = MemoryEfficientCrossAttention( tx_dim, tx_dim, qk_norm=True, # as in vit-22B ) self.spatial_token_reshape = lambda x: einops.rearrange( x, '(B V) C H W -> B (V H W) C', V=self.num_frames) self.latent_num = latent_num # 768 * 3 by default self.xyz_pos_embed = XYZPosEmbed(tx_dim) # ! VAE part self.conv_out = nn.Identity() self.Mlp_out = PreNorm( tx_dim, # ! add PreNorm before VAE reduction, stablize training. Mlp( in_features=tx_dim, # reduce dim hidden_features=tx_dim, out_features=self.z_channels * 2, # double_z act_layer=approx_gelu, drop=0)) self.ca_no_pcd = False self.pixel_aligned_query = False self.pc2 = True if self.pc2: # https://github.com/lukemelas/projection-conditioned-point-cloud-diffusion/blob/64fd55a0d00b52735cf02e11c5112374c7104ece/experiments/model/projection_model.py#L87 # Save rasterization settings raster_point_radius: float = 0.0075 # point size image_size = 512 # ? hard coded raster_points_per_pixel: int = 1 bin_size: int = 0 self.raster_settings = PointsRasterizationSettings( image_size=(image_size, image_size), radius=raster_point_radius, points_per_pixel=raster_points_per_pixel, bin_size=bin_size, ) self.scale_factor = 1 # def _process_token_xyz(self, token_xyz, h): # # pad zero xyz points to reasonable value. # nonzero_mask = (token_xyz != 0).all(dim=2) # Shape: (B, N) # non_zero_token_xyz = token_xyz[nonzero_mask] # non_zero_token_h = h[nonzero_mask] # # for loop to get foreground points of each instance # # TODO, accelerate with vmap # # No, directly use sparse pcd as input as surface points? fps sampling 768 from 4096 points. # # All points here should not have 0 xyz. # # fg_token_xyz = [] # # for idx in range(token_xyz.shape[1]): # fps_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( # non_zero_token_xyz, K=self.latent_num) # B self.latent_num # # pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # # pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # pcu.save_mesh_v(f'token_xyz3.ply', token_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # # xyz = self.spatial_token_reshape(xyz) # # pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # st() # query_h = masked_gather(non_zero_token_h, fps_idx) # torch.gather with dim expansion # return query_h, fps_xyz def _process_token_xyz(self, pcd, pcd_h): # ! 16x uniform downsample before FPS. # rand_start_pt = random.randint(0,16) # query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( # pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=True) # B self.latent_num # query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], fps_idx) # torch.gather with dim expansion # ! fps very slow on high-res pcd query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( pcd, K=self.latent_num, # random_start_point=False) # B self.latent_num random_start_point=True) # B self.latent_num query_pcd_h = masked_gather(pcd_h, fps_idx) # torch.gather with dim expansion # pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # pcu.save_mesh_v(f'query_pcd_xyz.ply', query_pcd_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # pcu.save_mesh_v(f'pcd_xyz.ply', pcd[0].float().detach().reshape(-1,3).cpu().numpy(),) # xyz = self.spatial_token_reshape(xyz) # pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) return query_pcd_h, query_pcd_xyz def forward(self, x, pcd, **kwargs): # def forward(self, x, num_frames=None): assert x.shape[1] == 15 # rgb(3),normal(3),plucker_ray(6),xyz(3) xyz = x[:, -3:, ...] # for fps downsampling # 0. retrieve VAE tokens h = super().forward( x, num_frames=self.num_frames ) # ! support data augmentation, different FPS different latent corresponding to the same instance? # st() # pcu.save_mesh_v(f'{Path(logger.get_dir())}/anchor_all.ply',pcd[0].float().detach().cpu().numpy()) # ! add 3D PE. # 1. unproj 2D tokens to 3D token_xyz = xyz[..., 4::8, 4::8] if self.pixel_aligned_query: # h = self.spatial_token_reshape(h) # V frames merge to a single latent here. # h = h + self.xyz_pos_embed(token_xyz) # directly add PE to h here. # # ! PE over surface fps-pcd # pcd_h = self.xyz_pos_embed(pcd) # directly add PE to h here. # 2. fps sampling surface as pcd-structured latent. h, query_pcd_xyz = self._process_token_xyz( pcd, token_xyz, h, c=kwargs.get('c'), x=x) # aggregate with pixel-aligned operation. elif self.pc2: # rasterize the point cloud to multi-view feature maps # https://github.com/lukemelas/projection-conditioned-point-cloud-diffusion/blob/64fd55a0d00b52735cf02e11c5112374c7104ece/experiments/model/projection_model.py#L128 # ! prepare the features before projection token_xyz = self.spatial_token_reshape(token_xyz) h = self.spatial_token_reshape( h) # V frames merge to a single latent here. # directly add PE to h here. h = h + self.xyz_pos_embed(token_xyz) # h: B L C # ! prepare pytorch3d camera c = kwargs['c'] # gs_format dict focal_length = c['orig_pose'][..., 16:17] # B V 1 img_h, img_w = x.shape[-2:] R, T = c['R'], c['T'] # B V 3 3, B V 3 # ! bs=1 test. will merge B, V later for parallel compute. V = focal_length.shape[1] principal_point = torch.zeros(V, 2) img_size = torch.Tensor([img_h, img_w]).unsqueeze(0).repeat_interleave(V, 0).to(focal_length) camera = PerspectiveCameras(focal_length=focal_length[0],principal_point=principal_point, R=R[0], T=T[0], image_size=img_size) # camera = PerspectiveCameras(focal_length=focal_length, R=R, T=T, image_size=(img_h, img_w)) # !Create rasterizer rasterizer = PointsRasterizer(cameras=camera.to(pcd.device), raster_settings=self.raster_settings) fragments = rasterizer(Pointclouds(pcd[0:1].repeat_interleave(V, 0))) # (B, H, W, R) fragments_idx: Tensor = fragments.idx.long() visible_pixels = (fragments_idx > -1) # (B, H, W, R) view_idx = 0 # Index of the viewpoint # (Pdb) fragments.zbuf.shape # torch.Size([8, 512, 512, 1]) # depth_image = fragments.zbuf[0, ..., 0].cpu().numpy() # Take the nearest point's depth # depth_image = (depth_image - depth_image.min()) / (depth_image.max()-depth_image.min()) # imageio.imwrite('tmp/depth.jpg', (depth_image*255.0).astype(np.uint8)) # st() points_to_visible_pixels = fragments_idx[visible_pixels] # ! visualize the results # for debug normal = x[:, 3:6, ...] normal_map = (normal * 127.5 + 127.5).float().to( torch.uint8) # BV 3 H W st() pass else: token_xyz = self.spatial_token_reshape(token_xyz) h = self.spatial_token_reshape( h) # V frames merge to a single latent here. h = h + self.xyz_pos_embed(token_xyz) # directly add PE to h here. # ! PE over surface fps-pcd pcd_h = self.xyz_pos_embed(pcd) # directly add PE to h here. # 2. fps sampling surface as pcd-structured latent. query_pcd_h, query_pcd_xyz = self._process_token_xyz(pcd, pcd_h) # 2.5 Cross attention to aggregate from all tokens. if self.ca_no_pcd: h = self.agg_ca(query_pcd_h, h) else: h = self.agg_ca( query_pcd_h, torch.cat([h, pcd_h], dim=1) ) # cross attend to aggregate info from both vae-h and pcd-h # 3. add vit TX (5 layers, concat xyz-PE) # h = h + self.xyz_pos_embed(fps_xyz) # TODO, add PE of query pts. directly add to h here. h = self.srt.transformer(h) # B L C h = self.Mlp_out(h) # equivalent to conv_out, 256 -> 8 in sd-VAE # h = einops.rearrange(h, 'B L C -> B C L') # for VAE compat return { 'h': h, 'query_pcd_xyz': query_pcd_xyz } # h_0, point cloud-structured latent space. For VAE later. class HybridEncoderPCDStructuredLatentUniformFPS( HybridEncoderPCDStructuredLatent): def __init__(self, num_frames, latent_num=768, **kwargs): super().__init__(num_frames, latent_num, **kwargs) self.ca_no_pcd = True # check speed up ratio def _process_token_xyz(self, pcd, pcd_h): # ! 16x uniform downsample before FPS. rand_start_pt = random.randint(0, 16) # rand_start_pt = 0 query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( # pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=False) # B self.latent_num pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=True) # B self.latent_num query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], fps_idx) # torch.gather with dim expansion # st() # ! fps very slow on high-res pcd # query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( # pcd, K=self.latent_num, random_start_point=True) # B self.latent_num # query_pcd_h = masked_gather(pcd_h, fps_idx) # torch.gather with dim expansion # pcu.save_mesh_v(f'xyz.ply', xyz[0].float().detach().permute(1,2,0).reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # pcu.save_mesh_v(f'fps_xyz.ply', fps_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # check result first, before fps sampling # pcu.save_mesh_v(f'query_pcd_xyz.ply', query_pcd_xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) # pcu.save_mesh_v(f'pcd_xyz.ply', pcd[0].float().detach().reshape(-1,3).cpu().numpy(),) # xyz = self.spatial_token_reshape(xyz) # pcu.save_mesh_v(f'xyz_new.ply', xyz[0].float().detach().reshape(-1,3).cpu().numpy(),) return query_pcd_h, query_pcd_xyz class HybridEncoderPCDStructuredLatentSNoPCD(HybridEncoderPCDStructuredLatent): def __init__(self, num_frames, latent_num=768, **kwargs): super().__init__(num_frames, latent_num, **kwargs) self.ca_no_pcd = True class HybridEncoderPCDStructuredLatentSNoPCD_PC2(HybridEncoderPCDStructuredLatentSNoPCD): def __init__(self, num_frames, latent_num=768, **kwargs): super().__init__(num_frames, latent_num, **kwargs) self.pc2 = True class HybridEncoderPCDStructuredLatentSNoPCD_PixelAlignedQuery( HybridEncoderPCDStructuredLatent): def __init__(self, num_frames, latent_num=768, **kwargs): super().__init__(num_frames, latent_num, **kwargs) self.ca_no_pcd = True self.pixel_aligned_query = True self.F = 4 # pixel-aligned query from nearest F views del self.agg_ca # for average pooling now. def _pcd_to_homo(self, pcd): return torch.cat([pcd, torch.ones_like(pcd[..., 0:1])], -1) # ! FPS sampling def _process_token_xyz(self, pcd, token_xyz, h, c, x=None): V = c['cam_pos'].shape[1] # (Pdb) p c.keys() # dict_keys(['source_cv2wT_quat', 'cam_view', 'cam_view_proj', 'cam_pos', 'tanfov', 'orig_pose', 'orig_c2w', 'orig_w2c']) # (Pdb) p c['cam_view'].shape # torch.Size([8, 9, 4, 4]) # (Pdb) p c['cam_pos'].shape # torch.Size([8, 9, 3]) # ! 16x uniform downsample before FPS. # rand_start_pt = random.randint(0,16) # query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( # pcd[:, rand_start_pt::16], K=self.latent_num, random_start_point=True) # B self.latent_num # query_pcd_h = masked_gather(pcd_h[:, rand_start_pt::16], fps_idx) # torch.gather with dim expansion # ! fps very slow on high-res pcd, but better. # ''' query_pcd_xyz, fps_idx = pytorch3d.ops.sample_farthest_points( pcd, K=self.latent_num, random_start_point=True) # B self.latent_num # query_pcd_h = masked_gather(pcd_h, fps_idx) # torch.gather with dim expansion # ''' # ! use unprojected xyz for pixel-aligned projection check # query_pcd_xyz = self.spatial_token_reshape(token_xyz) B, N = query_pcd_xyz.shape[:2] normal = x[:, 3:6, ...] normal_map = (normal * 127.5 + 127.5).float().to( torch.uint8) # BV 3 H W normal_map = einops.rearrange(normal_map, '(B V) C H W -> B V C H W', B=B, V=V).detach().cpu() # V C H W img_size = normal_map.shape[-1] # ! ====== single-view debug here for b in range(c['orig_w2c'].shape[0]): for V in range(c['orig_w2c'].shape[1]): selected_normal = normal_map[b, V] proj_point = c['orig_w2c'][b, V] @ self._pcd_to_homo(query_pcd_xyz[b]).permute(1, 0) proj_point[:2, ...] /= proj_point[2, ...] proj_point[2, ...] = 1 # homo intrin = c['orig_intrin'][b, V] proj_point = intrin @ proj_point[:3] proj_point = proj_point.permute(1,0)[..., :2] # 768 4 # st() # proj_point = c['cam_view_proj'][b, V] @ self._pcd_to_homo(query_pcd_xyz[b]).permute(1, 0) # plot proj_point and save for uv_idx in range(proj_point.shape[0]): # uv = proj_point[uv_idx] * 127.5 + 127.5 # uv = proj_point[uv_idx] * 127.5 + 127.5 uv = proj_point[uv_idx] * img_size x, y = int(uv[0].clip(0, img_size)), int(uv[1].clip(0, img_size)) selected_normal[:, max(y - 1, 0):min(y + 1, img_size), max(x - 1, 0):min(x + 1, img_size)] = torch.Tensor([ 255, 0, 0 ]).reshape(3, 1, 1).to(selected_normal) # set to red torchvision.utils.save_image(selected_normal.float(), f'tmp/pifu_normal_{b}_{V}.jpg', normalize=True, value_range=(0, 255)) st() pass st() # ! ====== single-view debug done # ! project pcd to each views batched_query_pcd = einops.repeat(self._pcd_to_homo(query_pcd_xyz), 'B N C -> (B V N) C 1', V=V) batched_cam_view_proj = einops.repeat(c['cam_view_proj'], 'B V H W -> (B V N) H W', N=N) batched_proj_uv = einops.rearrange( (batched_cam_view_proj @ batched_query_pcd), '(B V N) L 1 -> (B V) L N', B=B, V=V, N=N) # BV 4 N batched_proj_uv = batched_proj_uv[..., :2, :] # BV N 2 # draw projected UV coordinate on 2d normal map # idx_to_vis = 15 * 32 + 16 # middle of the img # idx_to_vis = 16 * 6 + 15 * 32 + 16 # middle of the img idx_to_vis = 0 # use fps points here # st() selected_proj_uv = einops.rearrange(batched_proj_uv, '(B V) C N -> B V C N', B=B, V=V, N=N)[0, ..., idx_to_vis] # V 2 N -> V 2 # selected_normal = einops.rearrange(normal_map, # '(B V) C H W -> B V C H W', # B=B, # V=V)[0].detach().cpu() # V C H W for uv_idx in range(selected_proj_uv.shape[0]): uv = selected_proj_uv[uv_idx] * 127.5 + 127.5 x, y = int(uv[0].clip(0, 255)), int(uv[1].clip(0, 255)) selected_normal[uv_idx, :, max(y - 5, 0):min(y + 5, 255), max(x - 5, 0):min(x + 5, 255)] = torch.Tensor([ 255, 0, 0 ]).reshape(3, 1, 1).to(selected_normal) # set to red # selected_normal[uv_idx, :, max(y-5, 0):min(y+5, 255), max(x-5,0):min(x+5,255)] = torch.Tensor([255,0,0]).to(selected_normal) # set to red # st() torchvision.utils.save_image(selected_normal.float(), 'pifu_normal.jpg', normalize=True, value_range=(0, 255)) st() pass # ! grid sample query_pcd_h = index( h, batched_proj_uv) # h: (B V) C H W, uv: (B V) N 2 -> BV 256 768 query_pcd_h_to_gather = einops.rearrange(query_pcd_h, '(B V) C N -> B N V C', B=B, V=V, N=N) # ! find nearest F views _, knn_idx, _ = pytorch3d.ops.knn_points( query_pcd_xyz, c['cam_pos'], K=self.F, return_nn=False) # knn_idx: B N F knn_idx_expanded = knn_idx[..., None].expand( -1, -1, -1, query_pcd_h_to_gather.shape[-1]) # B N F -> B N F C knn_pcd_h = torch.gather( query_pcd_h_to_gather, dim=2, index=knn_idx_expanded) # torch.Size([8, 768, 4, 256]) # average pooling knn feature. query_pcd_h = knn_pcd_h.mean(dim=2) # add PE pcd_h = self.xyz_pos_embed(query_pcd_xyz) # pcd_h as PE feature. query_pcd_h = query_pcd_h + pcd_h # TODO: QKV aggregation with pcd_h as q, query_pcd_h as kv. Requires gather? '''not used; binary mask for aggregation. # * mask idx not used anymore. torch.gather() instead, more flexible. # knn_idx_mask = torch.zeros((B,N,V), device=knn_idx.device) # knn_idx_mask.scatter_(dim=2, index=knn_idx, src=torch.ones_like(knn_idx_mask)) # ! B N V # try gather # gather_idx = einops.rearrange(knn_idx_mask, 'B N V -> B N V 1').bool() # query_pcd_h = einops.rearrange(query_pcd_h, "(B V) C N -> B N V C", B=pcd_h.shape[0], N=self.latent_num, V=V) # torch.Size([8, 768, 4, 256]) # ! apply KNN mask and average the feature. # query_pcd_h = einops.reduce(query_pcd_h * knn_idx_mask.unsqueeze(-1), 'B N V C -> B N C', 'sum') / self.F # B 768 256. average pooling aggregated feature, like in pifu. ''' ''' # pixel-aligned projection, not efficient enough. knn_cam_view_proj = pytorch3d.ops.knn_gather(einops.rearrange(c['cam_view_proj'], 'B V H W-> B V (H W)'), knn_idx) # get corresponding cam_view_projection matrix (P matrix) knn_cam_view_proj = einops.rearrange(knn_cam_view_proj, 'B N F (H W) -> (B N F) H W', H=4, W=4) # for matmul. H=W=4 here, P matrix. batched_query_pcd = einops.repeat(self._pcd_to_homo(query_pcd_xyz), 'B N C -> (B N F) C 1', F=self.F) xyz = knn_cam_view_proj @ batched_query_pcd # BNF 4 1 # st() knn_spatial_feat = pytorch3d.ops.knn_gather(einops.rearrange(h, '(B V) C H W -> B V (C H W)', V=self.num_frames), knn_idx) # get corresponding feat for grid_sample knn_spatial_feat = einops.rearrange(knn_spatial_feat, 'B N F (C H W) -> (B N F) C H W', C=h.shape[-3], H=h.shape[-2], W=h.shape[-1]) ''' # grid_sample # https://github.com/shunsukesaito/PIFu/blob/f0a9c99ef887e1eb360e865a87aa5f166231980e/lib/geometry.py#L15 # average pooling multi-view extracted information # return query_pcd_h, query_pcd_xyz return query_pcd_h, query_pcd_xyz