diff --git a/README.md b/README.md
index e00d9e8fcbf0d4586c41615b3f50e9bbe30ef738..d9f5496b77b6e4ec5ed5ff225a6993eed88f7718 100644
--- a/README.md
+++ b/README.md
@@ -8,4 +8,75 @@ Repository of LASA: Instance Reconstruction from Real Scans using A Large-scale
![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e)
![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a)
-#### Codes and dataset will be released soon!
+## Dataset
+Complete raw data will be released soon.
+
+## Download preprocessed data and processing
+Download the preprocessed data from
+BaiduYun (code: 62ux). (These data will be updated as cleaning process continues.) Put all the downloaded data under LASA, unzip the align_mat_all.zip mannually.
+You can choose to the the script ./process_scripts/unzip_all_data to unzip all the data in occ_data and other_data by following commands:
+```angular2html
+cd process_scripts
+python unzip_all_data.py --unzip_occ --unzip_other
+```
+Run the following commands to generate augmented partial point cloud for synthetic dataset and LASA dataset
+```angular2html
+cd process_scripts
+python augment_arkit_partial_point.py --cat arkit_chair arkit_stool ...
+python augment_synthetic_partial_point.py --cat 03001627 future_chair ABO_chair ...
+```
+Run the following command to extract image features
+```angular2html
+cd process_scripts
+bash dist_extract_vit.sh
+```
+Finally, run the following command to generate train/val splits:
+```angular2html
+cd process_scripts
+python generate_split_for_arkit --cat arkit_chair arkit_stool ...
+python generate_split_for_synthetic_data.py --cat 03001627 future_chair ABO_chair ...
+```
+
+## Evaluation
+Download the pretrained weight for chair from chair_checkpoint. (code:hlf9).
+Put these folder under LASA/output.
The ae folder stores the VAE weight, dm folder stores the diffusion model trained on synthetic data.
+finetune_dm folder stores the diffusion model finetuned on LASA dataset.
+Run the following commands to evaluate and extract the mesh:
+```angular2html
+cd evaluation
+bash dist_eval.sh
+```
+The category entries are the sub-category from arkit scenes, please see ./datasets/taxonomy.py about how they are defined.
+For example, if you want to evaluate on LASA's chair, category should contain both arkit_chair and arkit_stool.
+make sure the --ae-pth and --dm-pth entry points to the correct checkpoint path. If you are evaluating on LASA,
+make sure the --dm-pth points to the finetuned weight in the ./output/finetune_dm folder. The result will be saved
+under ./output_result.
+
+## Training
+Run the train_VAE.sh to train the VAE model. If you aims to train on one category, just specify one category from chair,
+cabinet, table, sofa, bed, shelf. Inputting all will train on all categories. Makes sure to download and preprocess all
+the required sub-category data. The sub-category arrangement can be found in ./datasets/taxonomy.py
+After finish training the VAE model, run the following commands to pre-extract the VAE features for every object:
+```angular2html
+cd process_scripts
+bash dist_export_triplane_features.sh
+```
+Then, we can start training the diffusion model on the synthetic dataset by running the train_diffusion.sh.
+Finally, finetune the diffusion model on LASA dataset by running finetune_diffusion.sh.
+
+Early stopping is used by mannualy stopping the training by 150 epochs and 500 epochs for training VAE model and diffusion model respetively.
+All experiments in the paper are conducted on 8 A100 GPUs with batch size = 22.
+## TODO
+
+- [ ] Object Detection Code
+- [ ] Code for Demo on both arkitscene and in the wild data
+
+## Citation
+```
+@article{liu2023lasa,
+ title={LASA: Instance Reconstruction from Real Scans using A Large-scale Aligned Shape Annotation Dataset},
+ author={Liu, Haolin and Ye, Chongjie and Nie, Yinyu and He, Yingfan and Han, Xiaoguang},
+ journal={arXiv preprint arXiv:2312.12418},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/configs/config_utils.py b/configs/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2371e31a713d06c4982acb86b69310c5d28c47b7
--- /dev/null
+++ b/configs/config_utils.py
@@ -0,0 +1,70 @@
+import os
+import yaml
+import logging
+from datetime import datetime
+
+def update_recursive(dict1, dict2):
+ ''' Update two config dictionaries recursively.
+
+ Args:
+ dict1 (dict): first dictionary to be updated
+ dict2 (dict): second dictionary which entries should be used
+
+ '''
+ for k, v in dict2.items():
+ if k not in dict1:
+ dict1[k] = dict()
+ if isinstance(v, dict):
+ update_recursive(dict1[k], v)
+ else:
+ dict1[k] = v
+
+class CONFIG(object):
+ '''
+ Stores all configures
+ '''
+ def __init__(self, input=None):
+ '''
+ Loads config file
+ :param path (str): path to config file
+ :return:
+ '''
+ self.config = self.read_to_dict(input)
+
+ def read_to_dict(self, input):
+ if not input:
+ return dict()
+ if isinstance(input, str) and os.path.isfile(input):
+ if input.endswith('yaml'):
+ with open(input, 'r') as f:
+ config = yaml.load(f, Loader=yaml.FullLoader)
+ else:
+ ValueError('Config file should be with the format of *.yaml')
+ elif isinstance(input, dict):
+ config = input
+ else:
+ raise ValueError('Unrecognized input type (i.e. not *.yaml file nor dict).')
+
+ return config
+
+ def update_config(self, *args, **kwargs):
+ '''
+ update config and corresponding logger setting
+ :param input: dict settings add to config file
+ :return:
+ '''
+ cfg1 = dict()
+ for item in args:
+ cfg1.update(self.read_to_dict(item))
+
+ cfg2 = self.read_to_dict(kwargs)
+
+ new_cfg = {**cfg1, **cfg2}
+
+ update_recursive(self.config, new_cfg)
+ # when update config file, the corresponding logger should also be updated.
+ self.__update_logger()
+
+ def write_config(self,save_path):
+ with open(save_path, 'w') as file:
+ yaml.dump(self.config, file, default_flow_style = False)
\ No newline at end of file
diff --git a/configs/finetune_triplane_diffusion.yaml b/configs/finetune_triplane_diffusion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e8ae8986211f68cde8fe1f6702a18df1b871ce7
--- /dev/null
+++ b/configs/finetune_triplane_diffusion.yaml
@@ -0,0 +1,67 @@
+model:
+ ae: #ae model is loaded to
+ type: TriVAE
+ point_emb_dim: 48
+ padding: 0.1
+ encoder:
+ plane_reso: 128
+ plane_latent_dim: 32
+ latent_dim: 32
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 32
+ output_dim: 64
+ decoder:
+ plane_reso: 128
+ latent_dim: 32
+ n_blocks: 5
+ query_emb_dim: 48
+ hidden_dim: 128
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 64
+ output_dim: 32
+ dm:
+ type: triplane_diff_multiimg_cond
+ backbone: resunet_multiimg_direct_atten
+ diff_reso: 64
+ input_channel: 32
+ output_channel: 32
+ triplane_padding: 0.1 #should be consistent with padding in ae
+
+ use_par: True
+ par_channel: 32
+ par_emb_dim: 48
+ norm: "batch"
+ img_in_channels: 1280
+ vit_reso: 16
+ use_cat_embedding: ???
+ block_type: multiview_local
+ par_point_encoder:
+ plane_reso: 64
+ plane_latent_dim: 32
+ n_blocks: 5
+ unet:
+ depth: 3
+ merge_mode: concat
+ start_filts: 32
+ output_dim: 32
+criterion:
+ type: EDMLoss_MultiImgCond
+ use_par: True
+dataset:
+ type: Occ_Par_MultiImg_Finetune
+ data_path: ???
+ surface_size: 20000
+ par_pc_size: 2048
+ load_proj_mat: True
+ load_image: True
+ par_point_aug: 0.5
+ par_prefix: "aug7_"
+ keyword: lowres #use lowres arkitscene or highres to train, lowres scene is more user accessible
+ jitter_partial_pretrain: 0.02
+ jitter_partial_finetune: 0.02
+ jitter_partial_val: 0.0
+ use_pretrain_data: False
diff --git a/configs/train_triplane_diffusion.yaml b/configs/train_triplane_diffusion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6d8d5e87a974cde3fe2337d2eeabd34b681f28f6
--- /dev/null
+++ b/configs/train_triplane_diffusion.yaml
@@ -0,0 +1,64 @@
+model:
+ ae: #ae model is loaded to
+ type: TriVAE
+ point_emb_dim: 48
+ padding: 0.1
+ encoder:
+ plane_reso: 128
+ plane_latent_dim: 32
+ latent_dim: 32
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 32
+ output_dim: 64
+ decoder:
+ plane_reso: 128
+ latent_dim: 32
+ n_blocks: 5
+ query_emb_dim: 48
+ hidden_dim: 128
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 64
+ output_dim: 32
+ dm:
+ type: triplane_diff_multiimg_cond
+ backbone: resunet_multiimg_direct_atten
+ diff_reso: 64
+ input_channel: 32
+ output_channel: 32
+ triplane_padding: 0.1 #should be consistent with padding in ae
+
+ use_par: True
+ par_channel: 32
+ par_emb_dim: 48
+ norm: "batch"
+ img_in_channels: 1280
+ vit_reso: 16
+ use_cat_embedding: ???
+ block_type: multiview_local
+ par_point_encoder:
+ plane_reso: 64
+ plane_latent_dim: 32
+ n_blocks: 5
+ unet:
+ depth: 3
+ merge_mode: concat
+ start_filts: 32
+ output_dim: 32
+criterion:
+ type: EDMLoss_MultiImgCond
+ use_par: True
+dataset:
+ type: Occ_Par_MultiImg
+ data_path: ???
+ surface_size: 20000
+ par_pc_size: 2048
+ load_proj_mat: True
+ load_image: True
+ par_point_aug: 0.5
+ par_prefix: "aug7_" # prefix of the filenames of the partial point cloud
+ jitter_partial_train: 0.02
+ jitter_partial_val: 0.0
diff --git a/configs/train_triplane_vae.yaml b/configs/train_triplane_vae.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34b8757af33c770d8ccd2bd5837040c64404fe63
--- /dev/null
+++ b/configs/train_triplane_vae.yaml
@@ -0,0 +1,30 @@
+model:
+ type: TriVAE
+ point_emb_dim: 48
+ padding: 0.1
+ encoder:
+ plane_reso: 128
+ plane_latent_dim: 32
+ latent_dim: 32
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 32
+ output_dim: 64
+ decoder:
+ plane_reso: 128
+ latent_dim: 32
+ n_blocks: 5
+ query_emb_dim: 48
+ hidden_dim: 128
+ unet:
+ depth: 4
+ merge_mode: concat
+ start_filts: 64
+ output_dim: 32
+dataset:
+ type: Occ
+ category: chair
+ data_path: ???
+ surface_size: 20000
+ num_samples: 2048
\ No newline at end of file
diff --git a/data/download_preprocess_data_here b/data/download_preprocess_data_here
new file mode 100644
index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de
--- /dev/null
+++ b/data/download_preprocess_data_here
@@ -0,0 +1 @@
+1
\ No newline at end of file
diff --git a/datasets/SingleView_dataset.py b/datasets/SingleView_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f8346b7d01c90974a8be8e1b120ff96a71695a4
--- /dev/null
+++ b/datasets/SingleView_dataset.py
@@ -0,0 +1,453 @@
+import os
+import glob
+import random
+
+import yaml
+
+import torch
+from torch.utils import data
+
+import numpy as np
+import json
+
+from PIL import Image
+
+import h5py
+import torch.distributed as dist
+import open3d as o3d
+o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
+import pickle as p
+import time
+import cv2
+from torchvision import transforms
+import copy
+from datasets.taxonomy import category_map_from_synthetic as category_ids
+class Object_Occ(data.Dataset):
+ def __init__(self, dataset_folder, split, categories=['03001627', "future_chair", 'ABO_chair'], transform=None,
+ sampling=True,
+ num_samples=4096, return_surface=True, surface_sampling=True, surface_size=2048, replica=16):
+
+ self.pc_size = surface_size
+
+ self.transform = transform
+ self.num_samples = num_samples
+ self.sampling = sampling
+ self.split = split
+
+ self.dataset_folder = dataset_folder
+ self.return_surface = return_surface
+ self.surface_sampling = surface_sampling
+
+ self.dataset_folder = dataset_folder
+ self.point_folder = os.path.join(self.dataset_folder, 'occ_data')
+ self.mesh_folder = os.path.join(self.dataset_folder, 'other_data')
+
+ if categories is None:
+ categories = os.listdir(self.point_folder)
+ categories = [c for c in categories if
+ os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')]
+ categories.sort()
+
+ print(categories)
+
+ self.models = []
+ for c_idx, c in enumerate(categories):
+ subpath = os.path.join(self.point_folder, c)
+ print(subpath)
+ assert os.path.isdir(subpath)
+
+ split_file = os.path.join(subpath, split + '.lst')
+ with open(split_file, 'r') as f:
+ models_c = f.readlines()
+ models_c = [item.rstrip('\n') for item in models_c]
+
+ for m in models_c[:]:
+ if len(m)<=1:
+ continue
+ if m.endswith('.npz'):
+ model_id = m[:-4]
+ else:
+ model_id = m
+ self.models.append({
+ 'category': c, 'model': model_id
+ })
+ self.replica = replica
+
+ def __getitem__(self, idx):
+ if self.replica >= 1:
+ idx = idx % len(self.models)
+ else:
+ random_segment = random.randint(0, int(1 / self.replica) - 1)
+ idx = int(random_segment * self.replica * len(self.models) + idx)
+
+ category = self.models[idx]['category']
+ model = self.models[idx]['model']
+
+ point_path = os.path.join(self.point_folder, category, model + '.npz')
+ # print(point_path)
+ try:
+ start_t = time.time()
+ with np.load(point_path) as data:
+ vol_points = data['vol_points']
+ vol_label = data['vol_label']
+ near_points = data['near_points']
+ near_label = data['near_label']
+ end_t = time.time()
+ # print("loading time %f"%(end_t-start_t))
+ except Exception as e:
+ print(e)
+ print(point_path)
+
+ with open(point_path.replace('.npz', '.npy'), 'rb') as f:
+ scale = np.load(f).item()
+ # scale=1.0
+
+ if self.return_surface:
+ pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
+ with np.load(pc_path) as data:
+ try:
+ surface = data['points'].astype(np.float32)
+ except:
+ print(pc_path,"has problems")
+ raise AttributeError
+ surface = surface * scale
+ if self.surface_sampling:
+ ind = np.random.default_rng().choice(surface.shape[0], self.pc_size, replace=False)
+ surface = surface[ind]
+ surface = torch.from_numpy(surface)
+
+ if self.sampling:
+ '''need to conduct label balancing'''
+ vol_ind=np.random.default_rng().choice(vol_points.shape[0], self.num_samples,
+ replace=(vol_points.shape[0]= 1:
+ idx = idx % len(self.models)
+ else:
+ random_segment = random.randint(0, int(1 / self.replica) - 1)
+ idx = int(random_segment * self.replica * len(self.models) + idx)
+ category = self.models[idx]['category']
+ model = self.models[idx]['model']
+ #image_filenames = self.model_images_names[model]
+ image_filenames = self.models[idx]["image_filenames"]
+ if self.split=="train":
+ n_frames = np.random.randint(min(2,len(image_filenames)), min(len(image_filenames) + 1, self.max_img_length + 1))
+ img_indexes = np.random.choice(len(image_filenames), n_frames,
+ replace=(n_frames > len(image_filenames))).tolist()
+ else:
+ if self.eval_multiview:
+ '''use all images'''
+ n_frames=len(image_filenames)
+ img_indexes=[i for i in range(n_frames)]
+ else:
+ n_frames = min(len(image_filenames),self.max_img_length)
+ img_indexes=np.linspace(start=0,stop=len(image_filenames)-1,num=n_frames).astype(np.int32)
+
+ partial_filenames = self.models[idx]['partial_filenames']
+ par_index = np.random.choice(len(partial_filenames), 1)[0]
+ partial_name = partial_filenames[par_index]
+
+ vol_points,vol_label,near_points,near_label=None,None,None,None
+ points,labels=None,None
+ point_path = os.path.join(self.point_folder, category, model + '.npz')
+ if self.ret_sample:
+ vol_points,vol_label,near_points,near_label=self.load_samples(point_path)
+ points,labels = self.process_samples(vol_points, vol_label, near_points,near_label)
+
+ with open(point_path.replace('.npz', '.npy'), 'rb') as f:
+ scale = np.load(f).item()
+
+ surface=None
+ pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
+ if self.return_surface:
+ surface=self.load_surface(pc_path,scale)
+
+ partial_path = os.path.join(self.mesh_folder, category, "5_partial_points", model, partial_name)
+ if self.par_point_aug is not None and random.random()= 0.0] = 1
+ label=labels[j:j+1]
+
+ accuracy = (pred == label).float().sum(dim=1) / label.shape[1]
+ accuracy = accuracy.mean()
+ intersection = (pred * label).sum(dim=1)
+ union = (pred + label).gt(0).sum(dim=1)
+ iou = intersection * 1.0 / union + 1e-5
+ iou = iou.mean()
+
+ metric_logger.update(iou=iou.item())
+ metric_logger.update(accuracy=accuracy.item())
+ metric_logger.update(loss=loss.item())
+ metric_logger.synchronize_between_processes()
+ print('* iou {ious.global_avg:.3f}'
+ .format(ious=metric_logger.iou))
+ print('* accuracy {accuracies.global_avg:.3f}'
+ .format(accuracies=metric_logger.accuracy))
+ print('* loss {losses.global_avg:.3f}'
+ .format(losses=metric_logger.loss))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
\ No newline at end of file
diff --git a/engine/engine_triplane_vae.py b/engine/engine_triplane_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..765aad41829f145bc701e1f5c4fa09a6c32b0d63
--- /dev/null
+++ b/engine/engine_triplane_vae.py
@@ -0,0 +1,185 @@
+# --------------------------------------------------------
+# References:
+# MAE: https://github.com/facebookresearch/mae
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import math
+import sys
+sys.path.append("..")
+from typing import Iterable
+
+import torch
+import torch.nn.functional as F
+
+import util.misc as misc
+import util.lr_sched as lr_sched
+
+
+def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
+ log_writer=None, args=None):
+ model.train(True)
+ metric_logger = misc.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 20
+
+ accum_iter = args.accum_iter
+
+ optimizer.zero_grad()
+
+ kl_weight = 25e-3 #TODO: try to modify this, it is 1e-3 originally, large kl ease the training of diffusion, but decrease in VAE results
+
+ if log_writer is not None:
+ print('log_dir: {}'.format(log_writer.log_dir))
+
+ for data_iter_step, data_batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ # we use a per iteration (instead of per epoch) lr scheduler
+ if data_iter_step % accum_iter == 0:
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
+
+ points = data_batch['points'].to(device, non_blocking=True)
+ labels = data_batch['labels'].to(device, non_blocking=True)
+ surface = data_batch['surface'].to(device, non_blocking=True)
+ # print(points.shape)
+ with torch.cuda.amp.autocast(enabled=False):
+ outputs = model(surface, points)
+ if 'kl' in outputs:
+ loss_kl = outputs['kl']
+ #print(loss_kl.shape)
+ loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
+ else:
+ loss_kl = None
+
+ outputs = outputs['logits']
+
+ num_samples=outputs.shape[1]//2
+ #print(num_samples)
+ loss_vol = criterion(outputs[:, :num_samples], labels[:, :num_samples])
+ loss_near = criterion(outputs[:, num_samples:], labels[:, num_samples:])
+
+ if loss_kl is not None:
+ loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl
+ else:
+ loss = loss_vol + 0.1 * loss_near
+
+ loss_value = loss.item()
+
+ threshold = 0
+
+ pred = torch.zeros_like(outputs[:, :num_samples])
+ pred[outputs[:, :num_samples] >= threshold] = 1
+
+ accuracy = (pred == labels[:, :num_samples]).float().sum(dim=1) / labels[:, :num_samples].shape[1]
+ accuracy = accuracy.mean()
+ intersection = (pred * labels[:, :num_samples]).sum(dim=1)
+ union = (pred + labels[:, :num_samples]).gt(0).sum(dim=1) + 1e-5
+ iou = intersection * 1.0 / union
+ iou = iou.mean()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ loss /= accum_iter
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
+ parameters=model.parameters(), create_graph=False,
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.zero_grad()
+
+ torch.cuda.synchronize()
+
+ metric_logger.update(loss=loss_value)
+
+ metric_logger.update(loss_vol=loss_vol.item())
+ metric_logger.update(loss_near=loss_near.item())
+
+ if loss_kl is not None:
+ metric_logger.update(loss_kl=loss_kl.item())
+
+ metric_logger.update(iou=iou.item())
+
+ min_lr = 10.
+ max_lr = 0.
+ for group in optimizer.param_groups:
+ min_lr = min(min_lr, group["lr"])
+ max_lr = max(max_lr, group["lr"])
+
+ metric_logger.update(lr=max_lr)
+
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
+ iou_reduce=misc.all_reduce_mean(iou)
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
+ """ We use epoch_1000x as the x-axis in tensorboard.
+ This calibrates different curves when batch size changes.
+ """
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
+ log_writer.add_scalar('iou', iou_reduce, epoch_1000x)
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device):
+ criterion = torch.nn.BCEWithLogitsLoss()
+
+ metric_logger = misc.MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ # switch to evaluation mode
+ model.eval()
+
+ for data_batch in metric_logger.log_every(data_loader, 50, header):
+
+ points = data_batch['points'].to(device, non_blocking=True)
+ labels = data_batch['labels'].to(device, non_blocking=True)
+ surface = data_batch['surface'].to(device, non_blocking=True)
+ # compute output
+ with torch.cuda.amp.autocast(enabled=False):
+
+ outputs = model(surface, points)
+ if 'kl' in outputs:
+ loss_kl = outputs['kl']
+ loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
+ else:
+ loss_kl = None
+
+ outputs = outputs['logits']
+
+ loss = criterion(outputs, labels)
+
+ threshold = 0
+
+ pred = torch.zeros_like(outputs)
+ pred[outputs >= threshold] = 1
+
+ accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1]
+ accuracy = accuracy.mean()
+ intersection = (pred * labels).sum(dim=1)
+ union = (pred + labels).gt(0).sum(dim=1)
+ iou = intersection * 1.0 / union + 1e-5
+ iou = iou.mean()
+
+ batch_size = points.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['iou'].update(iou.item(), n=batch_size)
+
+ if loss_kl is not None:
+ metric_logger.update(loss_kl=loss_kl.item())
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print('* iou {iou.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(iou=metric_logger.iou, losses=metric_logger.loss))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
\ No newline at end of file
diff --git a/evaluation/dist_eval.sh b/evaluation/dist_eval.sh
new file mode 100644
index 0000000000000000000000000000000000000000..43387a13d173ae55dece292652ddfe5874aa65dd
--- /dev/null
+++ b/evaluation/dist_eval.sh
@@ -0,0 +1,16 @@
+CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
+evaluate_object_reconstruction.py \
+--configs ../configs/finetune_triplane_diffusion.yaml \
+--category arkit_chair arkit_stool \
+--ae-pth ../output/ae/chair/best-checkpoint.pth \
+--dm-pth ../output/finetune_dm/lowres_chair/best-checkpoint.pth \
+--output_folder ../output_result/chair_result \
+--data-pth ../data \
+--eval_cd \
+--reso 256 \
+--save_mesh \
+--save_par_points \
+--save_image \
+--save_surface
+
+#check ./datasets/taxonomy to see how sub categories are defined
diff --git a/evaluation/evaluate_object_reconstruction.py b/evaluation/evaluate_object_reconstruction.py
new file mode 100644
index 0000000000000000000000000000000000000000..65d414963a218dfd92abc36a6b9bbeb982a4fcce
--- /dev/null
+++ b/evaluation/evaluate_object_reconstruction.py
@@ -0,0 +1,239 @@
+import argparse
+import sys
+sys.path.append("..")
+sys.path.append(".")
+import numpy as np
+
+import mcubes
+import os
+import torch
+
+import trimesh
+
+from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
+from datasets.transforms import Scale_Shift_Rotate
+from models import get_model
+from pathlib import Path
+import open3d as o3d
+from configs.config_utils import CONFIG
+import cv2
+from util.misc import MetricLogger
+import scipy
+from pyTorchChamferDistance.chamfer_distance import ChamferDistance
+from util.projection_utils import draw_proj_image
+from util import misc
+import time
+dist_chamfer=ChamferDistance()
+
+
+def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5):
+ """ p2: reference ponits
+ (B, N, 3)
+ """
+ p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale
+ f_thresh = space_ext * fscore_param
+
+ #print(p1.shape,p2.shape)
+ d1, d2, _, _ = dist_chamfer(p1, p2)
+ #print(d1.shape,d2.shape)
+ d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5)
+ chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1)
+ chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1)
+ precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1]
+ recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1]
+ #print(precision,recall)
+ fscore = 2 * torch.div(recall * precision, recall + precision)
+ fscore[fscore == float("inf")] = 0
+ return chamfer_L1,chamfer_L2,fscore
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False)
+ parser.add_argument('--configs',type=str,required=True)
+ parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926")
+ parser.add_argument('--dm-pth',type=str)
+ parser.add_argument('--ae-pth',type=str)
+ parser.add_argument('--data-pth', type=str,default="../")
+ parser.add_argument('--save_mesh',action="store_true",default=False)
+ parser.add_argument('--save_image',action="store_true",default=False)
+ parser.add_argument('--save_par_points', action="store_true", default=False)
+ parser.add_argument('--save_proj_img',action="store_true",default=False)
+ parser.add_argument('--save_surface',action="store_true",default=False)
+ parser.add_argument('--reso',default=128,type=int)
+ parser.add_argument('--category',nargs="+",type=str)
+ parser.add_argument('--eval_cd',action="store_true",default=False)
+ parser.add_argument('--use_augmentation',action="store_true",default=False)
+
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ args = parser.parse_args()
+ misc.init_distributed_mode(args)
+ config_path=args.configs
+ config=CONFIG(config_path)
+ dataset_config=config.config['dataset']
+ dataset_config['data_path']=args.data_pth
+ if "arkit" in args.category[0]:
+ split_filename=dataset_config['keyword']+'_val_par_img.json'
+ else:
+ split_filename='val_par_img.json'
+
+ transform = None
+ if args.use_augmentation:
+ transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1))
+ dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val",
+ transform=transform, sampling=False,
+ num_samples=1024, return_surface=True,ret_sample=True,
+ surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000,
+ load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1)
+ batch_size=1
+
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
+ val_sampler = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank,
+ shuffle=False) # shu
+ dataloader_val=torch.utils.data.DataLoader(
+ dataset_val,
+ sampler=val_sampler,
+ batch_size=batch_size,
+ num_workers=10,
+ shuffle=False,
+ )
+ output_folder=args.output_folder
+
+ device = torch.device('cuda')
+
+ ae_config=config.config['model']['ae']
+ dm_config=config.config['model']['dm']
+ ae_model=get_model(ae_config).to(device)
+ if args.category[0] == "all":
+ dm_config["use_cat_embedding"]=True
+ else:
+ dm_config["use_cat_embedding"] = False
+ dm_model=get_model(dm_config).to(device)
+ ae_model.eval()
+ dm_model.eval()
+ ae_model.load_state_dict(torch.load(args.ae_pth)['model'])
+ dm_model.load_state_dict(torch.load(args.dm_pth)['model'])
+
+ density = args.reso
+ gap = 2.2 / density
+ x = np.linspace(-1.1, 1.1, int(density + 1))
+ y = np.linspace(-1.1, 1.1, int(density + 1))
+ z = np.linspace(-1.1, 1.1, int(density + 1))
+ xv, yv, zv = np.meshgrid(x, y, z,indexing='ij')
+ grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True)
+
+ metric_logger=MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ with torch.no_grad():
+ for data_batch in metric_logger.log_every(dataloader_val,10, header):
+ # if data_iter_step==100:
+ # break
+ partial_name = data_batch['partial_name']
+ class_name = data_batch['class_name']
+ model_ids=data_batch['model_id']
+ surface=data_batch['surface']
+ proj_matrices=data_batch['proj_mat']
+ sample_points=data_batch["points"].cuda().float()
+ labels=data_batch["labels"].cuda().float()
+ sample_input=dm_model.prepare_sample_data(data_batch)
+ #t1 = time.time()
+ sampled_array = dm_model.sample(sample_input,num_steps=36).float()
+ #t2 = time.time()
+ #sample_time = t2 - t1
+ #print("sampling time %f" % (sample_time))
+ sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
+ for j in range(sampled_array.shape[0]):
+ if args.save_mesh | args.save_par_points | args.save_image:
+ object_folder = os.path.join(output_folder, class_name[j], model_ids[j])
+ Path(object_folder).mkdir(parents=True, exist_ok=True)
+ '''calculate iou'''
+ sample_point=sample_points[j:j+1]
+ sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point)
+ sample_pred=torch.zeros_like(sample_output)
+ sample_pred[sample_output>=0.0]=1
+ label=labels[j:j+1]
+ intersection = (sample_pred * label).sum(dim=1)
+ union = (sample_pred + label).gt(0).sum(dim=1)
+ iou = intersection * 1.0 / union + 1e-5
+ iou = iou.mean()
+ metric_logger.update(iou=iou.item())
+
+ if args.use_augmentation:
+ tran_mat=data_batch["tran_mat"][j].numpy()
+ mat_save_path='{}/tran_mat.npy'.format(object_folder)
+ np.save(mat_save_path,tran_mat)
+
+ if args.eval_cd:
+ grid_list=torch.split(grid,128**3,dim=1)
+ output_list=[]
+ #t3=time.time()
+ for sub_grid in grid_list:
+ output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid))
+ output=torch.cat(output_list,dim=1)
+ #t4=time.time()
+ #decoding_time=t4-t3
+ #print("decoding time:",decoding_time)
+ logits = output[j].detach()
+
+ volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
+ verts, faces = mcubes.marching_cubes(volume, 0)
+
+ verts *= gap
+ verts -= 1.1
+ #print("vertice max min",np.amin(verts,axis=0),np.amax(verts,axis=0))
+
+
+ m = trimesh.Trimesh(verts, faces)
+ '''calculate fscore and chamfer distance'''
+ result_surface,_=trimesh.sample.sample_surface(m,100000)
+ gt_surface=surface[j]
+ assert gt_surface.shape[0]==result_surface.shape[0]
+
+ result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0)
+ gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0)
+ _,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu)
+ metric_logger.update(chamferl2=chamfer_L2*1000.0)
+ metric_logger.update(fscore=fscore)
+
+ if args.save_mesh:
+ m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j]))
+
+ if args.save_par_points:
+ par_point_input = data_batch['par_points'][j].numpy()
+ #print("input max min", np.amin(par_point_input, axis=0), np.amax(par_point_input, axis=0))
+ par_point_o3d = o3d.geometry.PointCloud()
+ par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3])
+ o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d)
+ if args.save_image:
+ image_list=data_batch["org_image"]
+ for idx,image in enumerate(image_list):
+ image=image[0].numpy().astype(np.uint8)
+ if args.save_proj_img:
+ proj_mat=proj_matrices[j,idx].numpy()
+ proj_image=draw_proj_image(image,proj_mat,result_surface)
+ proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx)
+ cv2.imwrite(proj_save_path,proj_image)
+ save_path='{}/{}.jpg'.format(object_folder, idx)
+ cv2.imwrite(save_path,image)
+ if args.save_surface:
+ surface=gt_surface.numpy().astype(np.float32)
+ surface_o3d = o3d.geometry.PointCloud()
+ surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3])
+ o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d)
+ metric_logger.synchronize_between_processes()
+ print('* iou {ious.global_avg:.3f}'
+ .format(ious=metric_logger.iou))
+ if args.eval_cd:
+ print('* chamferl2 {chamferl2s.global_avg:.3f}'
+ .format(chamferl2s=metric_logger.chamferl2))
+ print('* fscore {fscores.global_avg:.3f}'
+ .format(fscores=metric_logger.fscore))
diff --git a/evaluation/pyTorchChamferDistance/.gitignore b/evaluation/pyTorchChamferDistance/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..506b29de9d26cace37122214d09d00b204a7c716
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/.gitignore
@@ -0,0 +1,3 @@
+.DS_Store
+._*
+
diff --git a/evaluation/pyTorchChamferDistance/LICENSE.md b/evaluation/pyTorchChamferDistance/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..8aa26455d23acf904be3ed9dfb3a3efe3e49245a
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/LICENSE.md
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) [year] [fullname]
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/evaluation/pyTorchChamferDistance/README.md b/evaluation/pyTorchChamferDistance/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2902a61fb658c10b057c67b406d937ef9f539284
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/README.md
@@ -0,0 +1,23 @@
+# Chamfer Distance for pyTorch
+
+This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.
+
+As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run.
+
+### Usage
+```python
+from chamfer_distance import ChamferDistance
+chamfer_dist = ChamferDistance()
+
+#...
+# points and points_reconstructed are n_points x 3 matrices
+
+dist1, dist2 = chamfer_dist(points, points_reconstructed)
+loss = (torch.mean(dist1)) + (torch.mean(dist2))
+
+
+#...
+```
+
+### Integration
+This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :)
diff --git a/evaluation/pyTorchChamferDistance/__init__.py b/evaluation/pyTorchChamferDistance/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py b/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e15be7028d12ddc55b29752ac718c5284200203
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py
@@ -0,0 +1 @@
+from .chamfer_distance import ChamferDistance
diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..40f3d79aee526188f2df559a34e563026933be56
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp
@@ -0,0 +1,185 @@
+#include
+
+// CUDA forward declarations
+int ChamferDistanceKernelLauncher(
+ const int b, const int n,
+ const float* xyz,
+ const int m,
+ const float* xyz2,
+ float* result,
+ int* result_i,
+ float* result2,
+ int* result2_i);
+
+int ChamferDistanceGradKernelLauncher(
+ const int b, const int n,
+ const float* xyz1,
+ const int m,
+ const float* xyz2,
+ const float* grad_dist1,
+ const int* idx1,
+ const float* grad_dist2,
+ const int* idx2,
+ float* grad_xyz1,
+ float* grad_xyz2);
+
+
+void chamfer_distance_forward_cuda(
+ const at::Tensor xyz1,
+ const at::Tensor xyz2,
+ const at::Tensor dist1,
+ const at::Tensor dist2,
+ const at::Tensor idx1,
+ const at::Tensor idx2)
+{
+ ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
+ xyz2.size(1), xyz2.data(),
+ dist1.data(), idx1.data(),
+ dist2.data(), idx2.data());
+}
+
+void chamfer_distance_backward_cuda(
+ const at::Tensor xyz1,
+ const at::Tensor xyz2,
+ at::Tensor gradxyz1,
+ at::Tensor gradxyz2,
+ at::Tensor graddist1,
+ at::Tensor graddist2,
+ at::Tensor idx1,
+ at::Tensor idx2)
+{
+ ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
+ xyz2.size(1), xyz2.data(),
+ graddist1.data(), idx1.data(),
+ graddist2.data(), idx2.data(),
+ gradxyz1.data(), gradxyz2.data());
+}
+
+
+void nnsearch(
+ const int b, const int n, const int m,
+ const float* xyz1,
+ const float* xyz2,
+ float* dist,
+ int* idx)
+{
+ for (int i = 0; i < b; i++) {
+ for (int j = 0; j < n; j++) {
+ const float x1 = xyz1[(i*n+j)*3+0];
+ const float y1 = xyz1[(i*n+j)*3+1];
+ const float z1 = xyz1[(i*n+j)*3+2];
+ double best = 0;
+ int besti = 0;
+ for (int k = 0; k < m; k++) {
+ const float x2 = xyz2[(i*m+k)*3+0] - x1;
+ const float y2 = xyz2[(i*m+k)*3+1] - y1;
+ const float z2 = xyz2[(i*m+k)*3+2] - z1;
+ const double d=x2*x2+y2*y2+z2*z2;
+ if (k==0 || d < best){
+ best = d;
+ besti = k;
+ }
+ }
+ dist[i*n+j] = best;
+ idx[i*n+j] = besti;
+ }
+ }
+}
+
+
+void chamfer_distance_forward(
+ const at::Tensor xyz1,
+ const at::Tensor xyz2,
+ const at::Tensor dist1,
+ const at::Tensor dist2,
+ const at::Tensor idx1,
+ const at::Tensor idx2)
+{
+ const int batchsize = xyz1.size(0);
+ const int n = xyz1.size(1);
+ const int m = xyz2.size(1);
+
+ const float* xyz1_data = xyz1.data();
+ const float* xyz2_data = xyz2.data();
+ float* dist1_data = dist1.data();
+ float* dist2_data = dist2.data();
+ int* idx1_data = idx1.data();
+ int* idx2_data = idx2.data();
+
+ nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
+ nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
+}
+
+
+void chamfer_distance_backward(
+ const at::Tensor xyz1,
+ const at::Tensor xyz2,
+ at::Tensor gradxyz1,
+ at::Tensor gradxyz2,
+ at::Tensor graddist1,
+ at::Tensor graddist2,
+ at::Tensor idx1,
+ at::Tensor idx2)
+{
+ const int b = xyz1.size(0);
+ const int n = xyz1.size(1);
+ const int m = xyz2.size(1);
+
+ const float* xyz1_data = xyz1.data();
+ const float* xyz2_data = xyz2.data();
+ float* gradxyz1_data = gradxyz1.data();
+ float* gradxyz2_data = gradxyz2.data();
+ float* graddist1_data = graddist1.data();
+ float* graddist2_data = graddist2.data();
+ const int* idx1_data = idx1.data();
+ const int* idx2_data = idx2.data();
+
+ for (int i = 0; i < b*n*3; i++)
+ gradxyz1_data[i] = 0;
+ for (int i = 0; i < b*m*3; i++)
+ gradxyz2_data[i] = 0;
+ for (int i = 0;i < b; i++) {
+ for (int j = 0; j < n; j++) {
+ const float x1 = xyz1_data[(i*n+j)*3+0];
+ const float y1 = xyz1_data[(i*n+j)*3+1];
+ const float z1 = xyz1_data[(i*n+j)*3+2];
+ const int j2 = idx1_data[i*n+j];
+
+ const float x2 = xyz2_data[(i*m+j2)*3+0];
+ const float y2 = xyz2_data[(i*m+j2)*3+1];
+ const float z2 = xyz2_data[(i*m+j2)*3+2];
+ const float g = graddist1_data[i*n+j]*2;
+
+ gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
+ gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
+ gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
+ gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
+ gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
+ gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
+ }
+ for (int j = 0; j < m; j++) {
+ const float x1 = xyz2_data[(i*m+j)*3+0];
+ const float y1 = xyz2_data[(i*m+j)*3+1];
+ const float z1 = xyz2_data[(i*m+j)*3+2];
+ const int j2 = idx2_data[i*m+j];
+ const float x2 = xyz1_data[(i*n+j2)*3+0];
+ const float y2 = xyz1_data[(i*n+j2)*3+1];
+ const float z2 = xyz1_data[(i*n+j2)*3+2];
+ const float g = graddist2_data[i*m+j]*2;
+ gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
+ gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
+ gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
+ gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
+ gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
+ gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
+ }
+ }
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
+ m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
+ m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
+ m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
+}
diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f10f2ba854883d7f590236bb69e3598e8a4ef379
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu
@@ -0,0 +1,209 @@
+#include
+
+#include
+#include
+
+__global__
+void ChamferDistanceKernel(
+ int b,
+ int n,
+ const float* xyz,
+ int m,
+ const float* xyz2,
+ float* result,
+ int* result_i)
+{
+ const int batch=512;
+ __shared__ float buf[batch*3];
+ for (int i=blockIdx.x;ibest){
+ result[(i*n+j)]=best;
+ result_i[(i*n+j)]=best_i;
+ }
+ }
+ __syncthreads();
+ }
+ }
+}
+
+void ChamferDistanceKernelLauncher(
+ const int b, const int n,
+ const float* xyz,
+ const int m,
+ const float* xyz2,
+ float* result,
+ int* result_i,
+ float* result2,
+ int* result2_i)
+{
+ ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i);
+ ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
+}
+
+
+__global__
+void ChamferDistanceGradKernel(
+ int b, int n,
+ const float* xyz1,
+ int m,
+ const float* xyz2,
+ const float* grad_dist1,
+ const int* idx1,
+ float* grad_xyz1,
+ float* grad_xyz2)
+{
+ for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
+ ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
+}
diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3bea670105c2900909b85210cf10331063cb5e3
--- /dev/null
+++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py
@@ -0,0 +1,58 @@
+
+import torch
+
+from torch.utils.cpp_extension import load
+cd = load(name="build",
+ sources=["pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp",
+ "pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"],
+ build_directory="pyTorchChamferDistance/build")
+
+class ChamferDistanceFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, xyz1, xyz2):
+ batchsize, n, _ = xyz1.size()
+ _, m, _ = xyz2.size()
+ xyz1 = xyz1.contiguous()
+ xyz2 = xyz2.contiguous()
+ dist1 = torch.zeros(batchsize, n)
+ dist2 = torch.zeros(batchsize, m)
+
+ idx1 = torch.zeros(batchsize, n, dtype=torch.int)
+ idx2 = torch.zeros(batchsize, m, dtype=torch.int)
+
+ if not xyz1.is_cuda:
+ cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
+ else:
+ dist1 = dist1.cuda()
+ dist2 = dist2.cuda()
+ idx1 = idx1.cuda()
+ idx2 = idx2.cuda()
+ cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
+
+ ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
+
+ return dist1, dist2, idx1, idx2
+
+ @staticmethod
+ def backward(ctx, graddist1, graddist2, *args):
+ xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
+
+ graddist1 = graddist1.contiguous()
+ graddist2 = graddist2.contiguous()
+
+ gradxyz1 = torch.zeros(xyz1.size())
+ gradxyz2 = torch.zeros(xyz2.size())
+
+ if not graddist1.is_cuda:
+ cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
+ else:
+ gradxyz1 = gradxyz1.cuda()
+ gradxyz2 = gradxyz2.cuda()
+ cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
+
+ return gradxyz1, gradxyz2
+
+
+class ChamferDistance(torch.nn.Module):
+ def forward(self, xyz1, xyz2):
+ return ChamferDistanceFunction.apply(xyz1, xyz2)
diff --git a/finetune_diffusion.sh b/finetune_diffusion.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c9f2a529af4b0a936578e816a41013aaa385ede9
--- /dev/null
+++ b/finetune_diffusion.sh
@@ -0,0 +1,18 @@
+cd scripts
+CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15003 --nproc_per_node=8 \
+train_triplane_diffusion.py \
+--configs ../configs/finetune_triplane_diffusion.yaml \
+--accum_iter 2 \
+--output_dir ../output/finetune_dm/lowres_chair \
+--log_dir ../output/finetune_dm/lowres_chair --num_workers 8 \
+--batch_size 22 \
+--blr 1e-4 \
+--epochs 500 \
+--dist_eval \
+--warmup_epochs 20 \
+--ae-pth ../output/ae/chair/best-checkpoint.pth \
+--category chair \
+--finetune \
+--finetune-pth ../output/dm/chair/best-checkpoint.pth \
+--data-pth ../data \
+--replica 5
\ No newline at end of file
diff --git a/models/TriplaneVAE.py b/models/TriplaneVAE.py
new file mode 100644
index 0000000000000000000000000000000000000000..477da1488d66914894b461ef46b5191da8f3c839
--- /dev/null
+++ b/models/TriplaneVAE.py
@@ -0,0 +1,94 @@
+import torch.nn as nn
+import sys,os
+sys.path.append("..")
+import torch
+from datasets import build_dataset
+from configs.config_utils import CONFIG
+from torch.utils.data import DataLoader
+from models.modules import PointEmbed
+from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder
+import numpy as np
+
+class TriplaneVAE(nn.Module):
+ def __init__(self,opt):
+ super().__init__()
+ self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim'])
+
+ encoder_args=opt['encoder']
+ decoder_args=opt['decoder']
+ self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'],
+ plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding'])
+ self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'],
+ hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'],
+ plane_resolution=decoder_args['plane_reso'],padding=opt['padding'])
+
+ def forward(self,p,query):
+ '''
+ :param p: surface points cloud of shape B,N,3
+ :param query: sample points of shape B,N,3
+ :return:
+ '''
+ point_emb=self.point_embedder(p)
+ query_emb=self.point_embedder(query)
+ kl,plane_feat,means,logvars=self.encoder(p,point_emb)
+ if self.training:
+ if np.random.random()<0.5:
+ '''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness'''
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear")
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear")
+ # if self.training:
+ # if np.random.random()<0.5:
+ # means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
+ # vars=torch.exp(logvars)
+ # vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")
+ # new_logvars=torch.log(vars)
+ # posterior = DiagonalGaussianDistribution(means, new_logvars)
+ # plane_feat=posterior.sample()
+ # plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear')
+
+ # mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
+ # vars = torch.exp(logvars)
+ # vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4
+ # logvars_scale=torch.log(vars_scale)
+ # scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device)
+ # plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise
+ # plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear')
+ o=self.decoder(plane_feat,query,query_emb)
+
+ return {'logits':o,'kl':kl}
+
+
+ def decode(self,plane_feature,query):
+ query_embedding=self.point_embedder(query)
+ o=self.decoder(plane_feature,query,query_embedding)
+
+ return o
+
+ def encode(self,p):
+ point_emb = self.point_embedder(p)
+ kl, plane_feat,mean,logvar = self.encoder(p, point_emb)
+ '''p is point cloud of B,N,3'''
+ return plane_feat,kl,mean,logvar
+
+if __name__=="__main__":
+ configs=CONFIG("../configs/train_triplane_vae_64.yaml")
+ config=configs.config
+ dataset_config=config['datasets']
+ model_config=config["model"]
+ dataset=build_dataset("train",dataset_config)
+ dataset.__getitem__(0)
+ dataloader=DataLoader(
+ dataset=dataset,
+ batch_size=10,
+ shuffle=True,
+ num_workers=2,
+ )
+ net=TriplaneVAE(model_config).float().cuda()
+ for idx,data_batch in enumerate(dataloader):
+ if idx==1:
+ break
+ surface=data_batch['surface'].float().cuda()
+ query=data_batch['points'].float().cuda()
+ net(surface,query)
+
+
diff --git a/models/Triplane_Diffusion.py b/models/Triplane_Diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ae3a09321368c2fe1ffe4ebbc5065d53f4b73c
--- /dev/null
+++ b/models/Triplane_Diffusion.py
@@ -0,0 +1,190 @@
+import torch
+import torch.nn as nn
+from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond
+from models.modules.parpoints_encoder import ParPoint_Encoder
+from models.modules.PointEMB import PointEmbed
+from models.modules.utils import StackedRandomGenerator
+from models.modules.diffusion_sampler import edm_sampler
+from models.modules.encoder import DiagonalGaussianDistribution
+import numpy as np
+class EDMLoss_MultiImgCond:
+ def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False):
+ self.P_mean = P_mean
+ self.P_std = P_std
+ self.sigma_data = sigma_data
+ self.use_par=use_par
+
+ def __call__(self, net, data_batch, classifier_free=False):
+ inputs = data_batch['input']
+ image=data_batch['image']
+ proj_mat=data_batch['proj_mat']
+ valid_frames=data_batch['valid_frames']
+ par_points=data_batch["par_points"]
+ category_code=data_batch["category_code"]
+ rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device)
+
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp() #B,1,1,1
+ weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2
+ y=inputs
+
+ n = torch.randn_like(y) * sigma
+
+ # if classifier_free and np.random.random()<0.5:
+ # net.par_feat=torch.zeros((inputs.shape[0],32,inputs.shape[2],inputs.shape[3])).float().to(inputs.device)
+ if classifier_free and np.random.random()<0.5:
+ image=torch.zeros_like(image).float().cuda()
+ net.module.extract_img_feat(image)
+ net.module.set_proj_matrix(proj_mat)
+ net.module.set_valid_frames(valid_frames)
+ net.module.set_category_code(category_code)
+ if self.use_par:
+ net.module.extract_point_feat(par_points)
+
+ D_yn = net(y + n,sigma)
+ loss = weight * ((D_yn - y) ** 2)
+ return loss
+
+class Triplane_Diff_MultiImgCond_EDM(nn.Module):
+ def __init__(self,opt):
+ super().__init__()
+ self.diff_reso=opt['diff_reso']
+ self.diff_dim=opt['output_channel']
+ self.use_cat_embedding=opt['use_cat_embedding']
+ self.use_fp16=False
+ self.sigma_data=0.5
+ self.sigma_max=float("inf")
+ self.sigma_min=0
+ self.use_par=opt['use_par']
+ self.triplane_padding=opt['triplane_padding']
+ self.block_type=opt['block_type']
+ #self.use_bn=opt['use_bn']
+ if opt['backbone']=="resunet_multiimg_direct_atten":
+ self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'],
+ output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'],
+ img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding,
+ norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type)
+ else:
+ raise NotImplementedError
+ if opt['use_par']: #use partial point cloud as inputs
+ par_emb_dim = opt['par_emb_dim']
+ par_args = opt['par_point_encoder']
+ self.point_embedder = PointEmbed(hidden_dim=par_emb_dim)
+ self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim,
+ plane_resolution=par_args['plane_reso'],
+ unet_kwargs=par_args['unet'])
+ self.unflatten = torch.nn.Unflatten(1, (16, 16))
+ def prepare_data(self,data_batch):
+ #par_points = data_batch['par_points'].to(device, non_blocking=True)
+ device=torch.device("cuda")
+ means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to(
+ device, non_blocking=True)
+ distribution = DiagonalGaussianDistribution(means, logvars)
+ plane_feat = distribution.sample()
+
+ image=data_batch["image"].to(device)
+ proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
+ valid_frames=data_batch["valid_frames"].to(device,non_blocking=True)
+ par_points=data_batch["par_points"].to(device,non_blocking=True)
+ category_code=data_batch["category_code"].to(device,non_blocking=True)
+ input_dict = {"input": plane_feat.float(),
+ "image": image.float(),
+ "par_points":par_points.float(),
+ "proj_mat":proj_mat.float(),
+ "category_code":category_code.float(),
+ "valid_frames":valid_frames.float()} # TODO: add image and proj matrix
+
+ return input_dict
+
+ def prepare_sample_data(self,data_batch):
+ device=torch.device("cuda")
+ image=data_batch['image'].to(device, non_blocking=True)
+ proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
+ valid_frames = data_batch["valid_frames"].to(device, non_blocking=True)
+ par_points = data_batch["par_points"].to(device, non_blocking=True)
+ category_code=data_batch["category_code"].to(device,non_blocking=True)
+ sample_dict={
+ "image":image.float(),
+ "proj_mat":proj_mat.float(),
+ "valid_frames":valid_frames.float(),
+ "category_code":category_code.float(),
+ "par_points":par_points.float(),
+ }
+ return sample_dict
+
+ def prepare_eval_data(self,data_batch):
+ device=torch.device("cuda")
+ samples=data_batch["points"].to(device, non_blocking=True)
+ labels=data_batch['labels'].to(device,non_blocking=True)
+
+ eval_dict={
+ "samples":samples,
+ "labels":labels,
+ }
+ return eval_dict
+
+ def extract_point_feat(self,par_points):
+ par_emb=self.point_embedder(par_points)
+ self.par_feat=self.par_points_encoder(par_points,par_emb)
+
+ def extract_img_feat(self,image):
+ self.image_emb=image
+
+ def set_proj_matrix(self,proj_matrix):
+ self.proj_matrix=proj_matrix
+
+ def set_valid_frames(self,valid_frames):
+ self.valid_frames=valid_frames
+
+ def set_category_code(self,category_code):
+ self.category_code=category_code
+
+ def forward(self, x, sigma,force_fp32=False):
+ x = x.to(torch.float32)
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) #B,1,1,1
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
+
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
+ c_noise = sigma.log() / 4 #B,1,1,1, need to check how to add embedding into unet
+
+ if self.use_par:
+ F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix,
+ self.valid_frames,self.category_code,self.par_feat)
+ else:
+ F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix,
+ self.valid_frames,self.category_code)
+ assert F_x.dtype == dtype
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
+ return D_x
+
+ def round_sigma(self, sigma):
+ return torch.as_tensor(sigma)
+
+ @torch.no_grad()
+ def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18):
+ img_cond=input_batch['image']
+ proj_mat=input_batch['proj_mat']
+ valid_frames=input_batch["valid_frames"]
+ category_code=input_batch["category_code"]
+ if img_cond is not None:
+ batch_size, device = img_cond.shape[0], img_cond.device
+ if batch_seeds is None:
+ batch_seeds = torch.arange(batch_size)
+ else:
+ device = batch_seeds.device
+ batch_size = batch_seeds.shape[0]
+
+ self.extract_img_feat(img_cond)
+ self.set_proj_matrix(proj_mat)
+ self.set_valid_frames(valid_frames)
+ self.set_category_code(category_code)
+ if self.use_par:
+ par_points=input_batch["par_points"]
+ self.extract_point_feat(par_points)
+ rnd = StackedRandomGenerator(device, batch_seeds)
+ latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device)
+
+ return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps)
+
+
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1af4a7b4034c52ade0ec2810620ada05e2c93e1
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,20 @@
+from .TriplaneVAE import TriplaneVAE
+from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
+from .Triplane_Diffusion import EDMLoss_MultiImgCond
+#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug
+
+def get_model(model_args):
+ if model_args['type']=="TriVAE":
+ model=TriplaneVAE(model_args)
+ elif model_args['type']=="triplane_diff_multiimg_cond":
+ model=Triplane_Diff_MultiImgCond_EDM(model_args)
+ else:
+ raise NotImplementedError
+ return model
+
+def get_criterion(cri_args):
+ if cri_args['type']=="EDMLoss_MultiImgCond":
+ criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
+ else:
+ raise NotImplementedError
+ return criterion
diff --git a/models/modules/PointEMB.py b/models/modules/PointEMB.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b5973155b8c1c30f5500ea28d09ca9fc7a6a30a
--- /dev/null
+++ b/models/modules/PointEMB.py
@@ -0,0 +1,34 @@
+import torch.nn as nn
+import torch
+import numpy as np
+
+class PointEmbed(nn.Module):
+ def __init__(self, hidden_dim=48):
+ super().__init__()
+
+ assert hidden_dim % 6 == 0
+
+ self.embedding_dim = hidden_dim
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
+ e = torch.stack([
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6), e]),
+ ])
+ self.register_buffer('basis', e) # 3 x 24
+
+
+ @staticmethod
+ def embed(input, basis):
+ projections = torch.einsum(
+ 'bnd,de->bne', input, basis) # N,24
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
+ return embeddings
+
+ def forward(self, input):
+ # input: B x N x 3
+ embed = self.embed(input, self.basis)
+ return embed
diff --git a/models/modules/Positional_Embedding.py b/models/modules/Positional_Embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f4722a52a4825543103868a70cbd7f5f3e5155e
--- /dev/null
+++ b/models/modules/Positional_Embedding.py
@@ -0,0 +1,15 @@
+import torch
+class PositionalEmbedding(torch.nn.Module):
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
+ super().__init__()
+ self.num_channels = num_channels
+ self.max_positions = max_positions
+ self.endpoint = endpoint
+
+ def forward(self, x):
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
+ freqs = (1 / self.max_positions) ** freqs
+ x = x.ger(freqs.to(x.dtype))
+ x = torch.cat([x.cos(), x.sin()], dim=1)
+ return x
diff --git a/models/modules/__init__.py b/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..464c972374c82b95726821f69d22bff6a7b484fe
--- /dev/null
+++ b/models/modules/__init__.py
@@ -0,0 +1,5 @@
+from .encoder import ConvPointnet_Encoder
+from .resnet_block import ResnetBlockFC
+from .unet import UNet,RollOut_Conv
+from .PointEMB import PointEmbed
+from .decoder import ConvPointnet_Decoder
\ No newline at end of file
diff --git a/models/modules/decoder.py b/models/modules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c89db21a92a51dd22fcf49ee8b5fc03830afe76e
--- /dev/null
+++ b/models/modules/decoder.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_scatter import scatter_mean, scatter_max
+from .unet import UNet
+from .resnet_block import ResnetBlockFC
+import numpy as np
+
+class ConvPointnet_Decoder(nn.Module):
+ ''' PointNet-based encoder network with ResNet blocks for each point.
+ Number of input points are fixed.
+
+ Args:
+ c_dim (int): dimension of latent code c
+ dim (int): input points dimension
+ hidden_dim (int): hidden dimension of the network
+ scatter_type (str): feature aggregation when doing local pooling
+ unet (bool): weather to use U-Net
+ unet_kwargs (str): U-Net parameters
+ plane_resolution (int): defined resolution for plane feature
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ '''
+
+ def __init__(self, latent_dim=32,query_emb_dim=51,hidden_dim=128, unet_kwargs=None,
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
+ super().__init__()
+
+ self.latent_dim=32
+ self.actvn = nn.ReLU()
+
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=latent_dim, **unet_kwargs)
+
+ self.fc_c=nn.ModuleList
+ self.reso_plane = plane_resolution
+ self.plane_type = plane_type
+ self.padding = padding
+ self.n_blocks=n_blocks
+
+ self.fc_c = nn.ModuleList([
+ nn.Linear(latent_dim*3, hidden_dim) for i in range(n_blocks)
+ ])
+ self.fc_p=nn.Linear(query_emb_dim,hidden_dim)
+ self.fc_out=nn.Linear(hidden_dim,1)
+
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(hidden_dim) for i in range(n_blocks)
+ ])
+
+ def forward(self, plane_features,query,query_emb): # , query2):
+ plane_feature=self.unet(plane_features)
+ H,W=plane_feature.shape[2:4]
+ xz_feat,xy_feat,yz_feat=torch.split(plane_feature,dim=2,split_size_or_sections=H//3)
+ xz_sample_feat=self.sample_plane_feature(query,xz_feat,'xz')
+ xy_sample_feat=self.sample_plane_feature(query,xy_feat,'xy')
+ yz_sample_feat=self.sample_plane_feature(query,yz_feat,'yz')
+
+ sample_feat=torch.cat([xz_sample_feat,xy_sample_feat,yz_sample_feat],dim=1)
+ sample_feat=sample_feat.transpose(1,2)
+
+ net=self.fc_p(query_emb)
+ for i in range(self.n_blocks):
+ net=net+self.fc_c[i](sample_feat)
+ net=self.blocks[i](net)
+ out=self.fc_out(self.actvn(net)).squeeze(-1)
+ return out
+
+
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
+
+ Args:
+ p (tensor): point
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
+ '''
+ if plane == 'xz':
+ xy = p[:, :, [0, 2]]
+ elif plane == 'xy':
+ xy = p[:, :, [0, 1]]
+ else:
+ xy = p[:, :, [1, 2]]
+ #print("origin",torch.amin(xy), torch.amax(xy))
+ xy=xy/2 #xy is originally -1 ~ 1
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
+ xy_new = xy_new + 0.5 # range (0, 1)
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
+
+ # f there are outliers out of the range
+ if xy_new.max() >= 1:
+ xy_new[xy_new >= 1] = 1 - 10e-6
+ if xy_new.min() < 0:
+ xy_new[xy_new < 0] = 0.0
+ return xy_new
+
+ def coordinate2index(self, x, reso):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
+ Corresponds to our 3D model
+
+ Args:
+ x (tensor): coordinate
+ reso (int): defined resolution
+ coord_type (str): coordinate type
+ '''
+ x = (x * reso).long()
+ index = x[:, :, 0] + reso * x[:, :, 1]
+ index = index[:, None, :]
+ return index
+
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
+ def sample_plane_feature(self, query, plane_feature, plane):
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
+ xy = xy[:, :, None].float()
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
+ mode='bilinear').squeeze(-1)
+ return sampled_feat
+
+
+
diff --git a/models/modules/diffusion_sampler.py b/models/modules/diffusion_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09f4705bc2a73b51635919088b7261720c33265
--- /dev/null
+++ b/models/modules/diffusion_sampler.py
@@ -0,0 +1,89 @@
+import torch
+import numpy as np
+
+def edm_sampler(
+ net, latents, randn_like=torch.randn_like,
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
+ # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
+):
+ # Adjust noise levels based on what's supported by the network.
+ sigma_min = max(sigma_min, net.sigma_min)
+ sigma_max = min(sigma_max, net.sigma_max)
+
+ # Time step discretization.
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
+
+ # Main sampling loop.
+ x_next = latents.to(torch.float64) * t_steps[0]
+ all_x=[]
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
+ x_cur = x_next
+
+ # Increase noise temporarily.
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
+
+ # Euler step.
+ denoised = net(x_hat, t_hat).to(torch.float64)
+ d_cur = (x_hat - denoised) / t_hat
+ x_next = x_hat + (t_next - t_hat) * d_cur
+
+ # Apply 2nd order correction.
+ if i < num_steps - 1:
+ denoised = net(x_next, t_next).to(torch.float64)
+ d_prime = (x_next - denoised) / t_next
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ all_x.append(x_next.clone()/(t_next**2+1).sqrt())
+
+ if ret_all:
+ return x_next,all_x
+
+ return x_next
+
+def edm_sampler_cond(
+ net, latents,cond_points, randn_like=torch.randn_like,
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
+ # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
+):
+ # Adjust noise levels based on what's supported by the network.
+ sigma_min = max(sigma_min, net.sigma_min)
+ sigma_max = min(sigma_max, net.sigma_max)
+
+ # Time step discretization.
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
+
+ # Main sampling loop.
+ x_next = latents.to(torch.float64) * t_steps[0]
+ all_x=[]
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
+ x_cur = x_next
+
+ # Increase noise temporarily.
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
+
+ # Euler step.
+ denoised = net(x_hat, t_hat,cond_points).to(torch.float64)
+ d_cur = (x_hat - denoised) / t_hat
+ x_next = x_hat + (t_next - t_hat) * d_cur
+
+ # Apply 2nd order correction.
+ if i < num_steps - 1:
+ denoised = net(x_next, t_next,cond_points).to(torch.float64)
+ d_prime = (x_next - denoised) / t_next
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
+ all_x.append(x_next.clone()/(t_next**2+1).sqrt())
+
+ if ret_all:
+ return x_next,all_x
+
+ return x_next
+
diff --git a/models/modules/encoder.py b/models/modules/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e73441d56356964c8258bbcfbf72329ff19d47a
--- /dev/null
+++ b/models/modules/encoder.py
@@ -0,0 +1,235 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_scatter import scatter_mean, scatter_max
+from .unet import UNet
+from .resnet_block import ResnetBlockFC
+import numpy as np
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, mean, logvar, deterministic=False):
+ self.mean = mean
+ self.logvar = logvar
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2,3])
+ else:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+class ConvPointnet_Encoder(nn.Module):
+ ''' PointNet-based encoder network with ResNet blocks for each point.
+ Number of input points are fixed.
+
+ Args:
+ c_dim (int): dimension of latent code c
+ dim (int): input points dimension
+ hidden_dim (int): hidden dimension of the network
+ scatter_type (str): feature aggregation when doing local pooling
+ unet (bool): weather to use U-Net
+ unet_kwargs (str): U-Net parameters
+ plane_resolution (int): defined resolution for plane feature
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ '''
+
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128,latent_dim=32, scatter_type='max',
+ unet=False, unet_kwargs=None,
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
+ super().__init__()
+ self.c_dim = c_dim
+
+ self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
+ ])
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
+
+ self.actvn = nn.ReLU()
+ self.hidden_dim = hidden_dim
+
+ if unet:
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
+ else:
+ self.unet = None
+
+ self.reso_plane = plane_resolution
+ self.plane_type = plane_type
+ self.padding = padding
+
+ if scatter_type == 'max':
+ self.scatter = scatter_max
+ elif scatter_type == 'mean':
+ self.scatter = scatter_mean
+
+ self.mean_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
+ self.logvar_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
+
+ # takes in "p": point cloud and "query": sdf_xyz
+ # sample plane features for unlabeled_query as well
+ def forward(self, p,point_emb): # , query2):
+ batch_size, T, D = p.size()
+ #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
+ # acquire the index for each point
+ coord = {}
+ index = {}
+ if 'xz' in self.plane_type:
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
+ if 'xy' in self.plane_type:
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
+ if 'yz' in self.plane_type:
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
+ net = self.fc_pos(point_emb)
+
+ net = self.blocks[0](net)
+ for block in self.blocks[1:]:
+ pooled = self.pool_local(coord, index, net)
+ net = torch.cat([net, pooled], dim=2)
+ net = block(net)
+
+ c = self.fc_c(net)
+ #print(c.shape)
+
+ fea = {}
+ plane_feat_sum = 0
+ # second_sum = 0
+ if 'xz' in self.plane_type:
+ fea['xz'] = self.generate_plane_features(p, c,
+ plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
+ if 'xy' in self.plane_type:
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
+ if 'yz' in self.plane_type:
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
+ cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
+ dim=2) # concat at row dimension
+ #print(cat_feature.shape)
+ plane_feat=self.unet(cat_feature)
+
+ mean=self.mean_fc(plane_feat)
+ logvar=self.logvar_fc(plane_feat)
+
+ posterior = DiagonalGaussianDistribution(mean, logvar)
+ x = posterior.sample()
+ kl = posterior.kl()
+
+ return kl, x, mean, logvar
+
+
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
+
+ Args:
+ p (tensor): point
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
+ '''
+ if plane == 'xz':
+ xy = p[:, :, [0, 2]]
+ elif plane == 'xy':
+ xy = p[:, :, [0, 1]]
+ else:
+ xy = p[:, :, [1, 2]]
+ #print("origin",torch.amin(xy), torch.amax(xy))
+ xy=xy/2 #xy is originally -1 ~ 1
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
+ xy_new = xy_new + 0.5 # range (0, 1)
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
+
+ # f there are outliers out of the range
+ if xy_new.max() >= 1:
+ xy_new[xy_new >= 1] = 1 - 10e-6
+ if xy_new.min() < 0:
+ xy_new[xy_new < 0] = 0.0
+ return xy_new
+
+ def coordinate2index(self, x, reso):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
+ Corresponds to our 3D model
+
+ Args:
+ x (tensor): coordinate
+ reso (int): defined resolution
+ coord_type (str): coordinate type
+ '''
+ x = (x * reso).long()
+ index = x[:, :, 0] + reso * x[:, :, 1]
+ index = index[:, None, :]
+ return index
+
+ # xy is the normalized coordinates of the point cloud of each plane
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
+ def pool_local(self, xy, index, c):
+ bs, fea_dim = c.size(0), c.size(2)
+ keys = xy.keys()
+
+ c_out = 0
+ for key in keys:
+ # scatter plane features from points
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
+ if self.scatter == scatter_max:
+ fea = fea[0]
+ # gather feature back to points
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
+ c_out += fea
+ return c_out.permute(0, 2, 1)
+
+ def generate_plane_features(self, p, c, plane='xz'):
+ # acquire indices of features in plane
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
+ index = self.coordinate2index(xy, self.reso_plane)
+
+ # scatter plane features from points
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
+ c = c.permute(0, 2, 1) # B x 512 x T
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
+ self.reso_plane) # sparce matrix (B x 512 x reso x reso)
+ #print(fea_plane.shape)
+
+ return fea_plane
+
+ # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
+ def sample_plane_feature(self, query, plane_feature, plane):
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
+ xy = xy[:, :, None].float()
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
+ mode='bilinear').squeeze(-1)
+ return sampled_feat
+
+
+
diff --git a/models/modules/image_sampler.py b/models/modules/image_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..681403e165aa63b4bc4ff9dcdc20a77342fd8fc0
--- /dev/null
+++ b/models/modules/image_sampler.py
@@ -0,0 +1,1046 @@
+import sys
+sys.path.append('../..')
+import torch
+import torch.nn as nn
+import math
+from models.modules.unet import RollOut_Conv
+from einops import rearrange, reduce
+MB =1024.0*1024.0
+def mask_kernel(x, sigma=1):
+ return torch.abs(x) < sigma #if the distance is smaller than the kernel size, return True
+
+def mask_kernel_close_false(x, sigma=1):
+ return torch.abs(x) > sigma #if the distance is smaller than the kernel size, return False
+
+class Image_Local_Sampler(nn.Module):
+ def __init__(self,reso,padding=0.1,in_channels=1280,out_channels=512):
+ super().__init__()
+ self.triplane_reso=reso
+ self.padding=padding
+ self.get_triplane_coord()
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1)
+ def get_triplane_coord(self):
+ '''xz plane firstly, z is at the '''
+ x=torch.arange(self.triplane_reso)
+ z=torch.arange(self.triplane_reso)
+ X,Z=torch.meshgrid(x,z,indexing='xy')
+ xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
+
+ '''xy plane'''
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ X, Y = torch.meshgrid(x, y, indexing='xy')
+ xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
+
+ '''yz plane'''
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+ Y,Z = torch.meshgrid(y,z,indexing='xy')
+ yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
+
+ triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
+ triplane_coords=triplane_coords/(self.triplane_reso-1)
+ triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
+ self.triplane_coords=triplane_coords.float().cuda()
+
+ def forward(self,image_feat,proj_mat):
+ image_feat=self.img_proj(image_feat)
+ batch_size=image_feat.shape[0]
+ triplane_coords=self.triplane_coords.unsqueeze(0).expand(batch_size,-1,-1,-1) #B,192,64,3
+ #print(torch.amin(triplane_coords),torch.amax(triplane_coords))
+ coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,triplane_coords.shape[1],triplane_coords.shape[2],1)).float().cuda()],dim=-1)
+ coord_inimg=torch.einsum('bhwc,bck->bhwk',coord_homo,proj_mat.transpose(1,2))
+ x=coord_inimg[:,:,:,0]/coord_inimg[:,:,:,2]
+ y=coord_inimg[:,:,:,1]/coord_inimg[:,:,:,2]
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
+ dist=coord_inimg[:,:,:,2]
+
+ xy=torch.cat([x[:,:,:,None],y[:,:,:,None]],dim=-1)
+ #print(image_feat.shape,xy.shape)
+ sample_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear')
+ return sample_feat
+
+def position_encoding(d_model, length):
+ if d_model % 2 != 0:
+ raise ValueError("Cannot use sin/cos positional encoding with "
+ "odd dim (got dim={:d})".format(d_model))
+ pe = torch.zeros(length, d_model)
+ position = torch.arange(0, length).unsqueeze(1) #length,1
+ div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
+ -(math.log(10000.0) / d_model))) #d_model//2, this is the frequency
+ pe[:, 0::2] = torch.sin(position.float() * div_term) #length*(d_model//2)
+ pe[:, 1::2] = torch.cos(position.float() * div_term)
+
+ return pe
+
+class Image_Vox_Local_Sampler(nn.Module):
+ def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
+ super().__init__()
+ self.triplane_reso=reso
+ self.padding=padding
+ self.get_vox_coord()
+ self.out_channels=out_channels
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
+
+ self.vox_process=nn.Sequential(
+ nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1,),
+ )
+ self.k=nn.Linear(in_features=inner_channel,out_features=inner_channel)
+ self.q=nn.Linear(in_features=inner_channel,out_features=inner_channel)
+ self.v=nn.Linear(in_features=inner_channel,out_features=inner_channel)
+ self.attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+
+ self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
+ self.condition_pe = position_encoding(inner_channel, self.triplane_reso).unsqueeze(0)
+ def get_vox_coord(self):
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
+ vox_coor=vox_coor/(self.triplane_reso-1)
+ vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
+ self.vox_coor=vox_coor.view(-1,3).float().cuda()
+
+
+ def forward(self,triplane_feat,image_feat,proj_mat):
+ xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
+ image_feat=self.img_proj(image_feat)
+ batch_size=image_feat.shape[0]
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
+ vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
+ coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
+ x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
+ y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
+
+ xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
+ #print(image_feat.shape,xy.shape)
+ grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
+ view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
+
+ grid_feat=self.vox_process(grid_feat)
+ xzy_grid=grid_feat.permute(0,4,2,3,1)
+ xz_as_query=xz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
+ xz_as_key=xzy_grid.reshape(batch_size*self.triplane_reso**2,self.triplane_reso,-1)
+
+ xyz_grid=grid_feat.permute(0,3,2,4,1)
+ xy_as_query=xy_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
+ xy_as_key = xyz_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
+
+ yzx_grid = grid_feat.permute(0, 4, 3, 2, 1)
+ yz_as_query = yz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
+ yz_as_key = yzx_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
+
+ query=self.q(torch.cat([xz_as_query,xy_as_query,yz_as_query],dim=0))
+ key=self.k(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
+ value=self.v(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
+
+ attn,_=self.attn(query,key,value)
+ xz_plane,xy_plane,yz_plane=torch.split(attn,dim=0,split_size_or_sections=batch_size*self.triplane_reso**2)
+ xz_plane=xz_plane.reshape(batch_size,self.triplane_reso,self.triplane_reso,-1).permute(0,3,1,2)
+ xy_plane = xy_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
+ yz_plane = yz_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
+
+ triplane_wImg=torch.cat([xz_plane,xy_plane,yz_plane],dim=2)
+ triplane_wImg=self.proj_out(triplane_wImg)
+ #print(triplane_wImg.shape)
+
+ return triplane_wImg
+
+class Image_Direct_AttenwMask_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
+ super().__init__()
+ self.triplane_reso=reso
+ self.vit_reso=vit_reso
+ self.padding=padding
+ self.n_heads=n_heads
+ self.get_plane_expand_coord()
+ self.get_vit_coords()
+ self.out_channels=out_channels
+ self.kernel_func=mask_kernel
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2+1).unsqueeze(0).cuda().float() #1,n_img*reso*reso,channel
+ self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
+ def get_plane_expand_coord(self):
+ x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+ y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+ z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+
+ first,second,third=torch.meshgrid(x,y,z,indexing='xy')
+ xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
+ xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
+ xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
+ yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
+
+ # print(xyz_coords[0,0,0],xyz_coords[0,0,1],xyz_coords[1,0,0],xyz_coords[0,1,0])
+ # print(xzy_coords[0, 0, 0], xzy_coords[0, 0, 1], xzy_coords[1, 0, 0], xzy_coords[0, 1, 0])
+ # print(yzx_coords[0, 0, 0], yzx_coords[0, 0, 1], yzx_coords[1, 0, 0], yzx_coords[0, 1, 0])
+
+ xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
+ xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
+ yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
+
+ coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
+ self.plane_coords=coords.cuda().float()
+ # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
+ # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
+ # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
+
+ def get_vit_coords(self):
+ x=torch.arange(self.vit_reso)
+ y=torch.arange(self.vit_reso)
+
+ X,Y=torch.meshgrid(x,y,indexing='xy')
+ vit_coords=torch.stack([X,Y],dim=-1)
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
+
+ def get_attn_mask(self,coords_proj,vit_coords,kernel_size=1.0):
+ '''
+ :param coords_proj: B,reso**3,2, in range of 0~1
+ :param vit_coords: B,vit_reso**2,2, in range of 0~vit_reso
+ :param kernel_size: 0.5, so that only one pixel will be select
+ :return:
+ '''
+ bs=coords_proj.shape[0]
+ coords_proj=coords_proj*(self.vit_reso-1)
+ #print(torch.amin(coords_proj[0,0:self.triplane_reso**3]),torch.amax(coords_proj[0,0:self.triplane_reso**3]))
+ dist=torch.cdist(coords_proj.float(),vit_coords.float())
+ mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B,3*reso**3,vit_reso**2
+ mask=mask.reshape(bs,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
+ mask=torch.sum(mask,dim=2)
+ attn_mask=(mask==0)
+ return attn_mask
+
+ def forward(self,triplane_feat,image_feat,proj_mat):
+ #xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
+ batch_size=image_feat.shape[0]
+ #print(self.plane_coords.shape)
+ coords=self.plane_coords.unsqueeze(0).expand(batch_size,-1,-1)
+
+ coords_homo=torch.cat([coords,torch.ones(batch_size,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
+ coords_inimg=torch.einsum('bhc,bck->bhk',coords_homo,proj_mat.transpose(1,2))
+ coords_x=coords_inimg[:,:,0]/coords_inimg[:,:,2]/(224.0-1) #0~1
+ coords_y=coords_inimg[:,:,1]/coords_inimg[:,:,2]/(224.0-1) #0~1
+ coords_x=torch.clamp(coords_x,min=0.0,max=1.0)
+ coords_y=torch.clamp(coords_y,min=0.0,max=1.0)
+ #print(torch.amin(coords_x),torch.amax(coords_x))
+ coords_proj=torch.stack([coords_x,coords_y],dim=-1)
+ vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1)
+ attn_mask=torch.repeat_interleave(
+ self.get_attn_mask(coords_proj,vit_coords,kernel_size=1.0),self.n_heads, 0
+ )
+ attn_mask = torch.cat([torch.zeros([attn_mask.shape[0], attn_mask.shape[1], 1]).cuda().bool(), attn_mask],
+ dim=-1) # add global token
+ #print(attn_mask.shape,torch.sum(attn_mask.float()))
+ triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
+ #print(triplane_feat.shape,self.triplane_pe.shape)
+ query=self.q(triplane_feat)+self.triplane_pe
+ key=self.k(image_feat)+self.image_pe
+ value=self.v(image_feat)+self.image_pe
+ #print(query.shape,key.shape,value.shape)
+ attn,_=self.attn(query,key,value,attn_mask=attn_mask)
+ #print(attn.shape)
+ output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
+
+ return output
+
+class MultiImage_Direct_AttenwMask_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
+ super().__init__()
+ self.triplane_reso=reso
+ self.vit_reso=vit_reso
+ self.padding=padding
+ self.n_heads=n_heads
+ self.get_plane_expand_coord()
+ self.get_vit_coords()
+ self.out_channels=out_channels
+ self.kernel_func=mask_kernel
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
+ self.image_pe = position_encoding(inner_channel, max_nimg*(self.vit_reso**2+1)).unsqueeze(0).cuda().float()
+ self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
+ def get_plane_expand_coord(self):
+ x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+ y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+ z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
+
+ first,second,third=torch.meshgrid(x,y,z,indexing='xy')
+ xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
+ xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
+ xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
+ yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
+
+ xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
+ xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
+ yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
+
+ coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
+ self.plane_coords=coords.cuda().float()
+ # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
+ # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
+ # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
+
+ def get_vit_coords(self):
+ x=torch.arange(self.vit_reso)
+ y=torch.arange(self.vit_reso)
+
+ X,Y=torch.meshgrid(x,y,indexing='xy')
+ vit_coords=torch.stack([X,Y],dim=-1)
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
+
+ def get_attn_mask(self,coords_proj,vit_coords,valid_frames,kernel_size=1.0):
+ '''
+ :param coords_proj: B,n_img,3*reso**3,2, in range of 0~vit_reso
+ :param vit_coords: B,n_img,vit_reso**2,2, in range of 0~vit_reso
+ :param kernel_size: 0.5, so that only one pixel will be select
+ :return:
+ '''
+ bs,n_img=coords_proj.shape[0],coords_proj.shape[1]
+ coords_proj_flat=coords_proj.reshape(bs*n_img,3*self.triplane_reso**3,2)
+ vit_coords_flat=vit_coords.reshape(bs*n_img,self.vit_reso**2,2)
+ dist=torch.cdist(coords_proj_flat.float(),vit_coords_flat.float())
+ mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B*n_img,3*reso**3,vit_reso**2
+ mask=mask.reshape(bs,n_img,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
+ mask=torch.sum(mask,dim=3) #B,n_img,3*reso**2,vit_reso**2
+ mask=torch.cat([torch.ones(size=mask.shape[0:3]).unsqueeze(3).float().cuda(),mask],dim=-1) #B,n_img,3*reso**2,vit_reso**2+1, add global mask
+ mask[valid_frames == 0, :, :] = False
+ mask=mask.permute(0,2,1,3).reshape(bs,3*self.triplane_reso**2,-1) #B,3*reso**2,n_img*(vit_resso**2+1)
+ attn_mask=(mask==0) #invert the mask, False indicates valid, True indicates invalid
+ return attn_mask
+
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
+ '''image feat is bs,n_img,length,channel'''
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
+ img_length=image_feat.shape[2]
+ image_feat_flat=image_feat.view(batch_size,n_img*img_length,-1)
+ coords=self.plane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
+
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
+ #print(coord_homo.shape,proj_mat.shape)
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
+ x = x/(224.0-1)
+ y = y/(224.0-1)
+ coords_x=torch.clamp(x,min=0.0,max=1.0)*(self.vit_reso-1)
+ coords_y=torch.clamp(y,min=0.0,max=1.0)*(self.vit_reso-1)
+ coords_proj=torch.stack([coords_x,coords_y],dim=-1)
+ vit_coords=self.vit_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
+ attn_mask=torch.repeat_interleave(
+ self.get_attn_mask(coords_proj,vit_coords,valid_frames,kernel_size=1.0),self.n_heads, 0
+ )
+ triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
+ query=self.q(triplane_feat)+self.triplane_pe
+ key=self.k(image_feat_flat)+self.image_pe
+ value=self.v(image_feat_flat)+self.image_pe
+ attn,_=self.attn(query,key,value,attn_mask=attn_mask)
+ output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
+
+ return output
+
+class MultiImage_Fuse_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
+ super().__init__()
+ self.triplane_reso=reso
+ self.vit_reso=vit_reso
+ self.inner_channel=inner_channel
+ self.padding=padding
+ self.n_heads=n_heads
+ self.get_vox_coord()
+ self.get_vit_coords()
+ self.out_channels=out_channels
+ self.kernel_func=mask_kernel
+ self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+
+ #self.cross_attn=CrossAttention(query_dim=inner_channel,heads=8,dim_head=inner_channel//8)
+ self.cross_attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].cuda().float() #1,1,length,channel
+ #self.image_pe = self.image_pe.reshape(1,max_nimg,self.vit_reso,self.vit_reso,inner_channel)
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
+
+ def get_vit_coords(self):
+ x = torch.arange(self.vit_reso)
+ y = torch.arange(self.vit_reso)
+
+ X, Y = torch.meshgrid(x, y, indexing='xy')
+ vit_coords = torch.stack([X, Y], dim=-1)
+ self.vit_coords = vit_coords.cuda().float() #reso,reso,2
+
+ def get_vox_coord(self):
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+
+ X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
+ vox_coor = torch.cat([X[:, :, :, None], Y[:, :, :, None], Z[:, :, :, None]], dim=-1)
+ self.vox_index = vox_coor.view(-1, 3).long().cuda()
+
+ vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
+ vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
+ self.vox_coor = vox_coor.view(-1, 3).float().cuda()
+
+ def get_attn_mask(self,valid_frames):
+ '''
+ :param valid_frames: of shape B,n_img
+ '''
+ #print(valid_frames)
+ #bs,n_img=valid_frames.shape[0:2]
+ attn_mask=(valid_frames.float()==0)
+ #attn_mask=attn_mask.unsqueeze(1).unsqueeze(2).expand(-1,self.triplane_reso**3,-1,-1) #B,1,n_img
+ #attn_mask=attn_mask.reshape(bs*self.triplane_reso**3,-1,n_img).bool()
+ attn_mask=torch.repeat_interleave(attn_mask.unsqueeze(1),self.triplane_reso**3,0)
+ # print(attn_mask[self.triplane_reso**3*1+10])
+ # print(attn_mask[self.triplane_reso ** 3 * 2+10])
+ # print(attn_mask[self.triplane_reso ** 3 * 3+10])
+ return attn_mask
+
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
+ '''image feat is bs,n_img,length,channel'''
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
+ image_feat=image_feat[:,:,1:,:] #discard global feature
+
+ #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
+ image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
+ unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
+ #unflat_k_v=image_k_v.permute(0,4,1,2,3)
+ #vit_coords=self.vit_coords[None,None].expand(batch_size,n_img,-1,-1,-1) #Bs,n_img,reso,reso,2
+
+ coords=self.vox_coor.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3,1).float().cuda()],dim=-1)
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
+ x = x/(224.0-1) #0~1
+ y = y/(224.0-1)
+ coords_proj=torch.stack([x,y],dim=-1)
+ coords_proj=(coords_proj-0.5)*2
+ img_index=((torch.arange(n_img)[None,:,None,None].expand(
+ batch_size,-1,self.triplane_reso**3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
+
+ # img_index_feat=torch.arange(n_img)[None,:,None,None,None].expand(
+ # batch_size,-1,self.vit_reso,self.vit_reso,-1).float().cuda() #Bs,n_img,reso,reso,1
+ #coords_feat=torch.cat([vit_coords,img_index_feat],dim=-1).permute(0,4,1,2,3)#Bs,n_img,reso,reso,3
+ grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
+ grid=torch.clamp(grid,min=-1.0,max=1.0)
+ sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
+ xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3,
+ dim=2) # B,C,64,64
+ xz_vox_feat=xz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,4,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zxy
+ xz_vox_feat=rearrange(xz_vox_feat, 'b c z x y -> b (x y z) c')
+ xy_vox_feat=xy_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,2,4).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #yxz
+ xy_vox_feat=rearrange(xy_vox_feat, 'b c y x z -> b (x y z) c')
+ yz_vox_feat=yz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,4,3,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zyx
+ yz_vox_feat=rearrange(yz_vox_feat, 'b c z y x -> b (x y z) c')
+ #xz_vox_feat = xz_feat[:, :, vox_index[:, 2], vox_index[:, 0]].transpose(1, 2) # B,C,64*64*64
+ #xy_vox_feat = xy_feat[:, :, vox_index[:, 1], vox_index[:, 0]].transpose(1, 2)
+ #yz_vox_feat = yz_feat[:, :, vox_index[:, 2], vox_index[:, 1]].transpose(1, 2)
+
+ triplane_expand_feat = torch.cat([xz_vox_feat, xy_vox_feat, yz_vox_feat], dim=-1) # B,64*64*64,3*C
+ triplane_query = self.q(triplane_expand_feat) + self.triplane_pe
+ k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
+ #k_v=sample_k_v.permute(0,3,2,1).reshape(batch_size*self.triplane_reso**3,n_img,-1) #B*64**3,n_img,C
+ k=k_v[:,:,0:self.inner_channel]
+ v=k_v[:,:,self.inner_channel:]
+ q=rearrange(triplane_query,'b k c -> (b k) 1 c')
+ #q=triplane_query.view(batch_size*self.triplane_reso**3,1,-1)
+ #k,v is of shape, B*reso**3,k,channel, q is of shape B*reso**3,1,channel
+ #attn mask should be B*reso**3*n_heads,1,k
+ #attn_mask=torch.repeat_interleave(self.get_attn_mask(valid_frames),self.n_heads,0)
+ #print(q.shape,k.shape,v.shape)
+ attn_out,_=self.cross_attn(q,k,v)#attn_mask=attn_mask) #fuse multi-view feature
+ #volume=attn_out.view(batch_size,self.triplane_reso,self.triplane_reso,self.triplane_reso,-1) #B,reso,reso,reso,channel
+ #print(attn_out.shape)
+ volume=rearrange(attn_out,'(b x y z) 1 c -> b x y z c',x=self.triplane_reso,y=self.triplane_reso,z=self.triplane_reso)
+ #xz_feat = torch.mean(volume, dim=2).transpose(1,2) #B,reso,reso,C
+ xz_feat = reduce(volume, "b x y z c -> b z x c", 'mean')
+ #xy_feat = torch.mean(volume, dim=3).transpose(1,2) #B,reso,reso,C
+ xy_feat= reduce(volume, 'b x y z c -> b y x c', 'mean')
+ #yz_feat = torch.mean(volume, dim=1).transpose(1,2) #B,reso,reso,C
+ yz_feat=reduce(volume, 'b x y z c -> b z y c', 'mean')
+ triplane_out = torch.cat([xz_feat, xy_feat, yz_feat], dim=1) #B,reso*3,reso,C
+ #print(triplane_out.shape)
+ triplane_out = self.proj_out(triplane_out)
+ triplane_out = triplane_out.permute(0,3,1,2)
+ #print(triplane_out.shape)
+ return triplane_out
+
+class MultiImage_TriFuse_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
+ super().__init__()
+ self.triplane_reso=reso
+ self.vit_reso=vit_reso
+ self.inner_channel=inner_channel
+ self.padding=padding
+ self.n_heads=n_heads
+ self.get_triplane_coord()
+ self.get_vit_coords()
+ self.out_channels=out_channels
+ self.kernel_func=mask_kernel
+ self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+
+ self.cross_attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+ self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 2*3).unsqueeze(0).cuda().float()
+
+ def get_vit_coords(self):
+ x = torch.arange(self.vit_reso)
+ y = torch.arange(self.vit_reso)
+
+ X, Y = torch.meshgrid(x, y, indexing='xy')
+ vit_coords = torch.stack([X, Y], dim=-1)
+ self.vit_coords = vit_coords.cuda().float() #reso,reso,2
+
+ def get_triplane_coord(self):
+ '''xz plane firstly, z is at the '''
+ x = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+ X, Z = torch.meshgrid(x, z, indexing='xy')
+ xz_coords = torch.cat(
+ [X[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2, Z[:, :, None]],
+ dim=-1) # in xyz order
+
+ '''xy plane'''
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ X, Y = torch.meshgrid(x, y, indexing='xy')
+ xy_coords = torch.cat(
+ [X[:, :, None], Y[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2],
+ dim=-1) # in xyz order
+
+ '''yz plane'''
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+ Y, Z = torch.meshgrid(y, z, indexing='xy')
+ yz_coords = torch.cat(
+ [torch.ones_like(Y[:, :, None]) * (self.triplane_reso - 1) / 2, Y[:, :, None], Z[:, :, None]], dim=-1)
+
+ triplane_coords = torch.cat([xz_coords, xy_coords, yz_coords], dim=0)
+ triplane_coords = triplane_coords / (self.triplane_reso - 1)
+ triplane_coords = (triplane_coords - 0.5) * 2 * (1 + self.padding + 10e-6)
+ self.triplane_coords = triplane_coords.view(-1,3).float().cuda()
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
+ '''image feat is bs,n_img,length,channel'''
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
+ image_feat=image_feat[:,:,1:,:] #discard global feature
+ #print(image_feat.shape)
+
+ #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
+ image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
+ unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
+
+ coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
+ coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**2*3,1).float().cuda()],dim=-1)
+ coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
+ x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
+ y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
+ x = x/(224.0-1) #0~1
+ y = y/(224.0-1)
+ coords_proj=torch.stack([x,y],dim=-1)
+ coords_proj=(coords_proj-0.5)*2
+ img_index=((torch.arange(n_img)[None,:,None,None].expand(
+ batch_size,-1,self.triplane_reso**2*3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
+
+ grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
+ grid=torch.clamp(grid,min=-1.0,max=1.0)
+ sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
+
+ triplane_flat_feat=rearrange(triplane_feat,'b c h w -> b (h w) c')
+ triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
+
+ k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
+ k=k_v[:,:,0:self.inner_channel]
+ v=k_v[:,:,self.inner_channel:]
+ q=rearrange(triplane_query,'b k c -> (b k) 1 c')
+ attn_out,_=self.cross_attn(q,k,v)
+ triplane_out=rearrange(attn_out,'(b h w) 1 c -> b c h w',b=batch_size,h=self.triplane_reso*3,w=self.triplane_reso)
+ triplane_out = self.proj_out(triplane_out)
+ return triplane_out
+
+
+class MultiImage_Global_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
+ img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
+ super().__init__()
+ self.triplane_reso=reso
+ self.vit_reso=vit_reso
+ self.inner_channel=inner_channel
+ self.padding=padding
+ self.n_heads=n_heads
+ self.out_channels=out_channels
+ self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+ self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
+ self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
+
+ self.cross_attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+ self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
+ self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso**2*3).unsqueeze(0).cuda().float()
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
+ '''image feat is bs,n_img,length,channel
+ triplane feat is bs,C,H*3,W
+ '''
+ batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
+ L=image_feat.shape[2]-1
+ image_feat=image_feat[:,:,1:,:] #discard global feature
+
+ image_k=self.k(image_feat)+self.image_pe #B,n_img,h*w,c
+ image_v=self.v(image_feat)+self.image_pe #B,n_img,h*w,c
+ image_k=image_k.view(batch_size,n_img*L,-1)
+ image_v=image_v.view(batch_size,n_img*L,-1)
+
+ triplane_flat_feat=rearrange(triplane_feat,"b c h w -> b (h w) c")
+ triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
+ #print(triplane_query.shape,image_k.shape,image_v.shape)
+ attn_out,_=self.cross_attn(triplane_query,image_k,image_v)
+ triplane_flat_out = self.proj_out(attn_out)
+ triplane_out=rearrange(triplane_flat_out,"b (h w) c -> b c h w",h=self.triplane_reso*3,w=self.triplane_reso)
+
+ return triplane_out
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+
+ if context_dim is None:
+ context_dim = query_dim
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, q,k,v):
+ h = self.heads
+
+ q, k, v = map(lambda t: rearrange(
+ t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+class Image_Vox_Local_Sampler_Pooling(nn.Module):
+ def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,stride=4):
+ super().__init__()
+ self.triplane_reso=reso
+ self.padding=padding
+ self.get_vox_coord()
+ self.out_channels=out_channels
+ self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
+
+ self.vox_process=nn.Sequential(
+ nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1)
+ )
+ self.xz_conv=nn.Sequential(
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((1,stride,1),stride=(1,stride,1)), #8
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((1,stride,1), stride=(1,stride,1)), #2
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ )
+ self.xy_conv = nn.Sequential(
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 8
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 2
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ )
+ self.yz_conv = nn.Sequential(
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 8
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 2
+ nn.BatchNorm3d(inner_channel),
+ nn.ReLU(),
+ nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
+ )
+ self.roll_out_conv=RollOut_Conv(in_channels=inner_channel,out_channels=out_channels)
+ #self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
+ def get_vox_coord(self):
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
+ vox_coor=vox_coor/(self.triplane_reso-1)
+ vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
+ self.vox_coor=vox_coor.view(-1,3).float().cuda()
+
+
+ def forward(self,image_feat,proj_mat):
+ image_feat=self.img_proj(image_feat)
+ batch_size=image_feat.shape[0]
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
+ vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
+ coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
+ x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
+ y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
+ x=(x/(224.0-1.0)-0.5)*2 #-1~1
+ y=(y/(224.0-1.0)-0.5)*2 #-1~1
+
+ xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
+ #print(image_feat.shape,xy.shape)
+ grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
+ view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
+
+ grid_feat=self.vox_process(grid_feat)
+ xz_feat=torch.mean(self.xz_conv(grid_feat),dim=3).transpose(2,3)
+ xy_feat=torch.mean(self.xy_conv(grid_feat),dim=4).transpose(2,3)
+ yz_feat=torch.mean(self.yz_conv(grid_feat),dim=2).transpose(2,3)
+ triplane_wImg=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
+ #print(triplane_wImg.shape)
+
+ return self.roll_out_conv(triplane_wImg)
+
+class Image_ExpandVox_attn_Sampler(nn.Module):
+ def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
+ super().__init__()
+ self.triplane_reso=reso
+ self.padding=padding
+ self.vit_reso=vit_reso
+ self.get_vox_coord()
+ self.get_vit_coords()
+ self.out_channels=out_channels
+ self.n_heads=n_heads
+
+ self.kernel_func = mask_kernel_close_false
+ self.k = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
+ # self.q_xz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
+ # self.q_xy = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
+ # self.q_yz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
+ self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
+
+ self.v = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
+ self.attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
+ self.out_proj=nn.Linear(in_features=inner_channel,out_features=out_channels)
+
+ self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
+ self.image_pe = position_encoding(inner_channel, self.vit_reso ** 2+1).unsqueeze(0).cuda().float()
+ def get_vox_coord(self):
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+
+ X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
+ vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
+ self.vox_index=vox_coor.view(-1,3).long().cuda()
+
+
+ vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
+ vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
+ self.vox_coor = vox_coor.view(-1, 3).float().cuda()
+ # print(self.vox_coor[0])
+ # print(self.vox_coor[self.triplane_reso**2])#x should increase
+ # print(self.vox_coor[self.triplane_reso]) #y should increase
+ # print(self.vox_coor[1])#z should increase
+
+ def get_vit_coords(self):
+ x=torch.arange(self.vit_reso)
+ y=torch.arange(self.vit_reso)
+
+ X,Y=torch.meshgrid(x,y,indexing='xy')
+ vit_coords=torch.stack([X,Y],dim=-1)
+ self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
+
+ def compute_attn_mask(self,proj_coords,vit_coords,kernel_size=1.0):
+ dist = torch.cdist(proj_coords.float(), vit_coords.float())
+ mask = self.kernel_func(dist, sigma=kernel_size) # True if valid, B,reso**3,vit_reso**2
+ return mask
+
+
+ def forward(self,triplane_feat,image_feat,proj_mat):
+ xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, dim=2) # B,C,64,64
+ #xz_feat=self.q_xz(xz_feat)
+ #xy_feat=self.q_xy(xy_feat)
+ #yz_feat=self.q_yz(yz_feat)
+ batch_size=image_feat.shape[0]
+ vox_index=self.vox_index #64*64*64,3
+ xz_vox_feat=xz_feat[:,:,vox_index[:,2],vox_index[:,0]].transpose(1,2) #B,C,64*64*64
+ xy_vox_feat=xy_feat[:,:,vox_index[:,1],vox_index[:,0]].transpose(1,2)
+ yz_vox_feat=yz_feat[:,:,vox_index[:,2],vox_index[:,1]].transpose(1,2)
+ triplane_expand_feat=torch.cat([xz_vox_feat,xy_vox_feat,yz_vox_feat],dim=-1)#B,C,64*64*64,3
+ triplane_query=self.q(triplane_expand_feat)+self.triplane_pe
+
+ '''compute projection'''
+ vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #
+ vox_homo = torch.cat([vox_coords, torch.ones((batch_size, self.triplane_reso ** 3, 1)).float().cuda()], dim=-1)
+ coord_inimg = torch.einsum('bhc,bck->bhk', vox_homo, proj_mat.transpose(1, 2))
+ x = coord_inimg[:, :, 0] / coord_inimg[:, :, 2]
+ y = coord_inimg[:, :, 1] / coord_inimg[:, :, 2]
+ #
+ x = x / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1
+ y = y / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 #B,N
+ xy=torch.stack([x,y],dim=-1) #B,64*64*64,2
+ xy=torch.clamp(xy,min=0,max=self.vit_reso-1)
+ vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) #B, 16*16,2
+ attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,vit_coords,kernel_size=0.5),
+ self.n_heads,0) #B*n_heads, reso**3, vit_reso**2
+
+ k=self.k(image_feat)+self.image_pe
+ v=self.v(image_feat)+self.image_pe
+ attn_mask=torch.cat([torch.zeros([attn_mask.shape[0],attn_mask.shape[1],1]).cuda().bool(),attn_mask],dim=-1) #add empty token to each key and value
+ vox_feat,_=self.attn(triplane_query,k,v,attn_mask=attn_mask) #B,reso**3,C
+ feat_volume=self.out_proj(vox_feat).transpose(1,2).reshape(batch_size,-1,self.triplane_reso,
+ self.triplane_reso,self.triplane_reso)
+ xz_feat=torch.mean(feat_volume,dim=3).transpose(2,3)
+ xy_feat=torch.mean(feat_volume,dim=4).transpose(2,3)
+ yz_feat=torch.mean(feat_volume,dim=2).transpose(2,3)
+ triplane_out=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
+ return triplane_out
+
+class Multi_Image_Fusion(nn.Module):
+ def __init__(self,reso,image_reso=16,padding=0.1,img_channels=1280,triplane_channel=64,inner_channels=128,output_channel=64,n_heads=8):
+ super().__init__()
+ self.triplane_reso=reso
+ self.image_reso=image_reso
+ self.padding=padding
+ self.get_triplane_coord()
+ self.get_vit_coords()
+ self.img_proj=nn.Conv3d(in_channels=img_channels,out_channels=512,kernel_size=1)
+ self.kernel_func=mask_kernel
+
+ self.q = nn.Linear(in_features=triplane_channel, out_features=inner_channels, bias=False)
+ self.k = nn.Linear(in_features=512, out_features=inner_channels)
+ self.v = nn.Linear(in_features=512, out_features=inner_channels)
+
+ self.attn = torch.nn.MultiheadAttention(
+ embed_dim=inner_channels, num_heads=n_heads, batch_first=True)
+ self.out_proj=nn.Linear(in_features=inner_channels,out_features=output_channel)
+ self.n_heads=n_heads
+
+ def get_triplane_coord(self):
+ '''xz plane firstly, z is at the '''
+ x=torch.arange(self.triplane_reso)
+ z=torch.arange(self.triplane_reso)
+ X,Z=torch.meshgrid(x,z,indexing='xy')
+ xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
+
+ '''xy plane'''
+ x = torch.arange(self.triplane_reso)
+ y = torch.arange(self.triplane_reso)
+ X, Y = torch.meshgrid(x, y, indexing='xy')
+ xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
+
+ '''yz plane'''
+ y = torch.arange(self.triplane_reso)
+ z = torch.arange(self.triplane_reso)
+ Y,Z = torch.meshgrid(y,z,indexing='xy')
+ yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
+
+ triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
+ triplane_coords=triplane_coords/(self.triplane_reso-1)
+ triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
+ self.triplane_coords=triplane_coords.float().cuda()
+
+ def get_vit_coords(self):
+ x=torch.arange(self.image_reso)
+ y=torch.arange(self.image_reso)
+ X,Y=torch.meshgrid(x,y,indexing='xy')
+ vit_coords=torch.cat([X[:,:,None],Y[:,:,None]],dim=-1)
+ self.vit_coords=vit_coords.float().cuda() #in x,y order
+
+ def compute_attn_mask(self,proj_coord,vit_coords,valid_frames,kernel_size=2.0):
+ '''
+ :param proj_coord: B,K,H,W,2
+ :param vit_coords: H,W,2
+ :return:
+ '''
+ B,K=proj_coord.shape[0:2]
+ vit_coords_expand=vit_coords[None,None,:,:,:].expand(B,K,-1,-1,-1)
+
+ proj_coord=proj_coord.view(B*K,proj_coord.shape[2]*proj_coord.shape[3],proj_coord.shape[4])
+ vit_coords_expand=vit_coords_expand.view(B*K,self.image_reso*self.image_reso,2)
+ attn_mask=self.kernel_func(torch.cdist(proj_coord,vit_coords_expand),sigma=float(kernel_size))
+ attn_mask=attn_mask.reshape(B,K,proj_coord.shape[1],vit_coords_expand.shape[1])
+ valid_expand=valid_frames[:,:,None,None]
+ attn_mask[valid_frames>0,:,:]=True
+ attn_mask=attn_mask.permute(0,2,1,3)
+ attn_mask=attn_mask.reshape(B,proj_coord.shape[1],K*vit_coords_expand.shape[1])
+ atten_index=torch.where(attn_mask[0,0]==False)
+ return attn_mask
+
+
+ def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
+ '''
+ :param image_feat: B,C,K,16,16
+ :param proj_mat: B,K,4,4
+ :param valid_frames: B,K, true if have image, used to compute attn_mask for transformer
+ :return:
+ '''
+ image_feat=self.img_proj(image_feat)
+ batch_size=image_feat.shape[0] #K is number of frames
+ K=image_feat.shape[2]
+ triplane_coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,K,-1,-1,-1) #B,K,192,64,3
+ #print(torch.amin(triplane_coords),torch.amax(triplane_coords))
+ coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,K,triplane_coords.shape[2],triplane_coords.shape[3],1)).float().cuda()],dim=-1)
+ #print(coord_homo.shape,proj_mat.shape)
+ coord_inimg=torch.einsum('bjhwc,bjck->bjhwk',coord_homo,proj_mat.transpose(2,3))
+ x=coord_inimg[:,:,:,:,0]/coord_inimg[:,:,:,:,2]
+ y=coord_inimg[:,:,:,:,1]/coord_inimg[:,:,:,:,2]
+ x=x/(224.0-1.0)*(self.image_reso-1)
+ y=y/(224.0-1.0)*(self.image_reso-1)
+
+ xy=torch.cat([x[...,None],y[...,None]],dim=-1) #B,K,H,W,2
+ image_value=image_feat.view(image_feat.shape[0],image_feat.shape[1],-1).transpose(1,2)
+ triplane_query=triplane_feat.view(triplane_feat.shape[0],triplane_feat.shape[1],-1).transpose(1,2)
+ valid_frames=1.0-valid_frames.float()
+ attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,self.vit_coords,valid_frames),
+ self.n_heads,dim=0)
+
+ q=self.q(triplane_query)
+ k=self.k(image_value)
+ v=self.v(image_value)
+ #print(q.shape,k.shape,v.shape)
+
+ attn,_=self.attn(q,k,v,attn_mask=attn_mask)
+ #print(attn.shape)
+ output=self.out_proj(attn).transpose(1,2).reshape(batch_size,-1,triplane_feat.shape[2],triplane_feat.shape[3])
+ #print(output.shape)
+ return output
+
+
+if __name__=="__main__":
+ # import sys
+ # sys.path.append("../..")
+ # from datasets.SingleView_dataset import Object_PartialPoints_Img
+ # from datasets.transforms import Aug_with_Tran
+ # #sampler=#Image_Vox_Local_Sampler_Pooling(reso=64,padding=0.1,out_channels=64,stride=4).cuda().float()
+ # sampler=Image_ExpandVox_attn_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=64,inner_channel=64
+ # ,out_channels=64,n_heads=8).cuda().float()
+ # # sampler=Image_Direct_AttenwMask_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
+ # # ,out_channels=64,n_heads=8).cuda().float()
+ # dataset_config = {
+ # "data_path": "/data1/haolin/datasets",
+ # "surface_size": 20000,
+ # "par_pc_size": 4096,
+ # "load_proj_mat": True,
+ # }
+ # transform = Aug_with_Tran()
+ # datasets = Object_PartialPoints_Img(dataset_config['data_path'], split_filename="val_par_img.json", split='val',
+ # transform=transform, sampling=False,
+ # num_samples=1024, return_surface=True, ret_sample=True,
+ # surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
+ # surface_size=dataset_config['surface_size'],
+ # load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
+ # load_org_img=False, load_triplane=True, replica=1)
+ #
+ # dataloader = torch.utils.data.DataLoader(
+ # datasets=datasets,
+ # batch_size=64,
+ # shuffle=True
+ # )
+ # iterator = dataloader.__iter__()
+ # data_batch = iterator.next()
+ # unflatten = torch.nn.Unflatten(1, (16, 16))
+ # image = data_batch['image'][:,:,:].cuda().float()
+ # #image=unflatten(image).permute(0,3,1,2)
+ # proj_mat = data_batch['proj_mat'].cuda().float()
+ # triplane_feat=torch.randn((64,64,32*3,32)).cuda().float()
+ # sampler(triplane_feat,image,proj_mat)
+ # memory_usage=torch.cuda.max_memory_allocated() / MB
+ # print("memory usage %f mb"%(memory_usage))
+
+
+ import sys
+ sys.path.append("../..")
+ from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
+ from datasets.transforms import Aug_with_Tran
+
+ dataset_config = {
+ "data_path": "/data1/haolin/datasets",
+ "surface_size": 20000,
+ "par_pc_size": 4096,
+ "load_proj_mat": True,
+ }
+ transform = Aug_with_Tran()
+ dataset = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename="train_par_img.json", split='train',
+ transform=transform, sampling=False,
+ num_samples=1024, return_surface=True, ret_sample=True,
+ surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
+ surface_size=dataset_config['surface_size'],
+ load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
+ load_org_img=False, load_triplane=True, replica=1)
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ batch_size=10,
+ shuffle=False
+ )
+ iterator = dataloader.__iter__()
+ data_batch = iterator.next()
+ #unflatten = torch.nn.Unflatten(2, (16, 16))
+ image = data_batch['image'][:,:,:,:].cuda().float()
+ #image=unflatten(image).permute(0,4,1,2,3)
+ proj_mat = data_batch['proj_mat'].cuda().float()
+ valid_frames = data_batch['valid_frames'].cuda().float()
+ triplane_feat=torch.randn((10,128,32*3,32)).cuda().float()
+
+ # fusion_module=MultiImage_Fuse_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
+ # ,out_channels=64,n_heads=8).cuda().float()
+ fusion_module=MultiImage_Global_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
+ ,out_channels=64,n_heads=8).cuda().float()
+ fusion_module(triplane_feat,image,proj_mat,valid_frames)
+ memory_usage=torch.cuda.max_memory_allocated() / MB
+ print("memory usage %f mb"%(memory_usage))
diff --git a/models/modules/parpoints_encoder.py b/models/modules/parpoints_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1648d66f751d5310d0c6ac383bf343c29f2afd45
--- /dev/null
+++ b/models/modules/parpoints_encoder.py
@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_scatter import scatter_mean, scatter_max
+from .unet import UNet
+from .resnet_block import ResnetBlockFC
+from .PointEMB import PointEmbed
+import numpy as np
+
+class ParPoint_Encoder(nn.Module):
+ ''' PointNet-based encoder network with ResNet blocks for each point.
+ Number of input points are fixed.
+
+ Args:
+ c_dim (int): dimension of latent code c
+ dim (int): input points dimension
+ hidden_dim (int): hidden dimension of the network
+ scatter_type (str): feature aggregation when doing local pooling
+ unet (bool): weather to use U-Net
+ unet_kwargs (str): U-Net parameters
+ plane_resolution (int): defined resolution for plane feature
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ n_blocks (int): number of blocks ResNetBlockFC layers
+ '''
+
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', unet_kwargs=None,
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
+ super().__init__()
+ self.c_dim = c_dim
+
+ self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
+ self.blocks = nn.ModuleList([
+ ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
+ ])
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
+
+ self.actvn = nn.ReLU()
+ self.hidden_dim = hidden_dim
+
+ self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
+
+ self.reso_plane = plane_resolution
+ self.plane_type = plane_type
+ self.padding = padding
+
+ if scatter_type == 'max':
+ self.scatter = scatter_max
+ elif scatter_type == 'mean':
+ self.scatter = scatter_mean
+
+ # takes in "p": point cloud and "query": sdf_xyz
+ # sample plane features for unlabeled_query as well
+ def forward(self, p,point_emb): # , query2):
+ batch_size, T, D = p.size()
+ #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
+ # acquire the index for each point
+ coord = {}
+ index = {}
+ if 'xz' in self.plane_type:
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
+ if 'xy' in self.plane_type:
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
+ if 'yz' in self.plane_type:
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
+ net = self.fc_pos(point_emb)
+
+ net = self.blocks[0](net)
+ for block in self.blocks[1:]:
+ pooled = self.pool_local(coord, index, net)
+ net = torch.cat([net, pooled], dim=2)
+ net = block(net)
+
+ c = self.fc_c(net)
+ #print(c.shape)
+
+ fea = {}
+ # second_sum = 0
+ if 'xz' in self.plane_type:
+ fea['xz'] = self.generate_plane_features(p, c,
+ plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
+ if 'xy' in self.plane_type:
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
+ if 'yz' in self.plane_type:
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
+ cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
+ dim=2) # concat at row dimension
+ #print(cat_feature.shape)
+ plane_feat=self.unet(cat_feature)
+
+ return plane_feat
+
+
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
+
+ Args:
+ p (tensor): point
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
+ '''
+ if plane == 'xz':
+ xy = p[:, :, [0, 2]]
+ elif plane == 'xy':
+ xy = p[:, :, [0, 1]]
+ else:
+ xy = p[:, :, [1, 2]]
+ #print("origin",torch.amin(xy), torch.amax(xy))
+ xy=xy/2 #xy is originally -1 ~ 1
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
+ xy_new = xy_new + 0.5 # range (0, 1)
+ #print("scale",torch.amin(xy_new),torch.amax(xy_new))
+
+ # f there are outliers out of the range
+ if xy_new.max() >= 1:
+ xy_new[xy_new >= 1] = 1 - 10e-6
+ if xy_new.min() < 0:
+ xy_new[xy_new < 0] = 0.0
+ return xy_new
+
+ def coordinate2index(self, x, reso):
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
+ Corresponds to our 3D model
+
+ Args:
+ x (tensor): coordinate
+ reso (int): defined resolution
+ coord_type (str): coordinate type
+ '''
+ x = (x * reso).long()
+ index = x[:, :, 0] + reso * x[:, :, 1]
+ index = index[:, None, :]
+ return index
+
+ # xy is the normalized coordinates of the point cloud of each plane
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
+ def pool_local(self, xy, index, c):
+ bs, fea_dim = c.size(0), c.size(2)
+ keys = xy.keys()
+
+ c_out = 0
+ for key in keys:
+ # scatter plane features from points
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
+ if self.scatter == scatter_max:
+ fea = fea[0]
+ # gather feature back to points
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
+ c_out += fea
+ return c_out.permute(0, 2, 1)
+
+ def generate_plane_features(self, p, c, plane='xz'):
+ # acquire indices of features in plane
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
+ index = self.coordinate2index(xy, self.reso_plane)
+
+ # scatter plane features from points
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
+ c = c.permute(0, 2, 1) # B x 512 x T
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
+ self.reso_plane) # sparce matrix (B x 512 x reso x reso)
+ #print(fea_plane.shape)
+
+ return fea_plane
\ No newline at end of file
diff --git a/models/modules/point_transformer.py b/models/modules/point_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9263a16bd1b19db9cc229355a57a44d931ca170
--- /dev/null
+++ b/models/modules/point_transformer.py
@@ -0,0 +1,442 @@
+from torch import nn, einsum
+import torch
+import torch.nn.functional as F
+from einops import rearrange,repeat
+from timm.models.layers import DropPath
+from torch_cluster import fps
+import numpy as np
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+class PositionalEmbedding(torch.nn.Module):
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
+ super().__init__()
+ self.num_channels = num_channels
+ self.max_positions = max_positions
+ self.endpoint = endpoint
+
+ def forward(self, x):
+ freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
+ freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
+ freqs = (1 / self.max_positions) ** freqs
+ x = x.ger(freqs.to(x.dtype))
+ x = torch.cat([x.cos(), x.sin()], dim=1)
+ return x
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+
+ if context_dim is None:
+ context_dim = query_dim
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+
+ if context is None:
+ context = x
+
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(
+ t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ if dim_out is None:
+ dim_out = dim
+
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, n_embd):
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(n_embd, n_embd*2)
+ self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ emb = self.linear(timestep)
+ scale, shift = torch.chunk(emb, 2, dim=2)
+ x = self.layernorm(x) * (1 + scale) + shift
+ return x
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = AdaLayerNorm(dim)
+ self.norm2 = AdaLayerNorm(dim)
+ self.norm3 = AdaLayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ init_values = 0
+ drop_path = 0.0
+
+
+ self.ls1 = LayerScale(
+ dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+
+ self.ls2 = LayerScale(
+ dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+
+ self.ls3 = LayerScale(
+ dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path3 = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x, t, context=None):
+ x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x
+ x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x
+ x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x
+ return x
+
+class LatentArrayTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(self, in_channels, t_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
+ block=BasicTransformerBlock):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+
+ self.t_channels = t_channels
+
+ self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
+
+ self.transformer_blocks = nn.ModuleList(
+ [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for _ in range(depth)]
+ )
+
+ self.norm = nn.LayerNorm(inner_dim)
+
+ if out_channels is None:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
+ else:
+ self.num_cls = out_channels
+ self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
+
+ self.context_dim = context_dim
+
+ self.map_noise = PositionalEmbedding(t_channels)
+
+ self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
+ self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
+
+ # ###
+ # self.pos_emb = nn.Embedding(512, inner_dim)
+ # ###
+
+ def forward(self, x, t, cond, class_emb):
+
+ t_emb = self.map_noise(t)[:, None]
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+
+ x = self.proj_in(x)
+ #print(class_emb.shape,t_emb.shape)
+ for block in self.transformer_blocks:
+ x = block(x, t_emb+class_emb[:,None,:], context=cond)
+
+ x = self.norm(x)
+
+ x = self.proj_out(x)
+ return x
+
+class PointTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(self, in_channels, t_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
+ block=BasicTransformerBlock):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+
+ self.t_channels = t_channels
+
+ self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
+
+ self.transformer_blocks = nn.ModuleList(
+ [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for _ in range(depth)]
+ )
+
+ self.norm = nn.LayerNorm(inner_dim)
+
+ if out_channels is None:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
+ else:
+ self.num_cls = out_channels
+ self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
+
+ self.context_dim = context_dim
+
+ self.map_noise = PositionalEmbedding(t_channels)
+
+ self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
+ self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
+
+ # ###
+ # self.pos_emb = nn.Embedding(512, inner_dim)
+ # ###
+
+ def forward(self, x, t, cond=None):
+
+ t_emb = self.map_noise(t)[:, None]
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+
+ x = self.proj_in(x)
+
+ for block in self.transformer_blocks:
+ x = block(x, t_emb, context=cond)
+
+ x = self.norm(x)
+
+ x = self.proj_out(x)
+ return x
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+def cache_fn(f):
+ cache = None
+ @wraps(f)
+ def cached_fn(*args, _cache = True, **kwargs):
+ if not _cache:
+ return f(*args, **kwargs)
+ nonlocal cache
+ if cache is not None:
+ return cache
+ cache = f(*args, **kwargs)
+ return cache
+ return cached_fn
+
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn, context_dim = None):
+ super().__init__()
+ self.fn = fn
+ self.norm = nn.LayerNorm(dim)
+ self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
+
+ def forward(self, x, **kwargs):
+ x = self.norm(x)
+
+ if exists(self.norm_context):
+ context = kwargs['context']
+ normed_context = self.norm_context(context)
+ kwargs.update(context = normed_context)
+
+ return self.fn(x, **kwargs)
+
+class Attention(nn.Module):
+ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x, context = None, mask = None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim = -1)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h = h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim = -1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
+ return self.drop_path(self.to_out(out))
+
+
+class PointEmbed(nn.Module):
+ def __init__(self, hidden_dim=48, dim=128):
+ super().__init__()
+
+ assert hidden_dim % 6 == 0
+
+ self.embedding_dim = hidden_dim
+ e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
+ e = torch.stack([
+ torch.cat([e, torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6), e,
+ torch.zeros(self.embedding_dim // 6)]),
+ torch.cat([torch.zeros(self.embedding_dim // 6),
+ torch.zeros(self.embedding_dim // 6), e]),
+ ])
+ self.register_buffer('basis', e) # 3 x 16
+
+ self.mlp = nn.Linear(self.embedding_dim + 3, dim)
+
+ @staticmethod
+ def embed(input, basis):
+ projections = torch.einsum(
+ 'bnd,de->bne', input, basis)
+ embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
+ return embeddings
+
+ def forward(self, input):
+ # input: B x N x 3
+ embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
+ return embed
+
+
+class PointEncoder(nn.Module):
+ def __init__(self,
+ dim=512,
+ num_inputs = 2048,
+ num_latents = 512,
+ latent_dim = 512):
+ super().__init__()
+
+ self.num_inputs = num_inputs
+ self.num_latents = num_latents
+
+ self.cross_attend_blocks = nn.ModuleList([
+ PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim),
+ PreNorm(dim, FeedForward(dim))
+ ])
+
+ self.point_embed = PointEmbed(dim=dim)
+ self.proj=nn.Linear(dim,latent_dim)
+ def encode(self, pc):
+ # pc: B x N x 3
+ B, N, D = pc.shape
+ assert N == self.num_inputs
+
+ ###### fps
+ flattened = pc.view(B * N, D)
+
+ batch = torch.arange(B).to(pc.device)
+ batch = torch.repeat_interleave(batch, N)
+
+ pos = flattened
+
+ ratio = 1.0 * self.num_latents / self.num_inputs
+
+ idx = fps(pos, batch, ratio=ratio)
+
+ sampled_pc = pos[idx]
+ sampled_pc = sampled_pc.view(B, -1, 3)
+ ######
+
+ sampled_pc_embeddings = self.point_embed(sampled_pc)
+
+ pc_embeddings = self.point_embed(pc)
+
+ cross_attn, cross_ff = self.cross_attend_blocks
+
+ x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
+ x = cross_ff(x) + x
+
+ return self.proj(x)
\ No newline at end of file
diff --git a/models/modules/pointnet2_backbone.py b/models/modules/pointnet2_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..9787e9a965c914b0bb4e06db171893f2e4665d5a
--- /dev/null
+++ b/models/modules/pointnet2_backbone.py
@@ -0,0 +1,188 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+import sys
+import os
+from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
+from .utils import zero_module
+from .Positional_Embedding import PositionalEmbedding
+
+class Pointnet2Encoder(nn.Module):
+ def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]):
+ super().__init__()
+ self.sa1 = PointnetSAModuleVotes(
+ npoint=npoints[0],
+ radius=radius[0],
+ nsample=nsample[0],
+ mlp=[input_feature_dim, 64, 64, 128],
+ use_xyz=True,
+ normalize_xyz=True
+ )
+
+ self.sa2 = PointnetSAModuleVotes(
+ npoint=npoints[1],
+ radius=radius[1],
+ nsample=nsample[1],
+ mlp=[128, 128, 128, 256],
+ use_xyz=True,
+ normalize_xyz=True
+ )
+
+ self.sa3 = PointnetSAModuleVotes(
+ npoint=npoints[2],
+ radius=radius[2],
+ nsample=nsample[2],
+ mlp=[256, 256, 256, 512],
+ use_xyz=True,
+ normalize_xyz=True
+ )
+
+ self.sa4 = PointnetSAModuleVotes(
+ npoint=npoints[3],
+ radius=radius[3],
+ nsample=nsample[3],
+ mlp=[512, 512, 512, 512],
+ use_xyz=True,
+ normalize_xyz=True
+ )
+ def _break_up_pc(self, pc):
+ xyz = pc[..., 0:3].contiguous()
+ features = (
+ pc[..., 3:].transpose(1, 2).contiguous()
+ if pc.size(-1) > 3 else None
+ )
+
+ return xyz, features
+ def forward(self,pointcloud,end_points=None):
+ if not end_points: end_points = {}
+ batch_size = pointcloud.shape[0]
+
+ xyz, features = self._break_up_pc(pointcloud)
+
+ end_points['org_xyz'] = xyz
+ # --------- 4 SET ABSTRACTION LAYERS ---------
+ xyz1, features1, _ = self.sa1(xyz, features)
+ end_points['sa1_xyz'] = xyz1
+ end_points['sa1_features'] = features1
+
+ xyz2, features2, _ = self.sa2(xyz1, features1) # this fps_inds is just 0,1,...,1023
+ end_points['sa2_xyz'] = xyz2
+ end_points['sa2_features'] = features2
+
+ xyz3, features3, _ = self.sa3(xyz2, features2) # this fps_inds is just 0,1,...,511
+ end_points['sa3_xyz'] = xyz3
+ end_points['sa3_features'] = features3
+ #print(xyz3.shape,features3.shape)
+ xyz4, features4, _ = self.sa4(xyz3, features3) # this fps_inds is just 0,1,...,255
+ end_points['sa4_xyz'] = xyz4
+ end_points['sa4_features'] = features4
+ #print(xyz4.shape,features4.shape)
+ return end_points
+
+
+
+class PointUNet(nn.Module):
+ r"""
+ Backbone network for point cloud feature learning.
+ Based on Pointnet++ single-scale grouping network.
+
+ Parameters
+ ----------
+ input_feature_dim: int
+ Number of input channels in the feature descriptor for each point.
+ e.g. 3 for RGB.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.noisy_encoder=Pointnet2Encoder()
+ self.cond_encoder=Pointnet2Encoder()
+ self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512])
+ self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
+ #self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
+ self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256])
+ self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256])
+ #self.fp2=PointnetFPModule(mlp=[512 + 256, 512, 256])
+ self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128])
+ self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128])
+ #self.fp3 = PointnetFPModule(mlp=[256 + 128, 256, 128])
+ self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128])
+ self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
+ #self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
+
+ self.output_layer=nn.Sequential(
+ nn.LayerNorm(128),
+ zero_module(nn.Linear(in_features=128,out_features=3,bias=False))
+ )
+ self.t_emb_layer = PositionalEmbedding(256)
+ self.map_layer0 = nn.Linear(in_features=256, out_features=512)
+ self.map_layer1 = nn.Linear(in_features=512, out_features=512)
+
+ def forward(self, noise_points, t,cond_points):
+ r"""
+ Forward pass of the network
+
+ Parameters
+ ----------
+ pointcloud: Variable(torch.cuda.FloatTensor)
+ (B, N, 3 + input_feature_dim) tensor
+ Point cloud to run predicts on
+ Each point in the point-cloud MUST
+ be formated as (x, y, z, features...)
+
+ Returns
+ ----------
+ end_points: {XXX_xyz, XXX_features, XXX_inds}
+ XXX_xyz: float32 Tensor of shape (B,K,3)
+ XXX_features: float32 Tensor of shape (B,K,D)
+ XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
+ """
+ t_emb = self.t_emb_layer(t)
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))#B,512
+ t_emb = t_emb[:, :, None] #B,512,K
+ noise_end_points=self.noisy_encoder(noise_points)
+ cond=self.cond_encoder(cond_points)
+ # --------- 2 FEATURE UPSAMPLING LAYERS --------
+ features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb,
+ cond['sa4_features'])
+ features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'],
+ features)
+ features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features,
+ cond["sa3_features"])
+ features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
+ features)
+ features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features,
+ cond['sa2_features'])
+ features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
+ features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features,
+ cond['sa1_features'])
+ features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
+ features=features.transpose(1,2)
+
+ # features = self.fp1_cross(noise_end_points['sa4_xyz'], cond_end_points['sa4_xyz'],
+ # noise_end_points['sa4_features']+t_emb, cond_end_points['sa4_features'])
+ # features = self.fp1(noise_end_points['sa3_xyz'].clone(), noise_end_points['sa4_xyz'].clone(), noise_end_points['sa3_features'],
+ # features)
+ # features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
+ # features)
+ # features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
+ # features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
+ # features = features.transpose(1,2)
+ output_points=self.output_layer(features)
+
+ return output_points
+
+
+if __name__ == '__main__':
+ net=PointUNet().cuda().float()
+ net=net.eval()
+ noise_points=torch.randn(16,4096,3).cuda().float()
+ cond_points=torch.randn(16,4096,3).cuda().float()
+ t=torch.randn(16).cuda().float()
+ cond_encoder=Pointnet2Encoder().cuda().float()
+
+ out = net(noise_points,cond_points)
+ print(out.shape)
\ No newline at end of file
diff --git a/models/modules/resnet_block.py b/models/modules/resnet_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..b715ff52626ad8e878593082d6c4926295261c68
--- /dev/null
+++ b/models/modules/resnet_block.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# Resnet Blocks
+class ResnetBlockFC(nn.Module):
+ ''' Fully connected ResNet Block class.
+ Args:
+ size_in (int): input dimension
+ size_out (int): output dimension
+ size_h (int): hidden dimension
+ '''
+
+ def __init__(self, size_in, size_out=None, size_h=None):
+ super().__init__()
+ # Attributes
+ if size_out is None:
+ size_out = size_in
+
+ if size_h is None:
+ size_h = min(size_in, size_out)
+
+ self.size_in = size_in
+ self.size_h = size_h
+ self.size_out = size_out
+ # Submodules
+ self.fc_0 = nn.Linear(size_in, size_h)
+ self.fc_1 = nn.Linear(size_h, size_out)
+ self.actvn = nn.ReLU()
+
+ if size_in == size_out:
+ self.shortcut = None
+ else:
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
+ # Initialization
+ nn.init.zeros_(self.fc_1.weight)
+
+ def forward(self, x):
+ net = self.fc_0(self.actvn(x))
+ dx = self.fc_1(self.actvn(net))
+
+ if self.shortcut is not None:
+ x_s = self.shortcut(x)
+ else:
+ x_s = x
+
+ return x_s + dx
\ No newline at end of file
diff --git a/models/modules/resunet.py b/models/modules/resunet.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd871dd43eb3adb33cbd34783f0c275d014a9eac
--- /dev/null
+++ b/models/modules/resunet.py
@@ -0,0 +1,440 @@
+import torch
+import torch.nn as nn
+from .unet import RollOut_Conv
+from .Positional_Embedding import PositionalEmbedding
+import torch.nn.functional as F
+from .utils import zero_module
+from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler
+
+class ResidualConv_MultiImgAtten(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
+ norm="batch"):
+ super(ResidualConv_MultiImgAtten, self).__init__()
+ self.use_attn=use_attn
+
+ if norm=="batch":
+ norm_layer=nn.BatchNorm2d
+ elif norm==None:
+ norm_layer=nn.Identity
+
+ self.conv_block = nn.Sequential(
+ norm_layer(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, padding=padding
+ )
+ )
+ self.out_layer=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
+ norm_layer(output_dim),
+ )
+ self.roll_out_conv=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ RollOut_Conv(output_dim, output_dim),
+ )
+ if self.use_attn:
+ self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
+ out_channels=output_dim,padding=triplane_padding)
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
+
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+ t_emb = t_emb[:,:,None,None]
+
+ out=self.conv_block(x)+t_emb
+ out=self.out_layer(out)
+ feature=out+self.conv_skip(x)
+ feature = self.roll_out_conv(feature)
+ if self.use_attn:
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
+ feature=self.down_conv(feature)
+
+ return feature
+
+class ResidualConv_TriMultiImgAtten(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
+ norm="batch"):
+ super(ResidualConv_TriMultiImgAtten, self).__init__()
+ self.use_attn=use_attn
+
+ if norm=="batch":
+ norm_layer=nn.BatchNorm2d
+ elif norm==None:
+ norm_layer=nn.Identity
+
+ self.conv_block = nn.Sequential(
+ norm_layer(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, padding=padding
+ )
+ )
+ self.out_layer=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
+ norm_layer(output_dim),
+ )
+ self.roll_out_conv=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ RollOut_Conv(output_dim, output_dim),
+ )
+ if self.use_attn:
+ self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
+ out_channels=output_dim,max_nimg=5,padding=triplane_padding)
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
+
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+ t_emb = t_emb[:,:,None,None]
+
+ out=self.conv_block(x)+t_emb
+ out=self.out_layer(out)
+ feature=out+self.conv_skip(x)
+ feature = self.roll_out_conv(feature)
+ if self.use_attn:
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
+ feature=self.down_conv(feature)
+
+ return feature
+
+
+class ResidualConv_GlobalAtten(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding, reso=64,
+ vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
+ norm="batch"):
+ super(ResidualConv_GlobalAtten, self).__init__()
+ self.use_attn=use_attn
+
+ if norm=="batch":
+ norm_layer=nn.BatchNorm2d
+ elif norm==None:
+ norm_layer=nn.Identity
+
+ self.conv_block = nn.Sequential(
+ norm_layer(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, padding=padding
+ )
+ )
+ self.out_layer=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
+ norm_layer(output_dim),
+ )
+ self.roll_out_conv=nn.Sequential(
+ norm_layer(output_dim),
+ nn.ReLU(),
+ RollOut_Conv(output_dim, output_dim),
+ )
+ if self.use_attn:
+ self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
+ img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
+ out_channels=output_dim,max_nimg=5,padding=triplane_padding)
+ self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
+
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
+ def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+ t_emb = t_emb[:,:,None,None]
+
+ out=self.conv_block(x)+t_emb
+ out=self.out_layer(out)
+ feature=out+self.conv_skip(x)
+ feature = self.roll_out_conv(feature)
+ if self.use_attn:
+ feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
+ feature=self.down_conv(feature)
+
+ return feature
+
+class ResidualConv(nn.Module):
+ def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256):
+ super(ResidualConv, self).__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.BatchNorm2d(input_dim),
+ nn.ReLU(),
+ nn.Conv2d(
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
+ ),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ RollOut_Conv(output_dim,output_dim),
+ )
+ self.out_layer=nn.Sequential(
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(),
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+ )
+ self.conv_skip = nn.Sequential(
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
+ nn.BatchNorm2d(output_dim),
+ )
+
+ self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
+ self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
+ def forward(self, x,t_emb):
+ t_emb = F.silu(self.map_layer0(t_emb))
+ t_emb = F.silu(self.map_layer1(t_emb))
+ t_emb = t_emb[:,:,None,None]
+
+ out=self.conv_block(x)+t_emb
+ out=self.out_layer(out)
+
+ return out + self.conv_skip(x)
+
+class Upsample(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel, stride):
+ super(Upsample, self).__init__()
+
+ self.upsample = nn.ConvTranspose2d(
+ input_dim, output_dim, kernel_size=kernel, stride=stride
+ )
+
+ def forward(self, x):
+ return self.upsample(x)
+
+
+
+class ResUnet_Par_cond(nn.Module):
+ def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32):
+ super(ResUnet_Par_cond, self).__init__()
+
+ self.input_layer = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+ nn.BatchNorm2d(filters[0]),
+ nn.ReLU(),
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+ )
+ self.input_skip = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+ )
+
+ self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1)
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
+ self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1)
+ self.bridge = ResidualConv(filters[3],filters[4],2,1)
+
+
+ self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
+ self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1)
+
+ self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
+ self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
+
+ self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
+ self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
+
+ self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
+ self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1)
+
+ self.output_layer = nn.Sequential(
+ #nn.LayerNorm(filters[0]),
+ nn.LayerNorm(64),#normalize along width dimension, usually it should normalize along channel dimension,
+ # I don't know why, but the finetuning performance increase significantly
+ zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
+ )
+ self.par_channel=par_channel
+ self.par_conv=nn.Sequential(
+ nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
+ )
+ self.t_emb_layer=PositionalEmbedding(256)
+ self.cat_emb=nn.Linear(
+ in_features=6,
+ out_features=256,
+ )
+
+ def forward(self, x,t,category_code,par_point_feat):
+ # Encode
+ t_emb=self.t_emb_layer(t)
+ cat_emb=self.cat_emb(category_code)
+ t_emb=t_emb+cat_emb
+ #print(t_emb.shape)
+ x1 = self.input_layer(x) + self.input_skip(x)
+ if par_point_feat is not None:
+ par_point_feat=self.par_conv(par_point_feat)
+ else:
+ bs,_,H,W=x1.shape
+ #print(x1.shape)
+ par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device)
+ x1 = torch.cat([x1, par_point_feat], dim=1)
+ x2 = self.residual_conv_1(x1,t_emb)
+ x3 = self.residual_conv_2(x2,t_emb)
+ # Bridge
+ x4 = self.residual_conv_3(x3,t_emb)
+ x5 = self.bridge(x4,t_emb)
+
+ x6=self.upsample_1(x5)
+ x6=torch.cat([x6,x4],dim=1)
+ x7=self.up_residual_conv1(x6,t_emb)
+
+ x7=self.upsample_2(x7)
+ x7=torch.cat([x7,x3],dim=1)
+ x8=self.up_residual_conv2(x7,t_emb)
+
+ x8 = self.upsample_3(x8)
+ x8 = torch.cat([x8, x2], dim=1)
+ #print(x8.shape)
+ x9 = self.up_residual_conv3(x8,t_emb)
+
+ x9 = self.upsample_4(x9)
+ x9 = torch.cat([x9, x1], dim=1)
+ x10 = self.up_residual_conv4(x9,t_emb)
+
+ output=self.output_layer(x10)
+
+ return output
+
+class ResUnet_DirectAttenMultiImg_Cond(nn.Module):
+ def __init__(self, channel, filters=[64, 128, 256, 512, 1024],
+ img_in_channels=1024,vit_reso=16,output_channel=32,
+ use_par=False,par_channel=32,triplane_padding=0.1,norm='batch',
+ use_cat_embedding=False,
+ block_type="multiview_local"):
+ super(ResUnet_DirectAttenMultiImg_Cond, self).__init__()
+
+ if block_type == "multiview_local":
+ block=ResidualConv_MultiImgAtten
+ elif block_type =="multiview_global":
+ block=ResidualConv_GlobalAtten
+ elif block_type =="multiview_tri":
+ block=ResidualConv_TriMultiImgAtten
+ else:
+ raise NotImplementedError
+
+ if norm=="batch":
+ norm_layer=nn.BatchNorm2d
+ elif norm==None:
+ norm_layer=nn.Identity
+
+ self.use_cat_embedding=use_cat_embedding
+ self.input_layer = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+ norm_layer(filters[0]),
+ nn.ReLU(),
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+ )
+ self.input_skip = nn.Sequential(
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+ )
+ self.use_par=use_par
+ input_1_channels=filters[0]
+ if self.use_par:
+ self.par_conv = nn.Sequential(
+ nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
+ )
+ input_1_channels=filters[0]+par_channel
+ self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64
+ ,use_attn=False,triplane_padding=triplane_padding,norm=norm)
+ self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32,
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
+ self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16,
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
+ self.bridge = block(filters[3] , filters[4], 2, 1, reso=8
+ ,use_attn=False,triplane_padding=triplane_padding,norm=norm) #input reso is 8, output reso is 4
+
+
+ self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
+ self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso,
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
+
+ self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
+ self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso,
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
+
+ self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
+ self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso,
+ use_attn=True,triplane_padding=triplane_padding,norm=norm)
+
+ self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
+ self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64,
+ use_attn=False,triplane_padding=triplane_padding,norm=norm)
+
+ self.output_layer = nn.Sequential(
+ nn.LayerNorm(64), #normalize along width dimension, usually it should normalize along channel dimension,
+ # I don't know why, but the finetuning performance increase significantly
+ #nn.LayerNorm([filters[0], 192, 64]),
+ zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
+ )
+ self.t_emb_layer=PositionalEmbedding(256)
+ if use_cat_embedding:
+ self.cat_emb = nn.Linear(
+ in_features=6,
+ out_features=256,
+ )
+
+ def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None):
+ # Encode
+ t_emb=self.t_emb_layer(t)
+ if self.use_cat_embedding:
+ cat_emb=self.cat_emb(category_code)
+ t_emb=t_emb+cat_emb
+ x1 = self.input_layer(x) + self.input_skip(x)
+ if self.use_par:
+ par_point_feat=self.par_conv(par_point_feat)
+ x1 = torch.cat([x1, par_point_feat], dim=1)
+ x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames)
+ x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames)
+ x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames)
+ x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames)
+
+ x6=self.upsample_1(x5)
+ x6=torch.cat([x6,x4],dim=1)
+ x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames)
+
+ x7=self.upsample_2(x7)
+ x7=torch.cat([x7,x3],dim=1)
+ x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames)
+
+ x8 = self.upsample_3(x8)
+ x8 = torch.cat([x8, x2], dim=1)
+ #print(x8.shape)
+ x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames)
+
+ x9 = self.upsample_4(x9)
+ x9 = torch.cat([x9, x1], dim=1)
+ x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames)
+
+ output=self.output_layer(x10)
+
+ return output
+
+
+if __name__=="__main__":
+ net=ResUnet(32,output_channel=32).float().cuda()
+ n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
+ print("Model = %s" % str(net))
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+ par_point_feat=torch.randn((10,32,64*3,64)).float().cuda()
+ input=torch.randn((10,32,64*3,64)).float().cuda()
+ t=torch.randn((10,1,1,1)).float().cuda()
+ output=net(input,t.flatten(),par_point_feat)
+ #print(output.shape)
\ No newline at end of file
diff --git a/models/modules/unet.py b/models/modules/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..dceeae5b5a562a20952361b7ee61b80550bee0ed
--- /dev/null
+++ b/models/modules/unet.py
@@ -0,0 +1,304 @@
+'''
+Codes are from:
+https://github.com/jaxony/unet-pytorch/blob/master/model.py
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from collections import OrderedDict
+from torch.nn import init
+import numpy as np
+
+
+def conv3x3(in_channels, out_channels, stride=1,
+ padding=1, bias=True, groups=1):
+ return nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ groups=groups)
+
+
+def upconv2x2(in_channels, out_channels, mode='transpose'):
+ if mode == 'transpose':
+ return nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=2,
+ stride=2)
+ else:
+ # out_channels is always going to be the same
+ # as in_channels
+ return nn.Sequential(
+ nn.Upsample(mode='bilinear', scale_factor=2),
+ conv1x1(in_channels, out_channels))
+
+
+def conv1x1(in_channels, out_channels, groups=1):
+ return nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ groups=groups,
+ stride=1)
+
+class RollOut_Conv(nn.Module):
+ def __init__(self,in_channels,out_channels):
+ super(RollOut_Conv,self).__init__()
+ #pass
+ self.in_channels=in_channels
+ self.out_channels=out_channels
+ self.conv = conv3x3(self.in_channels*3, self.out_channels)
+
+ def forward(self,row_features):
+ H,W=row_features.shape[2],row_features.shape[3]
+ H_per=H//3
+ xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per)
+ xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
+ yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
+ cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1)
+
+ xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
+ zy_feature=yz_feature.transpose(2,3) #switch z y axis, for reduced confusion
+ zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
+ cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1)
+
+ xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
+ yx_feature=xy_feature.transpose(2,3)
+ yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
+ cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1)
+
+ fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) #concat at row dimension
+
+ x = self.conv(fuse_row_feat)
+
+ return x
+
+
+class DownConv(nn.Module):
+ """
+ A helper Module that performs 2 convolutions and 1 MaxPool.
+ A ReLU activation follows each convolution.
+ """
+
+ def __init__(self, in_channels, out_channels, pooling=True):
+ super(DownConv, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.pooling = pooling
+
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
+ self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels)
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
+
+ if self.pooling:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ def forward(self, x):
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.Rollout_conv(x))
+ x = F.relu(self.conv2(x))
+ before_pool = x
+ if self.pooling:
+ x = self.pool(x)
+ return x, before_pool
+
+
+class UpConv(nn.Module):
+ """
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
+ A ReLU activation follows each convolution.
+ """
+
+ def __init__(self, in_channels, out_channels,
+ merge_mode='concat', up_mode='transpose'):
+ super(UpConv, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.merge_mode = merge_mode
+ self.up_mode = up_mode
+
+ self.upconv = upconv2x2(self.in_channels, self.out_channels,
+ mode=self.up_mode)
+
+ if self.merge_mode == 'concat':
+ self.conv1 = conv3x3(
+ 2 * self.out_channels, self.out_channels)
+ else:
+ # num of input channels to conv2 is same
+ self.conv1 = conv3x3(self.out_channels, self.out_channels)
+ self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels)
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
+
+ def forward(self, from_down, from_up):
+ """ Forward pass
+ Arguments:
+ from_down: tensor from the encoder pathway
+ from_up: upconv'd tensor from the decoder pathway
+ """
+ from_up = self.upconv(from_up)
+ if self.merge_mode == 'concat':
+ x = torch.cat((from_up, from_down), 1)
+ else:
+ x = from_up + from_down
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.Rollout_conv(x))
+ x = F.relu(self.conv2(x))
+ return x
+
+
+class UNet(nn.Module):
+ """ `UNet` class is based on https://arxiv.org/abs/1505.04597
+
+ The U-Net is a convolutional encoder-decoder neural network.
+ Contextual spatial information (from the decoding,
+ expansive pathway) about an input tensor is merged with
+ information representing the localization of details
+ (from the encoding, compressive pathway).
+
+ Modifications to the original paper:
+ (1) padding is used in 3x3 convolutions to prevent loss
+ of border pixels
+ (2) merging outputs does not require cropping due to (1)
+ (3) residual connections can be used by specifying
+ UNet(merge_mode='add')
+ (4) if non-parametric upsampling is used in the decoder
+ pathway (specified by upmode='upsample'), then an
+ additional 1x1 2d convolution occurs after upsampling
+ to reduce channel dimensionality by a factor of 2.
+ This channel halving happens with the convolution in
+ the tranpose convolution (specified by upmode='transpose')
+ """
+
+ def __init__(self, num_classes, in_channels=3, depth=5,
+ start_filts=64, up_mode='transpose',
+ merge_mode='concat', **kwargs):
+ """
+ Arguments:
+ in_channels: int, number of channels in the input tensor.
+ Default is 3 for RGB images.
+ depth: int, number of MaxPools in the U-Net.
+ start_filts: int, number of convolutional filters for the
+ first conv.
+ up_mode: string, type of upconvolution. Choices: 'transpose'
+ for transpose convolution or 'upsample' for nearest neighbour
+ upsampling.
+ """
+ super(UNet, self).__init__()
+
+ if up_mode in ('transpose', 'upsample'):
+ self.up_mode = up_mode
+ else:
+ raise ValueError("\"{}\" is not a valid mode for "
+ "upsampling. Only \"transpose\" and "
+ "\"upsample\" are allowed.".format(up_mode))
+
+ if merge_mode in ('concat', 'add'):
+ self.merge_mode = merge_mode
+ else:
+ raise ValueError("\"{}\" is not a valid mode for"
+ "merging up and down paths. "
+ "Only \"concat\" and "
+ "\"add\" are allowed.".format(up_mode))
+
+ # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
+ if self.up_mode == 'upsample' and self.merge_mode == 'add':
+ raise ValueError("up_mode \"upsample\" is incompatible "
+ "with merge_mode \"add\" at the moment "
+ "because it doesn't make sense to use "
+ "nearest neighbour to reduce "
+ "depth channels (by half).")
+
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.start_filts = start_filts
+ self.depth = depth
+
+ self.down_convs = []
+ self.up_convs = []
+
+ # create the encoder pathway and add to a list
+ for i in range(depth):
+ ins = self.in_channels if i == 0 else outs
+ outs = self.start_filts * (2 ** i)
+ pooling = True if i < depth - 1 else False
+
+ down_conv = DownConv(ins, outs, pooling=pooling)
+ self.down_convs.append(down_conv)
+
+ # create the decoder pathway and add to a list
+ # - careful! decoding only requires depth-1 blocks
+ for i in range(depth - 1):
+ ins = outs
+ outs = ins // 2
+ up_conv = UpConv(ins, outs, up_mode=up_mode,
+ merge_mode=merge_mode)
+ self.up_convs.append(up_conv)
+
+ # add the list of modules to current module
+ self.down_convs = nn.ModuleList(self.down_convs)
+ self.up_convs = nn.ModuleList(self.up_convs)
+ self.conv_final = conv1x1(outs, self.num_classes)
+
+ self.reset_params()
+
+ @staticmethod
+ def weight_init(m):
+ if isinstance(m, nn.Conv2d):
+ init.xavier_normal_(m.weight)
+ init.constant_(m.bias, 0)
+
+ def reset_params(self):
+ for i, m in enumerate(self.modules()):
+ self.weight_init(m)
+
+ def forward(self, feature_plane):
+ #cat_feature=torch.cat([feature_plane['xz'],feature_plane['xy'],feature_plane,feature_plane['yz']],dim=2) #concat at row dimension
+ x=feature_plane
+ encoder_outs = []
+ # encoder pathway, save outputs for merging
+ for i, module in enumerate(self.down_convs):
+ x, before_pool = module(x)
+ encoder_outs.append(before_pool)
+ for i, module in enumerate(self.up_convs):
+ before_pool = encoder_outs[-(i + 2)]
+ x = module(before_pool, x)
+
+ # No softmax is used. This means you need to use
+ # nn.CrossEntropyLoss is your training script,
+ # as this module includes a softmax already.
+ x = self.conv_final(x)
+ return x
+
+
+if __name__ == "__main__":
+ # """
+ # testing
+ # """
+ # model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
+ # print(model)
+ # print(sum(p.numel() for p in model.parameters()))
+ #
+ # reso = 176
+ # x = np.zeros((1, 1, reso, reso))
+ # x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
+ # x = torch.FloatTensor(x)
+ #
+ # out = model(x)
+ # print('%f' % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
+ #
+ # # loss = torch.sum(out)
+ # # loss.backward()
+ #roll_out_conv=RollOut_Conv(in_channels=32,out_channels=32).cuda().float()
+ model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float()
+ row_feature=torch.randn((10,32,128*3,128)).cuda().float()
+ output=model(row_feature)
+ #output_feature=roll_out_conv(row_feature)
+ #print(output_feature.shape)
\ No newline at end of file
diff --git a/models/modules/utils.py b/models/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..552813bff9f90e2508a4922a7bf1f026d3247f6c
--- /dev/null
+++ b/models/modules/utils.py
@@ -0,0 +1,25 @@
+import torch
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+class StackedRandomGenerator:
+ def __init__(self, device, seeds):
+ super().__init__()
+ self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
+
+ def randn(self, size, **kwargs):
+ assert size[0] == len(self.generators)
+ return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
+
+ def randn_like(self, input):
+ return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
+
+ def randint(self, *args, size, **kwargs):
+ assert size[0] == len(self.generators)
+ return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
diff --git a/output/put_checkpoints_here b/output/put_checkpoints_here
new file mode 100644
index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de
--- /dev/null
+++ b/output/put_checkpoints_here
@@ -0,0 +1 @@
+1
\ No newline at end of file
diff --git a/process_scripts/augment_arkit_partial_point.py b/process_scripts/augment_arkit_partial_point.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d3f40f67fd93b13bf3bd9f9e1b38404375850d6
--- /dev/null
+++ b/process_scripts/augment_arkit_partial_point.py
@@ -0,0 +1,64 @@
+import numpy as np
+import scipy
+import os
+import trimesh
+from sklearn.cluster import KMeans
+import random
+import glob
+import tqdm
+import argparse
+import multiprocessing as mp
+import sys
+sys.path.append("..")
+from datasets.taxonomy import arkit_category
+
+parser=argparse.ArgumentParser()
+parser.add_argument('--category',nargs="+",type=str)
+parser.add_argument("--keyword",type=str,default="lowres") #augment only the low resolution points
+parser.add_argument("--data_root",type=str,default="../data/other_data")
+args=parser.parse_args()
+category=args.category
+if category[0]=="all":
+ category=arkit_category["all"]
+kmeans=KMeans(
+ init="random",
+ n_clusters=20,
+ n_init=10,
+ max_iter=300,
+ random_state=42
+)
+
+def process_data(src_point_path,save_folder,keyword):
+ src_point_tri = trimesh.load(src_point_path)
+ src_point = np.asarray(src_point_tri.vertices)
+ kmeans.fit(src_point)
+ point_cluster_index = kmeans.labels_
+
+ '''choose 10~19 clusters to form the augmented new point'''
+ for i in range(10):
+ n_cluster = random.randint(14, 19) # 14,19 for lowres, 10,19 for highres
+ choose_cluster = np.random.choice(20, n_cluster, replace=False)
+ aug_point_list = []
+ for cluster_index in choose_cluster:
+ cluster_point = src_point[point_cluster_index == cluster_index]
+ aug_point_list.append(cluster_point)
+ aug_point = np.concatenate(aug_point_list, axis=0)
+ save_path = os.path.join(save_folder, "%s_partial_points_%d.ply" % (keyword, i + 1))
+ print("saving to %s"%(save_path))
+ aug_point_tri = trimesh.PointCloud(vertices=aug_point)
+ aug_point_tri.export(save_path)
+
+pool=mp.Pool(10)
+for cat in category[0:]:
+ keyword=args.keyword
+ point_dir = os.path.join(args.data_root,cat,"5_partial_points")
+ folder_list=os.listdir(point_dir)
+ for folder in tqdm.tqdm(folder_list[0:]):
+ folder_path=os.path.join(point_dir,folder)
+ src_point_path=os.path.join(point_dir,folder,"%s_partial_points_0.ply"%(keyword))
+ if os.path.exists(src_point_path)==False:
+ continue
+ save_folder=folder_path
+ pool.apply_async(process_data,(src_point_path,save_folder,keyword))
+pool.close()
+pool.join()
\ No newline at end of file
diff --git a/process_scripts/augment_synthetic_partial_points.py b/process_scripts/augment_synthetic_partial_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..a10f9005786729be88786f8981601db868a67e10
--- /dev/null
+++ b/process_scripts/augment_synthetic_partial_points.py
@@ -0,0 +1,64 @@
+import numpy as np
+import scipy
+import os
+import trimesh
+from sklearn.cluster import KMeans
+import random
+import glob
+import tqdm
+import multiprocessing as mp
+import sys
+sys.path.append("..")
+from datasets.taxonomy import synthetic_category_combined
+
+import argparse
+parser=argparse.ArgumentParser()
+parser.add_argument("--category",nargs="+",type=str)
+parser.add_argument("--root_dir",type=str,default="../data/other_data")
+args=parser.parse_args()
+categories=args.category
+if categories[0]=="all":
+ categories=synthetic_category_combined["all"]
+
+kmeans=KMeans(
+ init="random",
+ n_clusters=7,
+ n_init=10,
+ max_iter=300,
+ random_state=42
+)
+
+def process_data(src_filepath,save_path):
+ #print("processing %s"%(src_filepath))
+ src_point_tri = trimesh.load(src_filepath)
+ src_point = np.asarray(src_point_tri.vertices)
+ kmeans.fit(src_point)
+ point_cluster_index = kmeans.labels_
+
+ n_cluster = random.randint(3, 6)
+ choose_cluster = np.random.choice(7, n_cluster, replace=False)
+ aug_point_list = []
+ for cluster_index in choose_cluster:
+ cluster_point = src_point[point_cluster_index == cluster_index]
+ aug_point_list.append(cluster_point)
+ aug_point = np.concatenate(aug_point_list, axis=0)
+ aug_point_tri = trimesh.PointCloud(vertices=aug_point)
+ print("saving to %s"%(save_path))
+ aug_point_tri.export(save_path)
+
+pool=mp.Pool(10)
+for cat in categories:
+ print("processing %s"%cat)
+ point_dir=os.path.join(args.root_dir,cat,"5_partial_points")
+ folder_list=os.listdir(point_dir)
+ for folder in folder_list[:]:
+ folder_path=os.path.join(point_dir,folder)
+ src_filelist=glob.glob(folder_path+"/partial_points_*.ply")
+ for src_filepath in src_filelist:
+ basename=os.path.basename(src_filepath)
+ save_path = os.path.join(point_dir, folder, "aug7_" + basename)
+ pool.apply_async(process_data,(src_filepath,save_path))
+pool.close()
+pool.join()
+
+
diff --git a/process_scripts/dist_export_triplane_features.sh b/process_scripts/dist_export_triplane_features.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1609d4d76bfa9363472af63c9d27f5533c3cb3d2
--- /dev/null
+++ b/process_scripts/dist_export_triplane_features.sh
@@ -0,0 +1,8 @@
+CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
+export_triplane_features.py \
+--configs ../configs/train_triplane_vae.yaml \
+--batch_size 10 \
+--ae-pth ../output/ae/chair/best-checkpoint.pth \
+--data-pth ../data \
+--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair
+ #sub category
\ No newline at end of file
diff --git a/process_scripts/dist_extract_vit.sh b/process_scripts/dist_extract_vit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..99e6d177bc6b10c321579e8b4f78897f08d58049
--- /dev/null
+++ b/process_scripts/dist_extract_vit.sh
@@ -0,0 +1,6 @@
+CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15000 --nproc_per_node=2 \
+extract_img_vit_features.py \
+--batch_size 24 \
+--ckpt_path ../data/open_clip_pytorch_model.bin \
+--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair #sub category
+#--category 02871439 future_shelf ABO_shelf arkit_shelf \
diff --git a/process_scripts/export_triplane_features.py b/process_scripts/export_triplane_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0e4599001a90d8f0ab7ca409c40005e12b4375a
--- /dev/null
+++ b/process_scripts/export_triplane_features.py
@@ -0,0 +1,122 @@
+import argparse
+import math
+import sys
+sys.path.append("..")
+import numpy as np
+import os
+import torch
+
+import trimesh
+
+from datasets import Object_Occ,Scale_Shift_Rotate
+from models import get_model
+from pathlib import Path
+import open3d as o3d
+from configs.config_utils import CONFIG
+import tqdm
+from util import misc
+from datasets.taxonomy import synthetic_arkit_category_combined
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser('', add_help=False)
+ parser.add_argument('--configs',type=str,required=True)
+ parser.add_argument('--ae-pth',type=str)
+ parser.add_argument("--category",nargs='+', type=str)
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument("--batch_size", default=1, type=int)
+ parser.add_argument("--data-pth",default="../data",type=str)
+
+ args = parser.parse_args()
+ misc.init_distributed_mode(args)
+ device = torch.device(args.device)
+
+ config_path=args.configs
+ config=CONFIG(config_path)
+ dataset_config=config.config['dataset']
+ dataset_config['data_path']=args.data_pth
+ #transform = AxisScaling((0.75, 1.25), True)
+ transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True)
+ if len(args.category)==1 and args.category[0]=="all":
+ category=synthetic_arkit_category_combined["all"]
+ else:
+ category=args.category
+ train_dataset = Object_Occ(dataset_config['data_path'], split="train",
+ categories=category,
+ transform=transform, sampling=True,
+ num_samples=1024, return_surface=True,
+ surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
+ val_dataset = Object_Occ(dataset_config['data_path'], split="val",
+ categories=category,
+ transform=transform, sampling=True,
+ num_samples=1024, return_surface=True,
+ surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
+ train_sampler = torch.utils.data.DistributedSampler(
+ train_dataset, num_replicas=num_tasks, rank=global_rank,
+ shuffle=False) # shuffle=True to reduce monitor bias
+ val_sampler=torch.utils.data.DistributedSampler(
+ val_dataset, num_replicas=num_tasks, rank=global_rank,
+ shuffle=False) # shu
+ #dataset=val_dataset
+ batch_size=args.batch_size
+ train_dataloader=torch.utils.data.DataLoader(
+ train_dataset,sampler=train_sampler,
+ batch_size=batch_size,
+ num_workers=10,
+ shuffle=False,
+ drop_last=False,
+ )
+ val_dataloader = torch.utils.data.DataLoader(
+ val_dataset, sampler=val_sampler,
+ batch_size=batch_size,
+ num_workers=10,
+ shuffle=False,
+ drop_last=False,
+ )
+ dataloader_list=[train_dataloader,val_dataloader]
+ #dataloader_list=[val_dataloader]
+ output_dir=os.path.join(dataset_config['data_path'],"other_data")
+ #output_dir="/data1/haolin/datasets/ShapeNetV2_watertight"
+
+ model_config=config.config['model']
+ model=get_model(model_config)
+ model.load_state_dict(torch.load(args.ae_pth)['model'])
+ model.eval().float().to(device)
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
+
+ with torch.no_grad():
+ for e in range(5):
+ for dataloader in dataloader_list:
+ for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)):
+ surface = data_batch['surface'].to(device, non_blocking=True)
+ model_ids=data_batch['model_id']
+ tran_mats=data_batch['tran_mat']
+ categories=data_batch['category']
+ with torch.no_grad():
+ plane_feat,_,means,logvars=model.encode(surface)
+ plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear')
+ vars=torch.exp(logvars)
+ means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear")
+ vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4
+ sample_logvars=torch.log(vars)
+
+ for j in range(means.shape[0]):
+ #plane_dist=plane_feat[j].float().cpu().numpy()
+ mean=means[j].float().cpu().numpy()
+ logvar=sample_logvars[j].float().cpu().numpy()
+ tran_mat=tran_mats[j].float().cpu().numpy()
+
+ output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j])
+ Path(output_folder).mkdir(parents=True, exist_ok=True)
+ exist_len=len(os.listdir(output_folder))
+ save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len))
+ np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat)
diff --git a/process_scripts/extract_img_vit_features.py b/process_scripts/extract_img_vit_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..17d3b0fc89e31e2c84305afdbc56178d88ca84de
--- /dev/null
+++ b/process_scripts/extract_img_vit_features.py
@@ -0,0 +1,73 @@
+import os,sys
+sys.path.append("..")
+from util.simple_image_loader import Image_dataset
+from torch.utils.data import DataLoader
+import timm
+import torch
+from tqdm import tqdm
+import numpy as np
+from transformers import DPTForDepthEstimation, DPTFeatureExtractor
+import argparse
+from util import misc
+from datasets.taxonomy import synthetic_arkit_category_combined
+parser=argparse.ArgumentParser()
+
+parser.add_argument("--category",nargs="+",type=str)
+parser.add_argument("--root_dir",type=str, default="../data")
+parser.add_argument("--ckpt_path",type=str,default="../open_clip_pytorch_model.bin")
+parser.add_argument("--batch_size",type=int,default=24)
+parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+parser.add_argument('--local_rank', default=-1, type=int)
+parser.add_argument('--dist_on_itp', action='store_true')
+parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+args= parser.parse_args()
+misc.init_distributed_mode(args)
+category=args.category
+
+#dataset=Image_dataset(categories=['03001627','ABO_chair','future_chair'])
+if args.category[0]=="all":
+ category=synthetic_arkit_category_combined["all"]
+print("loading dataset")
+dataset=Image_dataset(dataset_folder=args.root_dir,categories=category,n_px=224)
+num_tasks = misc.get_world_size()
+global_rank = misc.get_rank()
+sampler = torch.utils.data.DistributedSampler(
+ dataset, num_replicas=num_tasks, rank=global_rank,
+ shuffle=False) # shuffle=True to reduce monitor bias
+
+dataloader=DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=args.batch_size,
+ num_workers=4,
+ pin_memory=True,
+ drop_last=False
+)
+print("loading model")
+VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
+model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file=args.ckpt_path))
+model=model.eval().float().cuda()
+save_dir=os.path.join(args.root_dir,"other_data")
+for idx,data_batch in enumerate(dataloader):
+ if idx%50==0:
+ print("{}/{}".format(dataloader.__len__(),idx))
+ images = data_batch["images"].cuda().float()
+ model_id= data_batch["model_id"]
+ image_name=data_batch["image_name"]
+ category=data_batch["category"]
+ with torch.no_grad():
+ #output=model(images,output_hidden_states=True)
+ output_features=model.forward_features(images)
+ #predict_depth=output.predicted_depth
+ #print(predict_depth.shape)
+ for j in range(output_features.shape[0]):
+ save_folder=os.path.join(save_dir,category[j],"7_img_features",model_id[j])
+ os.makedirs(save_folder,exist_ok=True)
+ save_path=os.path.join(save_folder,image_name[j]+".npz")
+ #print("saving to",save_path)
+ np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))
+
+
+
diff --git a/process_scripts/generate_split_for_arkit.py b/process_scripts/generate_split_for_arkit.py
new file mode 100644
index 0000000000000000000000000000000000000000..02b156392c63c8edd780c6eb124c0c5070483ba1
--- /dev/null
+++ b/process_scripts/generate_split_for_arkit.py
@@ -0,0 +1,102 @@
+import os
+import numpy as np
+import glob
+import open3d as o3d
+import json
+import argparse
+import glob
+
+parser=argparse.ArgumentParser()
+parser.add_argument("--cat",required=True,type=str,nargs="+")
+parser.add_argument("--keyword",default="lowres",type=str)
+parser.add_argument("--root_dir",type=str,default="../data")
+args=parser.parse_args()
+
+keyword=args.keyword
+sdf_folder="occ_data"
+other_folder="other_data"
+data_dir=args.root_dir
+
+align_dir=os.path.join(args.root_dir,"align_mat_all") # this alignment matrix is aligned from highres scan to lowres scan
+# the alignment matrix is still under cleaning, not all the data have proper alignment matrix yet.
+align_filelist=glob.glob(align_dir+"/*/*.txt")
+valid_model_list=[]
+for align_filepath in align_filelist:
+ if "-v" in align_filepath:
+ align_mat=np.loadtxt(align_filepath)
+ if align_mat.shape[0]!=4:
+ continue
+ model_id=os.path.basename(align_filepath).split("-")[0]
+ valid_model_list.append(model_id)
+
+print("there are %d valid lowres models"%(len(valid_model_list)))
+
+category_list=args.cat
+for category in category_list:
+ train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
+ with open(train_path,'r') as f:
+ train_list=f.readlines()
+ train_list=[item.rstrip() for item in train_list]
+ if ".npz" in train_list[0]:
+ train_list=[item[:-4] for item in train_list]
+ val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
+ with open(val_path,'r') as f:
+ val_list=f.readlines()
+ val_list=[item.rstrip() for item in val_list]
+ if ".npz" in val_list[0]:
+ val_list=[item[:-4] for item in val_list]
+
+
+ sdf_dir=os.path.join(data_dir,sdf_folder,category)
+ filelist=os.listdir(sdf_dir)
+ model_id_list=[item[:-4] for item in filelist if ".npz" in item]
+
+ train_par_img_list=[]
+ val_par_img_list=[]
+ for model_id in model_id_list:
+ if model_id not in valid_model_list:
+ continue
+ image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
+ partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
+ if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
+ continue
+ if os.path.exists(image_dir):
+ image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
+ image_list=[os.path.basename(image_path) for image_path in image_list]
+ else:
+ image_list=[]
+
+ if os.path.exists(partial_dir):
+ partial_list=glob.glob(partial_dir+"/%s_partial_points_*.ply"%(keyword))
+ else:
+ partial_list=[]
+ partial_valid_list=[]
+ for partial_filepath in partial_list:
+ par_o3d=o3d.io.read_point_cloud(partial_filepath)
+ par_xyz=np.asarray(par_o3d.points)
+ if par_xyz.shape[0]>2048:
+ partial_valid_list.append(os.path.basename(partial_filepath))
+ if model_id in val_list:
+ if "%s_partial_points_0.ply"%(keyword) in partial_valid_list:
+ partial_valid_list=["%s_partial_points_0.ply"%(keyword)]
+ else:
+ partial_valid_list=[]
+ if len(image_list)==0 and len(partial_valid_list)==0:
+ continue
+ ret_dict={
+ "model_id":model_id,
+ "image_filenames":image_list[:],
+ "partial_filenames":partial_valid_list[:]
+ }
+ if model_id in train_list:
+ train_par_img_list.append(ret_dict)
+ elif model_id in val_list:
+ val_par_img_list.append(ret_dict)
+
+ train_save_path=os.path.join(sdf_dir,"%s_train_par_img.json"%(keyword))
+ with open(train_save_path,'w') as f:
+ json.dump(train_par_img_list,f,indent=4)
+
+ val_save_path=os.path.join(sdf_dir,"%s_val_par_img.json"%(keyword))
+ with open(val_save_path,'w') as f:
+ json.dump(val_par_img_list,f,indent=4)
diff --git a/process_scripts/generate_split_for_synthetic_data.py b/process_scripts/generate_split_for_synthetic_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d50377c7ad4db036eadf6d12fef0d1f771f0e2a
--- /dev/null
+++ b/process_scripts/generate_split_for_synthetic_data.py
@@ -0,0 +1,78 @@
+import os,sys
+import numpy as np
+import glob
+import open3d as o3d
+import json
+import argparse
+
+parser=argparse.ArgumentParser()
+parser.add_argument("--cat",required=True,type=str,nargs="+")
+parser.add_argument("--root_dir",type=str,default="../data")
+args=parser.parse_args()
+
+sdf_folder="occ_data"
+other_folder="other_folder"
+data_dir=args.root_dir
+category=args.cat
+train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
+with open(train_path,'r') as f:
+ train_list=f.readlines()
+ train_list=[item.rstrip() for item in train_list]
+ if ".npz" in train_list[0]:
+ train_list=[item[:-4] for item in train_list]
+val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
+with open(val_path,'r') as f:
+ val_list=f.readlines()
+ val_list=[item.rstrip() for item in val_list]
+ if ".npz" in val_list[0]:
+ val_list=[item[:-4] for item in val_list]
+
+category_list=args.cat
+for category in category_list:
+ sdf_dir=os.path.join(data_dir,sdf_folder,category)
+ filelist=os.listdir(sdf_dir)
+ model_id_list=[item[:-4] for item in filelist if ".npz" in item]
+
+ train_par_img_list=[]
+ val_par_img_list=[]
+ for model_id in model_id_list:
+ image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
+ partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
+ if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
+ continue
+ if os.path.exists(image_dir):
+ image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
+ image_list=[os.path.basename(image_path) for image_path in image_list]
+ else:
+ image_list=[]
+
+ if os.path.exists(partial_dir):
+ partial_list=glob.glob(partial_dir+"/partial_points_*.ply")
+ else:
+ partial_list=[]
+ partial_valid_list=[]
+ for partial_filepath in partial_list:
+ par_o3d=o3d.io.read_point_cloud(partial_filepath)
+ par_xyz=np.asarray(par_o3d.points)
+ if par_xyz.shape[0]>2048:
+ partial_valid_list.append(os.path.basename(partial_filepath))
+ if len(image_list)==0 and len(partial_valid_list)==0:
+ continue
+ ret_dict={
+ "model_id":model_id,
+ "image_filenames":image_list[:],
+ "partial_filenames":partial_valid_list[:]
+ }
+ if model_id in train_list:
+ train_par_img_list.append(ret_dict)
+ elif model_id in val_list:
+ val_par_img_list.append(ret_dict)
+
+ #print(train_par_img_list)
+ train_save_path=os.path.join(sdf_dir,"train_par_img.json")
+ with open(train_save_path,'w') as f:
+ json.dump(train_par_img_list,f,indent=4)
+
+ val_save_path=os.path.join(sdf_dir,"val_par_img.json")
+ with open(val_save_path,'w') as f:
+ json.dump(val_par_img_list,f,indent=4)
diff --git a/process_scripts/unzip_all_data.py b/process_scripts/unzip_all_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..52b8dc3cef87d5a184724b57d906ec45a82fc9cf
--- /dev/null
+++ b/process_scripts/unzip_all_data.py
@@ -0,0 +1,38 @@
+import os
+import glob
+import argparse
+
+parser = argparse.ArgumentParser("unzip the prepared data")
+parser.add_argument("--occ_root", type=str, default="../data/occ_data")
+parser.add_argument("--other_root", type=str,default="../data/other_data")
+parser.add_argument("--unzip_occ",default=False,action="store_true")
+parser.add_argument("--unzip_other",default=False,action="store_true")
+
+args=parser.parse_args()
+if args.unzip_occ:
+ filelist=os.listdir(args.occ_root)
+ for filename in filelist:
+ filepath=os.path.join(args.occ_root,filename)
+ if ".rar" in filename:
+ unrar_command="unrar x %s %s"%(filepath,args.occ_root)
+ os.system(unrar_command)
+ elif ".zip" in filename:
+ unzip_command="7z x %s -o%s"%(filepath,args.occ_root)
+ os.system(unzip_command)
+
+
+if args.unzip_other:
+ category_list=os.listdir(args.other_root)
+ for category in category_list:
+ category_folder=os.path.join(args.other_root,category)
+ #print(category_folder)
+ rar_filelist=glob.glob(category_folder+"/*.rar")
+ zip_filelist=glob.glob(category_folder+"/*.zip")
+
+ for rar_filepath in rar_filelist:
+ unrar_command="unrar x %s %s"%(rar_filepath,category_folder)
+ os.system(unrar_command)
+ for zip_filepath in zip_filelist:
+ unzip_command="7z x %s -o%s"%(zip_filepath,category_folder)
+ os.system(unzip_command)
+
diff --git a/scripts/train_triplane_diffusion.py b/scripts/train_triplane_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bfe1cc1d97639ff2cc6deecdf31762750e02756
--- /dev/null
+++ b/scripts/train_triplane_diffusion.py
@@ -0,0 +1,317 @@
+import argparse
+import sys
+sys.path.append("..")
+import datetime
+import json
+import numpy as np
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+import util.misc as misc
+from datasets import build_dataset
+from util.misc import NativeScalerWithGradNormCount as NativeScaler
+from models import get_model,get_criterion
+
+from engine.engine_triplane_dm import train_one_epoch,evaluate_reconstruction
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Latent Diffusion', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int,
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
+ parser.add_argument('--epochs', default=800, type=int)
+ parser.add_argument('--accum_iter', default=1, type=int,
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+
+ parser.add_argument('--ae-pth',type=str)
+ # Optimizer parameters
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
+ help='learning rate (absolute lr)')
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', # 2e-4
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--layer_decay', type=float, default=0.75,
+ help='layer-wise lr decay from ELECTRA/BEiT')
+
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0')
+
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
+ help='epochs to warmup LR')
+ # Dataset parameters
+ parser.add_argument('--data-pth', default='../data', type=str,
+ help='dataset path')
+
+ parser.add_argument('--output_dir', default='./output/',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default='./output/',
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--resume', default='',
+ help='resume from checkpoint')
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--eval', action='store_true',
+ help='Perform evaluation only')
+ parser.add_argument('--dist_eval', action='store_true', default=False,
+ help='Enabling distributed evaluation (recommended during training for faster monitor')
+ parser.add_argument('--num_workers', default=60, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=True)
+ parser.add_argument('--constant_lr', default=False, action='store_true')
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+ parser.add_argument('--load_proj_mat',default=True,type=bool)
+ parser.add_argument('--num_objects',type=int,default=-1)
+
+ parser.add_argument('--configs', type=str)
+ parser.add_argument('--finetune', default=False, action="store_true")
+ parser.add_argument('--finetune-pth', type=str)
+ parser.add_argument('--use_cls_free',action="store_true",default=False)
+ parser.add_argument('--sync_bn',action="store_true",default=False)
+ parser.add_argument('--category',type=str)
+ parser.add_argument('--stop',type=int,default=1000)
+ parser.add_argument('--replica', type=int, default=5)
+
+ return parser
+
+
+def main(args,config):
+ misc.init_distributed_mode(args)
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + misc.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_config = config.config['dataset']
+ dataset_config['category']=args.category
+ dataset_config['replica']=args.replica
+ dataset_config['num_objects']=args.num_objects
+ dataset_config['data_path']=args.data_pth
+ dataset_train = build_dataset('train', dataset_config)
+ print("training dataset len is %d"%(len(dataset_train)))
+ dataset_val=build_dataset('val', dataset_config)
+ #dataset_val = build_dataset('val', dataset_config)
+
+ if True: # args.distributed:
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank,
+ shuffle=True) # shuffle=True to reduce monitor bias
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if global_rank == 0 and args.log_dir is not None and not args.eval:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = SummaryWriter(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if misc.get_rank()==0:
+ log_dir=args.log_dir
+ src_folder="/data1/haolin/TriplaneDiffusion"
+ misc.log_codefiles(src_folder,log_dir+"/code_bak")
+ #cmd="cp -r %s %s"%(src_folder,log_dir+"/code_bak")
+ #print(cmd)
+ #os.system(cmd)
+ config_dict=vars(args)
+ config_save_path=os.path.join(log_dir,"config.json")
+ with open(config_save_path,'w') as f:
+ json.dump(config_dict,f,indent=4)
+
+ model_dict=config
+ model_config_save_path=os.path.join(log_dir,"model.json")
+ config.write_config(model_config_save_path)
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ prefetch_factor=2,
+ )
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ # batch_size=args.batch_size,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+
+ ae_args=config.config['model']['ae']
+ ae = get_model(ae_args)
+ ae.eval()
+ print("Loading autoencoder %s" % args.ae_pth)
+ ae.load_state_dict(torch.load(args.ae_pth, map_location='cpu')['model'])
+ ae.to(device)
+
+ dm_args=config.config['model']['dm']
+ if args.category[0] == "all":
+ dm_args["use_cat_embedding"]=True
+ else:
+ dm_args["use_cat_embedding"] = False
+ dm_model = get_model(dm_args)
+ if args.sync_bn:
+ dm_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dm_model)
+ if args.finetune:
+ print("finetune the model, load from %s"%(args.finetune_pth))
+ dm_model.load_state_dict(torch.load(args.finetune_pth,map_location="cpu")['model'])
+ dm_model.to(device)
+
+ model_without_ddp = dm_model
+ n_parameters = sum(p.numel() for p in dm_model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+
+ if args.lr is None: # only base_lr is specified
+ args.lr = args.blr * eff_batch_size / 256
+
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("actual lr: %.2e" % args.lr)
+
+ print("accumulate grad iterations: %d" % args.accum_iter)
+ print("effective batch size: %d" % eff_batch_size)
+
+ if args.distributed:
+ dm_model = torch.nn.parallel.DistributedDataParallel(dm_model, device_ids=[args.gpu], find_unused_parameters=False)
+ model_without_ddp = dm_model.module
+
+ # # build optimizer with layer-wise lr decay (lrd)
+ # param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
+ # no_weight_decay_list=model_without_ddp.no_weight_decay(),
+ # layer_decay=args.layer_decay
+ # )
+ optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr)
+ loss_scaler = NativeScaler()
+
+ cri_args=config.config['criterion']
+ criterion = get_criterion(cri_args)
+
+ print("criterion = %s" % str(criterion))
+
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
+
+ if args.eval:
+ test_stats = evaluate(data_loader_val, dm_model, device)
+ print(f"loss of the network on the {len(dataset_val)} test images: {test_stats['loss']:.3f}")
+ exit(0)
+
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ min_loss = 1000.0
+ max_iou=0
+
+ stop_epochs=min(args.stop,args.epochs)
+ for epoch in range(args.start_epoch, stop_epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ #test_stats = evaluate_reconstruction(data_loader_val, dm_model, ae, criterion, device)
+ train_stats = train_one_epoch(
+ dm_model, ae, criterion, data_loader_train,
+ optimizer, device, epoch, loss_scaler,
+ args.clip_grad,
+ log_writer=log_writer,
+ log_dir=args.log_dir,
+ args=args
+ )
+ if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):
+ misc.save_model(
+ args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch,prefix="latest")
+
+ if epoch % 5 == 0 or epoch + 1 == args.epochs:
+ test_stats = evaluate_reconstruction(data_loader_val, dm_model, ae, criterion, device)
+ print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}")
+ # print(f"loss of the network on the {len(dataset_val)} test images: {test_stats['loss']:.3f}")
+
+ if test_stats["iou"] > max_iou:
+ max_iou = test_stats["iou"]
+ misc.save_model(
+ args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, prefix='best')
+ else:
+ misc.save_model(
+ args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, prefix='latest')
+
+ if log_writer is not None:
+ log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
+ log_writer.add_scalar('perf/test_iou', test_stats['iou'], epoch)
+ log_writer.add_scalar('perf/test_accuracy', test_stats['accuracy'], epoch)
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+
+ if args.output_dir and misc.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ args = get_args_parser()
+ args = args.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ config_path = args.configs
+ from configs.config_utils import CONFIG
+
+ config = CONFIG(config_path)
+ main(args,config)
diff --git a/scripts/train_triplane_vae.py b/scripts/train_triplane_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..8634b75b95f49f390f123d8b8c8d90d64f463ffe
--- /dev/null
+++ b/scripts/train_triplane_vae.py
@@ -0,0 +1,287 @@
+import argparse
+import datetime
+import json
+import numpy as np
+import os,sys
+sys.path.append("..")
+# os.system("taskset -p 0xff %d"%(os.getpid()))
+import time
+from pathlib import Path
+
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+torch.set_num_threads(8)
+
+import util.misc as misc
+from datasets import build_dataset
+from util.misc import NativeScalerWithGradNormCount as NativeScaler
+from models import get_model
+
+from engine.engine_triplane_vae import train_one_epoch, evaluate
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('Autoencoder', add_help=False)
+ parser.add_argument('--batch_size', default=64, type=int,
+ help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
+ parser.add_argument('--epochs', default=800, type=int)
+ parser.add_argument('--accum_iter', default=1, type=int,
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+
+ # Optimizer parameters
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--weight_decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
+ help='learning rate (absolute lr)')
+ parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--layer_decay', type=float, default=0.75,
+ help='layer-wise lr decay from ELECTRA/BEiT')
+
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0')
+
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
+ help='epochs to warmup LR')
+
+ parser.add_argument('--output_dir', default='./output/',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--log_dir', default='./output/',
+ help='path where to tensorboard log')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--resume', default='',
+ help='resume from checkpoint')
+ parser.add_argument('--data-pth',default="../data",type=str)
+
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--eval', action='store_true',
+ help='Perform evaluation only')
+ parser.add_argument('--dist_eval', action='store_true', default=False,
+ help='Enabling distributed evaluation (recommended during training for faster monitor')
+ parser.add_argument('--num_workers', default=60, type=int)
+ parser.add_argument('--pin_mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
+ parser.set_defaults(pin_mem=False)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--local_rank', default=-1, type=int)
+ parser.add_argument('--dist_on_itp', action='store_true')
+ parser.add_argument('--dist_url', default='env://',
+ help='url used to set up distributed training')
+
+ parser.add_argument('--configs',type=str)
+ parser.add_argument('--finetune', default=False, action="store_true")
+ parser.add_argument('--finetune-pth', type=str)
+ parser.add_argument('--category',type=str)
+ parser.add_argument('--replica',type=int,default=8)
+
+ return parser
+
+
+def main(args,config):
+ misc.init_distributed_mode(args)
+
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(', ', ',\n'))
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + misc.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_config=config.config['dataset']
+ dataset_config['category']=args.category
+ dataset_config['replica']=args.replica
+ dataset_config['data_path']=args.data_pth
+ dataset_train = build_dataset('train',dataset_config)
+ dataset_val = build_dataset('val', dataset_config)
+
+ if True: # args.distributed:
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print("Sampler_train = %s" % str(sampler_train))
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank,
+ shuffle=True) # shuffle=True to reduce monitor bias
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ if global_rank == 0 and args.log_dir is not None and not args.eval:
+ os.makedirs(args.log_dir, exist_ok=True)
+ log_writer = SummaryWriter(log_dir=args.log_dir)
+ else:
+ log_writer = None
+
+ if misc.get_rank() == 0:
+ log_dir = args.log_dir
+ src_folder = "/data1/haolin/TriplaneDiffusion"
+ misc.log_codefiles(src_folder, log_dir + "/code_bak")
+ config_dict = vars(args)
+ config_save_path = os.path.join(log_dir, "config.json")
+ with open(config_save_path, 'w') as f:
+ json.dump(config_dict, f, indent=4)
+ model_config_path=os.path.join(log_dir,"setup.yaml")
+ config.write_config(model_config_path)
+
+ print("dataset len", dataset_train.__len__())
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ prefetch_factor=2,
+ )
+ print("dataset len", dataset_train.__len__(), "dataloader len", len(data_loader_train))
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ # batch_size=args.batch_size,
+ batch_size=1,
+ # num_workers=args.num_workers,
+ num_workers=1,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+
+ #model = models_ae.__dict__[args.model](N=args.point_cloud_size)
+ model_config=config.config['model']
+ model = get_model(model_config)
+ if args.finetune:
+ print("finetune the model, load from %s"%(args.finetune_pth))
+ model.load_state_dict(torch.load(args.finetune_pth)['model'])
+ model.to(device)
+
+ model_without_ddp = model
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ print("Model = %s" % str(model_without_ddp))
+ print('number of params (M): %.2f' % (n_parameters / 1.e6))
+
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
+
+ if args.lr is None: # only base_lr is specified
+ args.lr = args.blr * eff_batch_size / 256
+
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("actual lr: %.2e" % args.lr)
+
+ print("accumulate grad iterations: %d" % args.accum_iter)
+ print("effective batch size: %d" % eff_batch_size)
+
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
+ model_without_ddp = model.module
+
+ # # build optimizer with layer-wise lr decay (lrd)
+ # param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
+ # no_weight_decay_list=model_without_ddp.no_weight_decay(),
+ # layer_decay=args.layer_decay
+ # )
+ optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr)
+ loss_scaler = NativeScaler()
+
+ criterion = torch.nn.BCEWithLogitsLoss()
+
+ print("criterion = %s" % str(criterion))
+
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
+
+ if args.eval:
+ test_stats = evaluate(data_loader_val, model, device)
+ print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}")
+ exit(0)
+
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ max_iou = 0.0
+ for epoch in range(args.start_epoch, args.epochs):
+ # if args.distributed:
+ # data_loader_train.sampler.set_epoch(epoch)
+ #test_stats = evaluate(data_loader_val, model, device)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train,
+ optimizer, device, epoch, loss_scaler,
+ args.clip_grad,
+ log_writer=log_writer,
+ args=args
+ )
+ # if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs):
+ # misc.save_model(
+ # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ # loss_scaler=loss_scaler, epoch=epoch)
+
+ if epoch % 5 == 0 or epoch + 1 == args.epochs:
+ test_stats = evaluate(data_loader_val, model, device)
+ print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}")
+ if test_stats["iou"] > max_iou:
+ max_iou = test_stats["iou"]
+ misc.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, prefix='best')
+ else:
+ misc.save_model(
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
+ loss_scaler=loss_scaler, epoch=epoch, prefix='latest')
+ # max_iou = max(max_iou, test_stats["iou"])
+ print(f'Max iou: {max_iou:.2f}%')
+
+ if log_writer is not None:
+ log_writer.add_scalar('perf/test_iou', test_stats['iou'], epoch)
+ log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch)
+
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+ else:
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+
+ if args.output_dir and misc.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ args = get_args_parser()
+ args = args.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ config_path=args.configs
+ from configs.config_utils import CONFIG
+ config=CONFIG(config_path)
+ main(args,config)
\ No newline at end of file
diff --git a/train_VAE.sh b/train_VAE.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ab4390dd959c931924d7ab9add3415b1bbdbd621
--- /dev/null
+++ b/train_VAE.sh
@@ -0,0 +1,15 @@
+cd scripts
+CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun --master_port 15000 --nproc_per_node=8 \
+train_triplane_vae.py \
+--configs ../configs/train_triplane_vae.yaml \
+--accum_iter 2 \
+--output_dir ../output/ae/chair \
+--log_dir ../output/ae/chair --num_workers 8 \
+--batch_size 22 \
+--epochs 200 \
+--warmup_epochs 5 \
+--dist_eval \
+--clip_grad 0.35 \
+--category chair \
+--data-pth ../data \
+--replica 5
\ No newline at end of file
diff --git a/train_diffusion.sh b/train_diffusion.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2f8da321b46b90d98bfb71ba01f0948499911235
--- /dev/null
+++ b/train_diffusion.sh
@@ -0,0 +1,16 @@
+cd scripts
+CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15004 --nproc_per_node=8 \
+train_triplane_diffusion.py \
+--configs ../configs/train_triplane_diffusion.yaml \
+--accum_iter 2 \
+--output_dir ../output/dm/debug \
+--log_dir ../output/dm/debug \
+--num_workers 8 \
+--batch_size 22 \
+--epochs 1000 \
+--dist_eval \
+--warmup_epochs 40 \
+--category chair \
+--ae-pth ../output/ae/chair/best-checkpoint.pth \
+--data-pth ../data \
+--replica 5
\ No newline at end of file
diff --git a/util/lr_sched.py b/util/lr_sched.py
new file mode 100644
index 0000000000000000000000000000000000000000..04697fc1a34d4a736cb9c86dcaaada150be7e78c
--- /dev/null
+++ b/util/lr_sched.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
\ No newline at end of file
diff --git a/util/misc.py b/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c9c1334249434d8b6e07dadbb4debee75dfde6
--- /dev/null
+++ b/util/misc.py
@@ -0,0 +1,358 @@
+# --------------------------------------------------------
+# References:
+# MAE: https://github.com/facebookresearch/mae
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch._six import inf
+import numpy as np
+
+def log_codefiles(data_root,save_root):
+ import glob
+ import shutil
+ print("saving codes to",data_root,glob.glob(data_root+"/*.py"))
+ all_files=glob.glob(data_root+"/*.py")+glob.glob(data_root+"/*.sh")+glob.glob(data_root+"/*/*.py")+glob.glob(data_root+"/*/*.sh")+ \
+ glob.glob(data_root + "/*/*/*.py") + glob.glob(data_root + "/*/*/*.sh")
+ for file in all_files:
+ rel_path=os.path.relpath(file,data_root)
+ dst_path=os.path.join(save_root,rel_path)
+ if os.path.exists(os.path.dirname(dst_path))==False:
+ os.makedirs(os.path.dirname(dst_path))
+ shutil.copyfile(file,dst_path)
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ #print(name,meter)
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ #print("available threads",torch.get_num_threads())
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ #print(str(self))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+def worker_init_fn(worker_id):
+ random_data = os.urandom(4)
+ base_seed = int.from_bytes(random_data, byteorder="big")
+ np.random.seed(base_seed + worker_id)
+ #torch.random.seed(base_seed+worker_id)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, prefix="latest"):
+ output_dir = Path(args.output_dir)
+ epoch_name = str(epoch)
+ if loss_scaler is not None:
+ checkpoint_paths = [output_dir / ('%s-checkpoint.pth' % (prefix))]
+ for checkpoint_path in checkpoint_paths:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+
+ save_on_master(to_save, checkpoint_path)
+ else:
+ client_state = {'epoch': epoch}
+ model.save_checkpoint(save_dir=args.output_dir, tag='%s-checkpoint.pth' % (prefix), client_state=client_state)
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+ if args.resume:
+ if args.resume.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ print("Resume checkpoint %s" % args.resume)
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
\ No newline at end of file
diff --git a/util/projection_utils.py b/util/projection_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc8721469a1ef5ef01cc793a875552ce13c36fc
--- /dev/null
+++ b/util/projection_utils.py
@@ -0,0 +1,12 @@
+import numpy as np
+import cv2
+def draw_proj_image(image,proj_mat,points):
+ points_homo=np.concatenate([points,np.ones((points.shape[0],1))],axis=1)
+ pts_inimg=np.dot(points_homo,proj_mat.T)
+ image=cv2.resize(image,dsize=(224,224),interpolation=cv2.INTER_LINEAR)
+ x=pts_inimg[:,0]/pts_inimg[:,2]
+ y=pts_inimg[:,1]/pts_inimg[:,2]
+ x=np.clip(x,a_min=0,a_max=223).astype(np.int32)
+ y=np.clip(y,a_min=0,a_max=223).astype(np.int32)
+ image[y,x]=np.array([[0,255,0]])
+ return image
\ No newline at end of file
diff --git a/util/simple_image_loader.py b/util/simple_image_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..92baa973b873e0a59b6435a9632da1e820211baa
--- /dev/null
+++ b/util/simple_image_loader.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+from torch.utils import data
+import os
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+import glob
+
+def image_transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ ToTensor(),
+ # Normalize((123.675/255.0,116.28/255.0,103.53/255.0),
+ # (58.395/255.0,57.12/255.0,57.375/255.0))
+ Normalize((0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711)),
+ # Normalize((0.5, 0.5, 0.5),
+ # (0.5, 0.5, 0.5)),
+ ])
+
+class Image_dataset(data.Dataset):
+ def __init__(self,dataset_folder="/data1/haolin/datasets",categories=['03001627'],n_px=224):
+ self.dataset_folder=dataset_folder
+ self.image_folder=os.path.join(self.dataset_folder,'other_data')
+ self.preprocess=image_transform(n_px)
+ self.image_path=[]
+ for cat in categories:
+ subpath=os.path.join(self.image_folder,cat,"6_images")
+ model_list=os.listdir(subpath)
+ for folder in model_list:
+ model_folder=os.path.join(subpath,folder)
+ image_list=os.listdir(model_folder)
+ for image_filename in image_list:
+ image_filepath=os.path.join(model_folder,image_filename)
+ self.image_path.append(image_filepath)
+ def __len__(self):
+ return len(self.image_path)
+
+ def __getitem__(self,index):
+ path=self.image_path[index]
+ basename=os.path.basename(path)[:-4]
+ model_id=path.split(os.sep)[-2]
+ category=path.split(os.sep)[-4]
+ image=Image.open(path)
+ image_tensor=self.preprocess(image)
+
+ return {"images":image_tensor,"image_name":basename,"model_id":model_id,"category":category}
+
+class Image_InTheWild_dataset(data.Dataset):
+ def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224):
+ self.dataset_dir=dataset_dir
+ self.preprocess = image_transform(n_px)
+ self.image_path = []
+ if scene_id=="all":
+ scene_list=os.listdir(self.dataset_dir)
+ for id in scene_list:
+ image_folder=os.path.join(self.dataset_dir,id,"6_images")
+ self.image_path+=glob.glob(image_folder+"/*/*jpg")
+ else:
+ image_folder = os.path.join(self.dataset_dir, scene_id, "6_images")
+ self.image_path += glob.glob(image_folder + "/*/*jpg")
+ def __len__(self):
+ return len(self.image_path)
+
+ def __getitem__(self,index):
+ path=self.image_path[index]
+ basename=os.path.basename(path)[:-4]
+ model_id=path.split(os.sep)[-2]
+ scene_id=path.split(os.sep)[-4]
+ image=Image.open(path)
+ image_tensor=self.preprocess(image)
+
+ return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id}
+
diff --git a/util/train_test_utils.py b/util/train_test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391