diff --git a/README.md b/README.md index e00d9e8fcbf0d4586c41615b3f50e9bbe30ef738..d9f5496b77b6e4ec5ed5ff225a6993eed88f7718 100644 --- a/README.md +++ b/README.md @@ -8,4 +8,75 @@ Repository of LASA: Instance Reconstruction from Real Scans using A Large-scale ![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e) ![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a) -#### Codes and dataset will be released soon! +## Dataset +Complete raw data will be released soon. + +## Download preprocessed data and processing +Download the preprocessed data from +BaiduYun (code: 62ux). (These data will be updated as cleaning process continues.) Put all the downloaded data under LASA, unzip the align_mat_all.zip mannually. +You can choose to the the script ./process_scripts/unzip_all_data to unzip all the data in occ_data and other_data by following commands: +```angular2html +cd process_scripts +python unzip_all_data.py --unzip_occ --unzip_other +``` +Run the following commands to generate augmented partial point cloud for synthetic dataset and LASA dataset +```angular2html +cd process_scripts +python augment_arkit_partial_point.py --cat arkit_chair arkit_stool ... +python augment_synthetic_partial_point.py --cat 03001627 future_chair ABO_chair ... +``` +Run the following command to extract image features +```angular2html +cd process_scripts +bash dist_extract_vit.sh +``` +Finally, run the following command to generate train/val splits: +```angular2html +cd process_scripts +python generate_split_for_arkit --cat arkit_chair arkit_stool ... +python generate_split_for_synthetic_data.py --cat 03001627 future_chair ABO_chair ... +``` + +## Evaluation +Download the pretrained weight for chair from chair_checkpoint. (code:hlf9). +Put these folder under LASA/output.
The ae folder stores the VAE weight, dm folder stores the diffusion model trained on synthetic data. +finetune_dm folder stores the diffusion model finetuned on LASA dataset. +Run the following commands to evaluate and extract the mesh: +```angular2html +cd evaluation +bash dist_eval.sh +``` +The category entries are the sub-category from arkit scenes, please see ./datasets/taxonomy.py about how they are defined. +For example, if you want to evaluate on LASA's chair, category should contain both arkit_chair and arkit_stool. +make sure the --ae-pth and --dm-pth entry points to the correct checkpoint path. If you are evaluating on LASA, +make sure the --dm-pth points to the finetuned weight in the ./output/finetune_dm folder. The result will be saved +under ./output_result. + +## Training +Run the train_VAE.sh to train the VAE model. If you aims to train on one category, just specify one category from chair, +cabinet, table, sofa, bed, shelf. Inputting all will train on all categories. Makes sure to download and preprocess all +the required sub-category data. The sub-category arrangement can be found in ./datasets/taxonomy.py
+After finish training the VAE model, run the following commands to pre-extract the VAE features for every object: +```angular2html +cd process_scripts +bash dist_export_triplane_features.sh +``` +Then, we can start training the diffusion model on the synthetic dataset by running the train_diffusion.sh.
+Finally, finetune the diffusion model on LASA dataset by running finetune_diffusion.sh.

+ +Early stopping is used by mannualy stopping the training by 150 epochs and 500 epochs for training VAE model and diffusion model respetively. +All experiments in the paper are conducted on 8 A100 GPUs with batch size = 22. +## TODO + +- [ ] Object Detection Code +- [ ] Code for Demo on both arkitscene and in the wild data + +## Citation +``` +@article{liu2023lasa, + title={LASA: Instance Reconstruction from Real Scans using A Large-scale Aligned Shape Annotation Dataset}, + author={Liu, Haolin and Ye, Chongjie and Nie, Yinyu and He, Yingfan and Han, Xiaoguang}, + journal={arXiv preprint arXiv:2312.12418}, + year={2023} +} +``` \ No newline at end of file diff --git a/configs/config_utils.py b/configs/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2371e31a713d06c4982acb86b69310c5d28c47b7 --- /dev/null +++ b/configs/config_utils.py @@ -0,0 +1,70 @@ +import os +import yaml +import logging +from datetime import datetime + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + + ''' + for k, v in dict2.items(): + if k not in dict1: + dict1[k] = dict() + if isinstance(v, dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + +class CONFIG(object): + ''' + Stores all configures + ''' + def __init__(self, input=None): + ''' + Loads config file + :param path (str): path to config file + :return: + ''' + self.config = self.read_to_dict(input) + + def read_to_dict(self, input): + if not input: + return dict() + if isinstance(input, str) and os.path.isfile(input): + if input.endswith('yaml'): + with open(input, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + else: + ValueError('Config file should be with the format of *.yaml') + elif isinstance(input, dict): + config = input + else: + raise ValueError('Unrecognized input type (i.e. not *.yaml file nor dict).') + + return config + + def update_config(self, *args, **kwargs): + ''' + update config and corresponding logger setting + :param input: dict settings add to config file + :return: + ''' + cfg1 = dict() + for item in args: + cfg1.update(self.read_to_dict(item)) + + cfg2 = self.read_to_dict(kwargs) + + new_cfg = {**cfg1, **cfg2} + + update_recursive(self.config, new_cfg) + # when update config file, the corresponding logger should also be updated. + self.__update_logger() + + def write_config(self,save_path): + with open(save_path, 'w') as file: + yaml.dump(self.config, file, default_flow_style = False) \ No newline at end of file diff --git a/configs/finetune_triplane_diffusion.yaml b/configs/finetune_triplane_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7e8ae8986211f68cde8fe1f6702a18df1b871ce7 --- /dev/null +++ b/configs/finetune_triplane_diffusion.yaml @@ -0,0 +1,67 @@ +model: + ae: #ae model is loaded to + type: TriVAE + point_emb_dim: 48 + padding: 0.1 + encoder: + plane_reso: 128 + plane_latent_dim: 32 + latent_dim: 32 + unet: + depth: 4 + merge_mode: concat + start_filts: 32 + output_dim: 64 + decoder: + plane_reso: 128 + latent_dim: 32 + n_blocks: 5 + query_emb_dim: 48 + hidden_dim: 128 + unet: + depth: 4 + merge_mode: concat + start_filts: 64 + output_dim: 32 + dm: + type: triplane_diff_multiimg_cond + backbone: resunet_multiimg_direct_atten + diff_reso: 64 + input_channel: 32 + output_channel: 32 + triplane_padding: 0.1 #should be consistent with padding in ae + + use_par: True + par_channel: 32 + par_emb_dim: 48 + norm: "batch" + img_in_channels: 1280 + vit_reso: 16 + use_cat_embedding: ??? + block_type: multiview_local + par_point_encoder: + plane_reso: 64 + plane_latent_dim: 32 + n_blocks: 5 + unet: + depth: 3 + merge_mode: concat + start_filts: 32 + output_dim: 32 +criterion: + type: EDMLoss_MultiImgCond + use_par: True +dataset: + type: Occ_Par_MultiImg_Finetune + data_path: ??? + surface_size: 20000 + par_pc_size: 2048 + load_proj_mat: True + load_image: True + par_point_aug: 0.5 + par_prefix: "aug7_" + keyword: lowres #use lowres arkitscene or highres to train, lowres scene is more user accessible + jitter_partial_pretrain: 0.02 + jitter_partial_finetune: 0.02 + jitter_partial_val: 0.0 + use_pretrain_data: False diff --git a/configs/train_triplane_diffusion.yaml b/configs/train_triplane_diffusion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d8d5e87a974cde3fe2337d2eeabd34b681f28f6 --- /dev/null +++ b/configs/train_triplane_diffusion.yaml @@ -0,0 +1,64 @@ +model: + ae: #ae model is loaded to + type: TriVAE + point_emb_dim: 48 + padding: 0.1 + encoder: + plane_reso: 128 + plane_latent_dim: 32 + latent_dim: 32 + unet: + depth: 4 + merge_mode: concat + start_filts: 32 + output_dim: 64 + decoder: + plane_reso: 128 + latent_dim: 32 + n_blocks: 5 + query_emb_dim: 48 + hidden_dim: 128 + unet: + depth: 4 + merge_mode: concat + start_filts: 64 + output_dim: 32 + dm: + type: triplane_diff_multiimg_cond + backbone: resunet_multiimg_direct_atten + diff_reso: 64 + input_channel: 32 + output_channel: 32 + triplane_padding: 0.1 #should be consistent with padding in ae + + use_par: True + par_channel: 32 + par_emb_dim: 48 + norm: "batch" + img_in_channels: 1280 + vit_reso: 16 + use_cat_embedding: ??? + block_type: multiview_local + par_point_encoder: + plane_reso: 64 + plane_latent_dim: 32 + n_blocks: 5 + unet: + depth: 3 + merge_mode: concat + start_filts: 32 + output_dim: 32 +criterion: + type: EDMLoss_MultiImgCond + use_par: True +dataset: + type: Occ_Par_MultiImg + data_path: ??? + surface_size: 20000 + par_pc_size: 2048 + load_proj_mat: True + load_image: True + par_point_aug: 0.5 + par_prefix: "aug7_" # prefix of the filenames of the partial point cloud + jitter_partial_train: 0.02 + jitter_partial_val: 0.0 diff --git a/configs/train_triplane_vae.yaml b/configs/train_triplane_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34b8757af33c770d8ccd2bd5837040c64404fe63 --- /dev/null +++ b/configs/train_triplane_vae.yaml @@ -0,0 +1,30 @@ +model: + type: TriVAE + point_emb_dim: 48 + padding: 0.1 + encoder: + plane_reso: 128 + plane_latent_dim: 32 + latent_dim: 32 + unet: + depth: 4 + merge_mode: concat + start_filts: 32 + output_dim: 64 + decoder: + plane_reso: 128 + latent_dim: 32 + n_blocks: 5 + query_emb_dim: 48 + hidden_dim: 128 + unet: + depth: 4 + merge_mode: concat + start_filts: 64 + output_dim: 32 +dataset: + type: Occ + category: chair + data_path: ??? + surface_size: 20000 + num_samples: 2048 \ No newline at end of file diff --git a/data/download_preprocess_data_here b/data/download_preprocess_data_here new file mode 100644 index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de --- /dev/null +++ b/data/download_preprocess_data_here @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/datasets/SingleView_dataset.py b/datasets/SingleView_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8346b7d01c90974a8be8e1b120ff96a71695a4 --- /dev/null +++ b/datasets/SingleView_dataset.py @@ -0,0 +1,453 @@ +import os +import glob +import random + +import yaml + +import torch +from torch.utils import data + +import numpy as np +import json + +from PIL import Image + +import h5py +import torch.distributed as dist +import open3d as o3d +o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) +import pickle as p +import time +import cv2 +from torchvision import transforms +import copy +from datasets.taxonomy import category_map_from_synthetic as category_ids +class Object_Occ(data.Dataset): + def __init__(self, dataset_folder, split, categories=['03001627', "future_chair", 'ABO_chair'], transform=None, + sampling=True, + num_samples=4096, return_surface=True, surface_sampling=True, surface_size=2048, replica=16): + + self.pc_size = surface_size + + self.transform = transform + self.num_samples = num_samples + self.sampling = sampling + self.split = split + + self.dataset_folder = dataset_folder + self.return_surface = return_surface + self.surface_sampling = surface_sampling + + self.dataset_folder = dataset_folder + self.point_folder = os.path.join(self.dataset_folder, 'occ_data') + self.mesh_folder = os.path.join(self.dataset_folder, 'other_data') + + if categories is None: + categories = os.listdir(self.point_folder) + categories = [c for c in categories if + os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')] + categories.sort() + + print(categories) + + self.models = [] + for c_idx, c in enumerate(categories): + subpath = os.path.join(self.point_folder, c) + print(subpath) + assert os.path.isdir(subpath) + + split_file = os.path.join(subpath, split + '.lst') + with open(split_file, 'r') as f: + models_c = f.readlines() + models_c = [item.rstrip('\n') for item in models_c] + + for m in models_c[:]: + if len(m)<=1: + continue + if m.endswith('.npz'): + model_id = m[:-4] + else: + model_id = m + self.models.append({ + 'category': c, 'model': model_id + }) + self.replica = replica + + def __getitem__(self, idx): + if self.replica >= 1: + idx = idx % len(self.models) + else: + random_segment = random.randint(0, int(1 / self.replica) - 1) + idx = int(random_segment * self.replica * len(self.models) + idx) + + category = self.models[idx]['category'] + model = self.models[idx]['model'] + + point_path = os.path.join(self.point_folder, category, model + '.npz') + # print(point_path) + try: + start_t = time.time() + with np.load(point_path) as data: + vol_points = data['vol_points'] + vol_label = data['vol_label'] + near_points = data['near_points'] + near_label = data['near_label'] + end_t = time.time() + # print("loading time %f"%(end_t-start_t)) + except Exception as e: + print(e) + print(point_path) + + with open(point_path.replace('.npz', '.npy'), 'rb') as f: + scale = np.load(f).item() + # scale=1.0 + + if self.return_surface: + pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz') + with np.load(pc_path) as data: + try: + surface = data['points'].astype(np.float32) + except: + print(pc_path,"has problems") + raise AttributeError + surface = surface * scale + if self.surface_sampling: + ind = np.random.default_rng().choice(surface.shape[0], self.pc_size, replace=False) + surface = surface[ind] + surface = torch.from_numpy(surface) + + if self.sampling: + '''need to conduct label balancing''' + vol_ind=np.random.default_rng().choice(vol_points.shape[0], self.num_samples, + replace=(vol_points.shape[0]= 1: + idx = idx % len(self.models) + else: + random_segment = random.randint(0, int(1 / self.replica) - 1) + idx = int(random_segment * self.replica * len(self.models) + idx) + category = self.models[idx]['category'] + model = self.models[idx]['model'] + #image_filenames = self.model_images_names[model] + image_filenames = self.models[idx]["image_filenames"] + if self.split=="train": + n_frames = np.random.randint(min(2,len(image_filenames)), min(len(image_filenames) + 1, self.max_img_length + 1)) + img_indexes = np.random.choice(len(image_filenames), n_frames, + replace=(n_frames > len(image_filenames))).tolist() + else: + if self.eval_multiview: + '''use all images''' + n_frames=len(image_filenames) + img_indexes=[i for i in range(n_frames)] + else: + n_frames = min(len(image_filenames),self.max_img_length) + img_indexes=np.linspace(start=0,stop=len(image_filenames)-1,num=n_frames).astype(np.int32) + + partial_filenames = self.models[idx]['partial_filenames'] + par_index = np.random.choice(len(partial_filenames), 1)[0] + partial_name = partial_filenames[par_index] + + vol_points,vol_label,near_points,near_label=None,None,None,None + points,labels=None,None + point_path = os.path.join(self.point_folder, category, model + '.npz') + if self.ret_sample: + vol_points,vol_label,near_points,near_label=self.load_samples(point_path) + points,labels = self.process_samples(vol_points, vol_label, near_points,near_label) + + with open(point_path.replace('.npz', '.npy'), 'rb') as f: + scale = np.load(f).item() + + surface=None + pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz') + if self.return_surface: + surface=self.load_surface(pc_path,scale) + + partial_path = os.path.join(self.mesh_folder, category, "5_partial_points", model, partial_name) + if self.par_point_aug is not None and random.random()= 0.0] = 1 + label=labels[j:j+1] + + accuracy = (pred == label).float().sum(dim=1) / label.shape[1] + accuracy = accuracy.mean() + intersection = (pred * label).sum(dim=1) + union = (pred + label).gt(0).sum(dim=1) + iou = intersection * 1.0 / union + 1e-5 + iou = iou.mean() + + metric_logger.update(iou=iou.item()) + metric_logger.update(accuracy=accuracy.item()) + metric_logger.update(loss=loss.item()) + metric_logger.synchronize_between_processes() + print('* iou {ious.global_avg:.3f}' + .format(ious=metric_logger.iou)) + print('* accuracy {accuracies.global_avg:.3f}' + .format(accuracies=metric_logger.accuracy)) + print('* loss {losses.global_avg:.3f}' + .format(losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} \ No newline at end of file diff --git a/engine/engine_triplane_vae.py b/engine/engine_triplane_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..765aad41829f145bc701e1f5c4fa09a6c32b0d63 --- /dev/null +++ b/engine/engine_triplane_vae.py @@ -0,0 +1,185 @@ +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import math +import sys +sys.path.append("..") +from typing import Iterable + +import torch +import torch.nn.functional as F + +import util.misc as misc +import util.lr_sched as lr_sched + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, + log_writer=None, args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 20 + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + kl_weight = 25e-3 #TODO: try to modify this, it is 1e-3 originally, large kl ease the training of diffusion, but decrease in VAE results + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + for data_iter_step, data_batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) + + points = data_batch['points'].to(device, non_blocking=True) + labels = data_batch['labels'].to(device, non_blocking=True) + surface = data_batch['surface'].to(device, non_blocking=True) + # print(points.shape) + with torch.cuda.amp.autocast(enabled=False): + outputs = model(surface, points) + if 'kl' in outputs: + loss_kl = outputs['kl'] + #print(loss_kl.shape) + loss_kl = torch.sum(loss_kl) / loss_kl.shape[0] + else: + loss_kl = None + + outputs = outputs['logits'] + + num_samples=outputs.shape[1]//2 + #print(num_samples) + loss_vol = criterion(outputs[:, :num_samples], labels[:, :num_samples]) + loss_near = criterion(outputs[:, num_samples:], labels[:, num_samples:]) + + if loss_kl is not None: + loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl + else: + loss = loss_vol + 0.1 * loss_near + + loss_value = loss.item() + + threshold = 0 + + pred = torch.zeros_like(outputs[:, :num_samples]) + pred[outputs[:, :num_samples] >= threshold] = 1 + + accuracy = (pred == labels[:, :num_samples]).float().sum(dim=1) / labels[:, :num_samples].shape[1] + accuracy = accuracy.mean() + intersection = (pred * labels[:, :num_samples]).sum(dim=1) + union = (pred + labels[:, :num_samples]).gt(0).sum(dim=1) + 1e-5 + iou = intersection * 1.0 / union + iou = iou.mean() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=False, + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + metric_logger.update(loss_vol=loss_vol.item()) + metric_logger.update(loss_near=loss_near.item()) + + if loss_kl is not None: + metric_logger.update(loss_kl=loss_kl.item()) + + metric_logger.update(iou=iou.item()) + + min_lr = 10. + max_lr = 0. + for group in optimizer.param_groups: + min_lr = min(min_lr, group["lr"]) + max_lr = max(max_lr, group["lr"]) + + metric_logger.update(lr=max_lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + iou_reduce=misc.all_reduce_mean(iou) + if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('iou', iou_reduce, epoch_1000x) + log_writer.add_scalar('lr', max_lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(data_loader, model, device): + criterion = torch.nn.BCEWithLogitsLoss() + + metric_logger = misc.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + for data_batch in metric_logger.log_every(data_loader, 50, header): + + points = data_batch['points'].to(device, non_blocking=True) + labels = data_batch['labels'].to(device, non_blocking=True) + surface = data_batch['surface'].to(device, non_blocking=True) + # compute output + with torch.cuda.amp.autocast(enabled=False): + + outputs = model(surface, points) + if 'kl' in outputs: + loss_kl = outputs['kl'] + loss_kl = torch.sum(loss_kl) / loss_kl.shape[0] + else: + loss_kl = None + + outputs = outputs['logits'] + + loss = criterion(outputs, labels) + + threshold = 0 + + pred = torch.zeros_like(outputs) + pred[outputs >= threshold] = 1 + + accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1] + accuracy = accuracy.mean() + intersection = (pred * labels).sum(dim=1) + union = (pred + labels).gt(0).sum(dim=1) + iou = intersection * 1.0 / union + 1e-5 + iou = iou.mean() + + batch_size = points.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['iou'].update(iou.item(), n=batch_size) + + if loss_kl is not None: + metric_logger.update(loss_kl=loss_kl.item()) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print('* iou {iou.global_avg:.3f} loss {losses.global_avg:.3f}' + .format(iou=metric_logger.iou, losses=metric_logger.loss)) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} \ No newline at end of file diff --git a/evaluation/dist_eval.sh b/evaluation/dist_eval.sh new file mode 100644 index 0000000000000000000000000000000000000000..43387a13d173ae55dece292652ddfe5874aa65dd --- /dev/null +++ b/evaluation/dist_eval.sh @@ -0,0 +1,16 @@ +CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \ +evaluate_object_reconstruction.py \ +--configs ../configs/finetune_triplane_diffusion.yaml \ +--category arkit_chair arkit_stool \ +--ae-pth ../output/ae/chair/best-checkpoint.pth \ +--dm-pth ../output/finetune_dm/lowres_chair/best-checkpoint.pth \ +--output_folder ../output_result/chair_result \ +--data-pth ../data \ +--eval_cd \ +--reso 256 \ +--save_mesh \ +--save_par_points \ +--save_image \ +--save_surface + +#check ./datasets/taxonomy to see how sub categories are defined diff --git a/evaluation/evaluate_object_reconstruction.py b/evaluation/evaluate_object_reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..65d414963a218dfd92abc36a6b9bbeb982a4fcce --- /dev/null +++ b/evaluation/evaluate_object_reconstruction.py @@ -0,0 +1,239 @@ +import argparse +import sys +sys.path.append("..") +sys.path.append(".") +import numpy as np + +import mcubes +import os +import torch + +import trimesh + +from datasets.SingleView_dataset import Object_PartialPoints_MultiImg +from datasets.transforms import Scale_Shift_Rotate +from models import get_model +from pathlib import Path +import open3d as o3d +from configs.config_utils import CONFIG +import cv2 +from util.misc import MetricLogger +import scipy +from pyTorchChamferDistance.chamfer_distance import ChamferDistance +from util.projection_utils import draw_proj_image +from util import misc +import time +dist_chamfer=ChamferDistance() + + +def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5): + """ p2: reference ponits + (B, N, 3) + """ + p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale + f_thresh = space_ext * fscore_param + + #print(p1.shape,p2.shape) + d1, d2, _, _ = dist_chamfer(p1, p2) + #print(d1.shape,d2.shape) + d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5) + chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1) + chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1) + precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1] + recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1] + #print(precision,recall) + fscore = 2 * torch.div(recall * precision, recall + precision) + fscore[fscore == float("inf")] = 0 + return chamfer_L1,chamfer_L2,fscore + +if __name__ == "__main__": + + parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False) + parser.add_argument('--configs',type=str,required=True) + parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926") + parser.add_argument('--dm-pth',type=str) + parser.add_argument('--ae-pth',type=str) + parser.add_argument('--data-pth', type=str,default="../") + parser.add_argument('--save_mesh',action="store_true",default=False) + parser.add_argument('--save_image',action="store_true",default=False) + parser.add_argument('--save_par_points', action="store_true", default=False) + parser.add_argument('--save_proj_img',action="store_true",default=False) + parser.add_argument('--save_surface',action="store_true",default=False) + parser.add_argument('--reso',default=128,type=int) + parser.add_argument('--category',nargs="+",type=str) + parser.add_argument('--eval_cd',action="store_true",default=False) + parser.add_argument('--use_augmentation',action="store_true",default=False) + + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + args = parser.parse_args() + misc.init_distributed_mode(args) + config_path=args.configs + config=CONFIG(config_path) + dataset_config=config.config['dataset'] + dataset_config['data_path']=args.data_pth + if "arkit" in args.category[0]: + split_filename=dataset_config['keyword']+'_val_par_img.json' + else: + split_filename='val_par_img.json' + + transform = None + if args.use_augmentation: + transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1)) + dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val", + transform=transform, sampling=False, + num_samples=1024, return_surface=True,ret_sample=True, + surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000, + load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1) + batch_size=1 + + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + val_sampler = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, + shuffle=False) # shu + dataloader_val=torch.utils.data.DataLoader( + dataset_val, + sampler=val_sampler, + batch_size=batch_size, + num_workers=10, + shuffle=False, + ) + output_folder=args.output_folder + + device = torch.device('cuda') + + ae_config=config.config['model']['ae'] + dm_config=config.config['model']['dm'] + ae_model=get_model(ae_config).to(device) + if args.category[0] == "all": + dm_config["use_cat_embedding"]=True + else: + dm_config["use_cat_embedding"] = False + dm_model=get_model(dm_config).to(device) + ae_model.eval() + dm_model.eval() + ae_model.load_state_dict(torch.load(args.ae_pth)['model']) + dm_model.load_state_dict(torch.load(args.dm_pth)['model']) + + density = args.reso + gap = 2.2 / density + x = np.linspace(-1.1, 1.1, int(density + 1)) + y = np.linspace(-1.1, 1.1, int(density + 1)) + z = np.linspace(-1.1, 1.1, int(density + 1)) + xv, yv, zv = np.meshgrid(x, y, z,indexing='ij') + grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True) + + metric_logger=MetricLogger(delimiter=" ") + header = 'Test:' + + with torch.no_grad(): + for data_batch in metric_logger.log_every(dataloader_val,10, header): + # if data_iter_step==100: + # break + partial_name = data_batch['partial_name'] + class_name = data_batch['class_name'] + model_ids=data_batch['model_id'] + surface=data_batch['surface'] + proj_matrices=data_batch['proj_mat'] + sample_points=data_batch["points"].cuda().float() + labels=data_batch["labels"].cuda().float() + sample_input=dm_model.prepare_sample_data(data_batch) + #t1 = time.time() + sampled_array = dm_model.sample(sample_input,num_steps=36).float() + #t2 = time.time() + #sample_time = t2 - t1 + #print("sampling time %f" % (sample_time)) + sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear") + for j in range(sampled_array.shape[0]): + if args.save_mesh | args.save_par_points | args.save_image: + object_folder = os.path.join(output_folder, class_name[j], model_ids[j]) + Path(object_folder).mkdir(parents=True, exist_ok=True) + '''calculate iou''' + sample_point=sample_points[j:j+1] + sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point) + sample_pred=torch.zeros_like(sample_output) + sample_pred[sample_output>=0.0]=1 + label=labels[j:j+1] + intersection = (sample_pred * label).sum(dim=1) + union = (sample_pred + label).gt(0).sum(dim=1) + iou = intersection * 1.0 / union + 1e-5 + iou = iou.mean() + metric_logger.update(iou=iou.item()) + + if args.use_augmentation: + tran_mat=data_batch["tran_mat"][j].numpy() + mat_save_path='{}/tran_mat.npy'.format(object_folder) + np.save(mat_save_path,tran_mat) + + if args.eval_cd: + grid_list=torch.split(grid,128**3,dim=1) + output_list=[] + #t3=time.time() + for sub_grid in grid_list: + output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid)) + output=torch.cat(output_list,dim=1) + #t4=time.time() + #decoding_time=t4-t3 + #print("decoding time:",decoding_time) + logits = output[j].detach() + + volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy() + verts, faces = mcubes.marching_cubes(volume, 0) + + verts *= gap + verts -= 1.1 + #print("vertice max min",np.amin(verts,axis=0),np.amax(verts,axis=0)) + + + m = trimesh.Trimesh(verts, faces) + '''calculate fscore and chamfer distance''' + result_surface,_=trimesh.sample.sample_surface(m,100000) + gt_surface=surface[j] + assert gt_surface.shape[0]==result_surface.shape[0] + + result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0) + gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0) + _,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu) + metric_logger.update(chamferl2=chamfer_L2*1000.0) + metric_logger.update(fscore=fscore) + + if args.save_mesh: + m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j])) + + if args.save_par_points: + par_point_input = data_batch['par_points'][j].numpy() + #print("input max min", np.amin(par_point_input, axis=0), np.amax(par_point_input, axis=0)) + par_point_o3d = o3d.geometry.PointCloud() + par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3]) + o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d) + if args.save_image: + image_list=data_batch["org_image"] + for idx,image in enumerate(image_list): + image=image[0].numpy().astype(np.uint8) + if args.save_proj_img: + proj_mat=proj_matrices[j,idx].numpy() + proj_image=draw_proj_image(image,proj_mat,result_surface) + proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx) + cv2.imwrite(proj_save_path,proj_image) + save_path='{}/{}.jpg'.format(object_folder, idx) + cv2.imwrite(save_path,image) + if args.save_surface: + surface=gt_surface.numpy().astype(np.float32) + surface_o3d = o3d.geometry.PointCloud() + surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3]) + o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d) + metric_logger.synchronize_between_processes() + print('* iou {ious.global_avg:.3f}' + .format(ious=metric_logger.iou)) + if args.eval_cd: + print('* chamferl2 {chamferl2s.global_avg:.3f}' + .format(chamferl2s=metric_logger.chamferl2)) + print('* fscore {fscores.global_avg:.3f}' + .format(fscores=metric_logger.fscore)) diff --git a/evaluation/pyTorchChamferDistance/.gitignore b/evaluation/pyTorchChamferDistance/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..506b29de9d26cace37122214d09d00b204a7c716 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +._* + diff --git a/evaluation/pyTorchChamferDistance/LICENSE.md b/evaluation/pyTorchChamferDistance/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..8aa26455d23acf904be3ed9dfb3a3efe3e49245a --- /dev/null +++ b/evaluation/pyTorchChamferDistance/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) [year] [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/evaluation/pyTorchChamferDistance/README.md b/evaluation/pyTorchChamferDistance/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2902a61fb658c10b057c67b406d937ef9f539284 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/README.md @@ -0,0 +1,23 @@ +# Chamfer Distance for pyTorch + +This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension. + +As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run. + +### Usage +```python +from chamfer_distance import ChamferDistance +chamfer_dist = ChamferDistance() + +#... +# points and points_reconstructed are n_points x 3 matrices + +dist1, dist2 = chamfer_dist(points, points_reconstructed) +loss = (torch.mean(dist1)) + (torch.mean(dist2)) + + +#... +``` + +### Integration +This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :) diff --git a/evaluation/pyTorchChamferDistance/__init__.py b/evaluation/pyTorchChamferDistance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py b/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e15be7028d12ddc55b29752ac718c5284200203 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py @@ -0,0 +1 @@ +from .chamfer_distance import ChamferDistance diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40f3d79aee526188f2df559a34e563026933be56 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp @@ -0,0 +1,185 @@ +#include + +// CUDA forward declarations +int ChamferDistanceKernelLauncher( + const int b, const int n, + const float* xyz, + const int m, + const float* xyz2, + float* result, + int* result_i, + float* result2, + int* result2_i); + +int ChamferDistanceGradKernelLauncher( + const int b, const int n, + const float* xyz1, + const int m, + const float* xyz2, + const float* grad_dist1, + const int* idx1, + const float* grad_dist2, + const int* idx2, + float* grad_xyz1, + float* grad_xyz2); + + +void chamfer_distance_forward_cuda( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor dist1, + const at::Tensor dist2, + const at::Tensor idx1, + const at::Tensor idx2) +{ + ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), + xyz2.size(1), xyz2.data(), + dist1.data(), idx1.data(), + dist2.data(), idx2.data()); +} + +void chamfer_distance_backward_cuda( + const at::Tensor xyz1, + const at::Tensor xyz2, + at::Tensor gradxyz1, + at::Tensor gradxyz2, + at::Tensor graddist1, + at::Tensor graddist2, + at::Tensor idx1, + at::Tensor idx2) +{ + ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), + xyz2.size(1), xyz2.data(), + graddist1.data(), idx1.data(), + graddist2.data(), idx2.data(), + gradxyz1.data(), gradxyz2.data()); +} + + +void nnsearch( + const int b, const int n, const int m, + const float* xyz1, + const float* xyz2, + float* dist, + int* idx) +{ + for (int i = 0; i < b; i++) { + for (int j = 0; j < n; j++) { + const float x1 = xyz1[(i*n+j)*3+0]; + const float y1 = xyz1[(i*n+j)*3+1]; + const float z1 = xyz1[(i*n+j)*3+2]; + double best = 0; + int besti = 0; + for (int k = 0; k < m; k++) { + const float x2 = xyz2[(i*m+k)*3+0] - x1; + const float y2 = xyz2[(i*m+k)*3+1] - y1; + const float z2 = xyz2[(i*m+k)*3+2] - z1; + const double d=x2*x2+y2*y2+z2*z2; + if (k==0 || d < best){ + best = d; + besti = k; + } + } + dist[i*n+j] = best; + idx[i*n+j] = besti; + } + } +} + + +void chamfer_distance_forward( + const at::Tensor xyz1, + const at::Tensor xyz2, + const at::Tensor dist1, + const at::Tensor dist2, + const at::Tensor idx1, + const at::Tensor idx2) +{ + const int batchsize = xyz1.size(0); + const int n = xyz1.size(1); + const int m = xyz2.size(1); + + const float* xyz1_data = xyz1.data(); + const float* xyz2_data = xyz2.data(); + float* dist1_data = dist1.data(); + float* dist2_data = dist2.data(); + int* idx1_data = idx1.data(); + int* idx2_data = idx2.data(); + + nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); + nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); +} + + +void chamfer_distance_backward( + const at::Tensor xyz1, + const at::Tensor xyz2, + at::Tensor gradxyz1, + at::Tensor gradxyz2, + at::Tensor graddist1, + at::Tensor graddist2, + at::Tensor idx1, + at::Tensor idx2) +{ + const int b = xyz1.size(0); + const int n = xyz1.size(1); + const int m = xyz2.size(1); + + const float* xyz1_data = xyz1.data(); + const float* xyz2_data = xyz2.data(); + float* gradxyz1_data = gradxyz1.data(); + float* gradxyz2_data = gradxyz2.data(); + float* graddist1_data = graddist1.data(); + float* graddist2_data = graddist2.data(); + const int* idx1_data = idx1.data(); + const int* idx2_data = idx2.data(); + + for (int i = 0; i < b*n*3; i++) + gradxyz1_data[i] = 0; + for (int i = 0; i < b*m*3; i++) + gradxyz2_data[i] = 0; + for (int i = 0;i < b; i++) { + for (int j = 0; j < n; j++) { + const float x1 = xyz1_data[(i*n+j)*3+0]; + const float y1 = xyz1_data[(i*n+j)*3+1]; + const float z1 = xyz1_data[(i*n+j)*3+2]; + const int j2 = idx1_data[i*n+j]; + + const float x2 = xyz2_data[(i*m+j2)*3+0]; + const float y2 = xyz2_data[(i*m+j2)*3+1]; + const float z2 = xyz2_data[(i*m+j2)*3+2]; + const float g = graddist1_data[i*n+j]*2; + + gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); + gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); + gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); + gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); + gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); + gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); + } + for (int j = 0; j < m; j++) { + const float x1 = xyz2_data[(i*m+j)*3+0]; + const float y1 = xyz2_data[(i*m+j)*3+1]; + const float z1 = xyz2_data[(i*m+j)*3+2]; + const int j2 = idx2_data[i*m+j]; + const float x2 = xyz1_data[(i*n+j2)*3+0]; + const float y2 = xyz1_data[(i*n+j2)*3+1]; + const float z2 = xyz1_data[(i*n+j2)*3+2]; + const float g = graddist2_data[i*m+j]*2; + gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); + gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); + gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); + gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); + gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); + gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); + } + } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); + m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); + m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); + m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); +} diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu new file mode 100644 index 0000000000000000000000000000000000000000..f10f2ba854883d7f590236bb69e3598e8a4ef379 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu @@ -0,0 +1,209 @@ +#include + +#include +#include + +__global__ +void ChamferDistanceKernel( + int b, + int n, + const float* xyz, + int m, + const float* xyz2, + float* result, + int* result_i) +{ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} + +void ChamferDistanceKernelLauncher( + const int b, const int n, + const float* xyz, + const int m, + const float* xyz2, + float* result, + int* result_i, + float* result2, + int* result2_i) +{ + ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); + ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); +} + + +__global__ +void ChamferDistanceGradKernel( + int b, int n, + const float* xyz1, + int m, + const float* xyz2, + const float* grad_dist1, + const int* idx1, + float* grad_xyz1, + float* grad_xyz2) +{ + for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); + ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); +} diff --git a/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..e3bea670105c2900909b85210cf10331063cb5e3 --- /dev/null +++ b/evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py @@ -0,0 +1,58 @@ + +import torch + +from torch.utils.cpp_extension import load +cd = load(name="build", + sources=["pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp", + "pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"], + build_directory="pyTorchChamferDistance/build") + +class ChamferDistanceFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, xyz1, xyz2): + batchsize, n, _ = xyz1.size() + _, m, _ = xyz2.size() + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + dist1 = torch.zeros(batchsize, n) + dist2 = torch.zeros(batchsize, m) + + idx1 = torch.zeros(batchsize, n, dtype=torch.int) + idx2 = torch.zeros(batchsize, m, dtype=torch.int) + + if not xyz1.is_cuda: + cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) + else: + dist1 = dist1.cuda() + dist2 = dist2.cuda() + idx1 = idx1.cuda() + idx2 = idx2.cuda() + cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) + + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + + return dist1, dist2, idx1, idx2 + + @staticmethod + def backward(ctx, graddist1, graddist2, *args): + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + + graddist1 = graddist1.contiguous() + graddist2 = graddist2.contiguous() + + gradxyz1 = torch.zeros(xyz1.size()) + gradxyz2 = torch.zeros(xyz2.size()) + + if not graddist1.is_cuda: + cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) + else: + gradxyz1 = gradxyz1.cuda() + gradxyz2 = gradxyz2.cuda() + cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) + + return gradxyz1, gradxyz2 + + +class ChamferDistance(torch.nn.Module): + def forward(self, xyz1, xyz2): + return ChamferDistanceFunction.apply(xyz1, xyz2) diff --git a/finetune_diffusion.sh b/finetune_diffusion.sh new file mode 100644 index 0000000000000000000000000000000000000000..c9f2a529af4b0a936578e816a41013aaa385ede9 --- /dev/null +++ b/finetune_diffusion.sh @@ -0,0 +1,18 @@ +cd scripts +CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15003 --nproc_per_node=8 \ +train_triplane_diffusion.py \ +--configs ../configs/finetune_triplane_diffusion.yaml \ +--accum_iter 2 \ +--output_dir ../output/finetune_dm/lowres_chair \ +--log_dir ../output/finetune_dm/lowres_chair --num_workers 8 \ +--batch_size 22 \ +--blr 1e-4 \ +--epochs 500 \ +--dist_eval \ +--warmup_epochs 20 \ +--ae-pth ../output/ae/chair/best-checkpoint.pth \ +--category chair \ +--finetune \ +--finetune-pth ../output/dm/chair/best-checkpoint.pth \ +--data-pth ../data \ +--replica 5 \ No newline at end of file diff --git a/models/TriplaneVAE.py b/models/TriplaneVAE.py new file mode 100644 index 0000000000000000000000000000000000000000..477da1488d66914894b461ef46b5191da8f3c839 --- /dev/null +++ b/models/TriplaneVAE.py @@ -0,0 +1,94 @@ +import torch.nn as nn +import sys,os +sys.path.append("..") +import torch +from datasets import build_dataset +from configs.config_utils import CONFIG +from torch.utils.data import DataLoader +from models.modules import PointEmbed +from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder +import numpy as np + +class TriplaneVAE(nn.Module): + def __init__(self,opt): + super().__init__() + self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim']) + + encoder_args=opt['encoder'] + decoder_args=opt['decoder'] + self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'], + plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding']) + self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'], + hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'], + plane_resolution=decoder_args['plane_reso'],padding=opt['padding']) + + def forward(self,p,query): + ''' + :param p: surface points cloud of shape B,N,3 + :param query: sample points of shape B,N,3 + :return: + ''' + point_emb=self.point_embedder(p) + query_emb=self.point_embedder(query) + kl,plane_feat,means,logvars=self.encoder(p,point_emb) + if self.training: + if np.random.random()<0.5: + '''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness''' + plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear") + plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear") + # if self.training: + # if np.random.random()<0.5: + # means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear") + # vars=torch.exp(logvars) + # vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear") + # new_logvars=torch.log(vars) + # posterior = DiagonalGaussianDistribution(means, new_logvars) + # plane_feat=posterior.sample() + # plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear') + + # mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear") + # vars = torch.exp(logvars) + # vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4 + # logvars_scale=torch.log(vars_scale) + # scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device) + # plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise + # plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear') + o=self.decoder(plane_feat,query,query_emb) + + return {'logits':o,'kl':kl} + + + def decode(self,plane_feature,query): + query_embedding=self.point_embedder(query) + o=self.decoder(plane_feature,query,query_embedding) + + return o + + def encode(self,p): + point_emb = self.point_embedder(p) + kl, plane_feat,mean,logvar = self.encoder(p, point_emb) + '''p is point cloud of B,N,3''' + return plane_feat,kl,mean,logvar + +if __name__=="__main__": + configs=CONFIG("../configs/train_triplane_vae_64.yaml") + config=configs.config + dataset_config=config['datasets'] + model_config=config["model"] + dataset=build_dataset("train",dataset_config) + dataset.__getitem__(0) + dataloader=DataLoader( + dataset=dataset, + batch_size=10, + shuffle=True, + num_workers=2, + ) + net=TriplaneVAE(model_config).float().cuda() + for idx,data_batch in enumerate(dataloader): + if idx==1: + break + surface=data_batch['surface'].float().cuda() + query=data_batch['points'].float().cuda() + net(surface,query) + + diff --git a/models/Triplane_Diffusion.py b/models/Triplane_Diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..84ae3a09321368c2fe1ffe4ebbc5065d53f4b73c --- /dev/null +++ b/models/Triplane_Diffusion.py @@ -0,0 +1,190 @@ +import torch +import torch.nn as nn +from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond +from models.modules.parpoints_encoder import ParPoint_Encoder +from models.modules.PointEMB import PointEmbed +from models.modules.utils import StackedRandomGenerator +from models.modules.diffusion_sampler import edm_sampler +from models.modules.encoder import DiagonalGaussianDistribution +import numpy as np +class EDMLoss_MultiImgCond: + def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + self.use_par=use_par + + def __call__(self, net, data_batch, classifier_free=False): + inputs = data_batch['input'] + image=data_batch['image'] + proj_mat=data_batch['proj_mat'] + valid_frames=data_batch['valid_frames'] + par_points=data_batch["par_points"] + category_code=data_batch["category_code"] + rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device) + + sigma = (rnd_normal * self.P_std + self.P_mean).exp() #B,1,1,1 + weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2 + y=inputs + + n = torch.randn_like(y) * sigma + + # if classifier_free and np.random.random()<0.5: + # net.par_feat=torch.zeros((inputs.shape[0],32,inputs.shape[2],inputs.shape[3])).float().to(inputs.device) + if classifier_free and np.random.random()<0.5: + image=torch.zeros_like(image).float().cuda() + net.module.extract_img_feat(image) + net.module.set_proj_matrix(proj_mat) + net.module.set_valid_frames(valid_frames) + net.module.set_category_code(category_code) + if self.use_par: + net.module.extract_point_feat(par_points) + + D_yn = net(y + n,sigma) + loss = weight * ((D_yn - y) ** 2) + return loss + +class Triplane_Diff_MultiImgCond_EDM(nn.Module): + def __init__(self,opt): + super().__init__() + self.diff_reso=opt['diff_reso'] + self.diff_dim=opt['output_channel'] + self.use_cat_embedding=opt['use_cat_embedding'] + self.use_fp16=False + self.sigma_data=0.5 + self.sigma_max=float("inf") + self.sigma_min=0 + self.use_par=opt['use_par'] + self.triplane_padding=opt['triplane_padding'] + self.block_type=opt['block_type'] + #self.use_bn=opt['use_bn'] + if opt['backbone']=="resunet_multiimg_direct_atten": + self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'], + output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'], + img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding, + norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type) + else: + raise NotImplementedError + if opt['use_par']: #use partial point cloud as inputs + par_emb_dim = opt['par_emb_dim'] + par_args = opt['par_point_encoder'] + self.point_embedder = PointEmbed(hidden_dim=par_emb_dim) + self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim, + plane_resolution=par_args['plane_reso'], + unet_kwargs=par_args['unet']) + self.unflatten = torch.nn.Unflatten(1, (16, 16)) + def prepare_data(self,data_batch): + #par_points = data_batch['par_points'].to(device, non_blocking=True) + device=torch.device("cuda") + means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to( + device, non_blocking=True) + distribution = DiagonalGaussianDistribution(means, logvars) + plane_feat = distribution.sample() + + image=data_batch["image"].to(device) + proj_mat = data_batch['proj_mat'].to(device, non_blocking=True) + valid_frames=data_batch["valid_frames"].to(device,non_blocking=True) + par_points=data_batch["par_points"].to(device,non_blocking=True) + category_code=data_batch["category_code"].to(device,non_blocking=True) + input_dict = {"input": plane_feat.float(), + "image": image.float(), + "par_points":par_points.float(), + "proj_mat":proj_mat.float(), + "category_code":category_code.float(), + "valid_frames":valid_frames.float()} # TODO: add image and proj matrix + + return input_dict + + def prepare_sample_data(self,data_batch): + device=torch.device("cuda") + image=data_batch['image'].to(device, non_blocking=True) + proj_mat = data_batch['proj_mat'].to(device, non_blocking=True) + valid_frames = data_batch["valid_frames"].to(device, non_blocking=True) + par_points = data_batch["par_points"].to(device, non_blocking=True) + category_code=data_batch["category_code"].to(device,non_blocking=True) + sample_dict={ + "image":image.float(), + "proj_mat":proj_mat.float(), + "valid_frames":valid_frames.float(), + "category_code":category_code.float(), + "par_points":par_points.float(), + } + return sample_dict + + def prepare_eval_data(self,data_batch): + device=torch.device("cuda") + samples=data_batch["points"].to(device, non_blocking=True) + labels=data_batch['labels'].to(device,non_blocking=True) + + eval_dict={ + "samples":samples, + "labels":labels, + } + return eval_dict + + def extract_point_feat(self,par_points): + par_emb=self.point_embedder(par_points) + self.par_feat=self.par_points_encoder(par_points,par_emb) + + def extract_img_feat(self,image): + self.image_emb=image + + def set_proj_matrix(self,proj_matrix): + self.proj_matrix=proj_matrix + + def set_valid_frames(self,valid_frames): + self.valid_frames=valid_frames + + def set_category_code(self,category_code): + self.category_code=category_code + + def forward(self, x, sigma,force_fp32=False): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) #B,1,1,1 + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() + c_noise = sigma.log() / 4 #B,1,1,1, need to check how to add embedding into unet + + if self.use_par: + F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix, + self.valid_frames,self.category_code,self.par_feat) + else: + F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix, + self.valid_frames,self.category_code) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + + @torch.no_grad() + def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18): + img_cond=input_batch['image'] + proj_mat=input_batch['proj_mat'] + valid_frames=input_batch["valid_frames"] + category_code=input_batch["category_code"] + if img_cond is not None: + batch_size, device = img_cond.shape[0], img_cond.device + if batch_seeds is None: + batch_seeds = torch.arange(batch_size) + else: + device = batch_seeds.device + batch_size = batch_seeds.shape[0] + + self.extract_img_feat(img_cond) + self.set_proj_matrix(proj_mat) + self.set_valid_frames(valid_frames) + self.set_category_code(category_code) + if self.use_par: + par_points=input_batch["par_points"] + self.extract_point_feat(par_points) + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device) + + return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps) + + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1af4a7b4034c52ade0ec2810620ada05e2c93e1 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,20 @@ +from .TriplaneVAE import TriplaneVAE +from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM +from .Triplane_Diffusion import EDMLoss_MultiImgCond +#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug + +def get_model(model_args): + if model_args['type']=="TriVAE": + model=TriplaneVAE(model_args) + elif model_args['type']=="triplane_diff_multiimg_cond": + model=Triplane_Diff_MultiImgCond_EDM(model_args) + else: + raise NotImplementedError + return model + +def get_criterion(cri_args): + if cri_args['type']=="EDMLoss_MultiImgCond": + criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par']) + else: + raise NotImplementedError + return criterion diff --git a/models/modules/PointEMB.py b/models/modules/PointEMB.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5973155b8c1c30f5500ea28d09ca9fc7a6a30a --- /dev/null +++ b/models/modules/PointEMB.py @@ -0,0 +1,34 @@ +import torch.nn as nn +import torch +import numpy as np + +class PointEmbed(nn.Module): + def __init__(self, hidden_dim=48): + super().__init__() + + assert hidden_dim % 6 == 0 + + self.embedding_dim = hidden_dim + e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi + e = torch.stack([ + torch.cat([e, torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), e, + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6), e]), + ]) + self.register_buffer('basis', e) # 3 x 24 + + + @staticmethod + def embed(input, basis): + projections = torch.einsum( + 'bnd,de->bne', input, basis) # N,24 + embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) + return embeddings + + def forward(self, input): + # input: B x N x 3 + embed = self.embed(input, self.basis) + return embed diff --git a/models/modules/Positional_Embedding.py b/models/modules/Positional_Embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4722a52a4825543103868a70cbd7f5f3e5155e --- /dev/null +++ b/models/modules/Positional_Embedding.py @@ -0,0 +1,15 @@ +import torch +class PositionalEmbedding(torch.nn.Module): + def __init__(self, num_channels, max_positions=10000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x diff --git a/models/modules/__init__.py b/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..464c972374c82b95726821f69d22bff6a7b484fe --- /dev/null +++ b/models/modules/__init__.py @@ -0,0 +1,5 @@ +from .encoder import ConvPointnet_Encoder +from .resnet_block import ResnetBlockFC +from .unet import UNet,RollOut_Conv +from .PointEMB import PointEmbed +from .decoder import ConvPointnet_Decoder \ No newline at end of file diff --git a/models/modules/decoder.py b/models/modules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c89db21a92a51dd22fcf49ee8b5fc03830afe76e --- /dev/null +++ b/models/modules/decoder.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_mean, scatter_max +from .unet import UNet +from .resnet_block import ResnetBlockFC +import numpy as np + +class ConvPointnet_Decoder(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + plane_resolution (int): defined resolution for plane feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, latent_dim=32,query_emb_dim=51,hidden_dim=128, unet_kwargs=None, + plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5): + super().__init__() + + self.latent_dim=32 + self.actvn = nn.ReLU() + + self.unet = UNet(unet_kwargs['output_dim'], in_channels=latent_dim, **unet_kwargs) + + self.fc_c=nn.ModuleList + self.reso_plane = plane_resolution + self.plane_type = plane_type + self.padding = padding + self.n_blocks=n_blocks + + self.fc_c = nn.ModuleList([ + nn.Linear(latent_dim*3, hidden_dim) for i in range(n_blocks) + ]) + self.fc_p=nn.Linear(query_emb_dim,hidden_dim) + self.fc_out=nn.Linear(hidden_dim,1) + + self.blocks = nn.ModuleList([ + ResnetBlockFC(hidden_dim) for i in range(n_blocks) + ]) + + def forward(self, plane_features,query,query_emb): # , query2): + plane_feature=self.unet(plane_features) + H,W=plane_feature.shape[2:4] + xz_feat,xy_feat,yz_feat=torch.split(plane_feature,dim=2,split_size_or_sections=H//3) + xz_sample_feat=self.sample_plane_feature(query,xz_feat,'xz') + xy_sample_feat=self.sample_plane_feature(query,xy_feat,'xy') + yz_sample_feat=self.sample_plane_feature(query,yz_feat,'yz') + + sample_feat=torch.cat([xz_sample_feat,xy_sample_feat,yz_sample_feat],dim=1) + sample_feat=sample_feat.transpose(1,2) + + net=self.fc_p(query_emb) + for i in range(self.n_blocks): + net=net+self.fc_c[i](sample_feat) + net=self.blocks[i](net) + out=self.fc_out(self.actvn(net)).squeeze(-1) + return out + + + def normalize_coordinate(self, p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane == 'xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + #print("origin",torch.amin(xy), torch.amax(xy)) + xy=xy/2 #xy is originally -1 ~ 1 + xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + xy_new = xy_new + 0.5 # range (0, 1) + #print("scale",torch.amin(xy_new),torch.amax(xy_new)) + + # f there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + def coordinate2index(self, x, reso): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + index = x[:, :, 0] + reso * x[:, :, 1] + index = index[:, None, :] + return index + + # uses values from plane_feature and pixel locations from vgrid to interpolate feature + def sample_plane_feature(self, query, plane_feature, plane): + xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding) + xy = xy[:, :, None].float() + vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) + sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, + mode='bilinear').squeeze(-1) + return sampled_feat + + + diff --git a/models/modules/diffusion_sampler.py b/models/modules/diffusion_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a09f4705bc2a73b51635919088b7261720c33265 --- /dev/null +++ b/models/modules/diffusion_sampler.py @@ -0,0 +1,89 @@ +import torch +import numpy as np + +def edm_sampler( + net, latents, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + all_x=[] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat, t_hat).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next, t_next).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + all_x.append(x_next.clone()/(t_next**2+1).sqrt()) + + if ret_all: + return x_next,all_x + + return x_next + +def edm_sampler_cond( + net, latents,cond_points, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + # S_churn=40, S_min=0.05, S_max=50, S_noise=1.003, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + all_x=[] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat, t_hat,cond_points).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next, t_next,cond_points).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + all_x.append(x_next.clone()/(t_next**2+1).sqrt()) + + if ret_all: + return x_next,all_x + + return x_next + diff --git a/models/modules/encoder.py b/models/modules/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1e73441d56356964c8258bbcfbf72329ff19d47a --- /dev/null +++ b/models/modules/encoder.py @@ -0,0 +1,235 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_mean, scatter_max +from .unet import UNet +from .resnet_block import ResnetBlockFC +import numpy as np + +class DiagonalGaussianDistribution(object): + def __init__(self, mean, logvar, deterministic=False): + self.mean = mean + self.logvar = logvar + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2,3]) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + +class ConvPointnet_Encoder(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + plane_resolution (int): defined resolution for plane feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128,latent_dim=32, scatter_type='max', + unet=False, unet_kwargs=None, + plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5): + super().__init__() + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2 * hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + if unet: + self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs) + else: + self.unet = None + + self.reso_plane = plane_resolution + self.plane_type = plane_type + self.padding = padding + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + + self.mean_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1) + self.logvar_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1) + + # takes in "p": point cloud and "query": sdf_xyz + # sample plane features for unlabeled_query as well + def forward(self, p,point_emb): # , query2): + batch_size, T, D = p.size() + #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0)) + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding) + index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding) + index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding) + index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane) + net = self.fc_pos(point_emb) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + #print(c.shape) + + fea = {} + plane_feat_sum = 0 + # second_sum = 0 + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, + plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']], + dim=2) # concat at row dimension + #print(cat_feature.shape) + plane_feat=self.unet(cat_feature) + + mean=self.mean_fc(plane_feat) + logvar=self.logvar_fc(plane_feat) + + posterior = DiagonalGaussianDistribution(mean, logvar) + x = posterior.sample() + kl = posterior.kl() + + return kl, x, mean, logvar + + + def normalize_coordinate(self, p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane == 'xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + #print("origin",torch.amin(xy), torch.amax(xy)) + xy=xy/2 #xy is originally -1 ~ 1 + xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + xy_new = xy_new + 0.5 # range (0, 1) + #print("scale",torch.amin(xy_new),torch.amax(xy_new)) + + # f there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + def coordinate2index(self, x, reso): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + index = x[:, :, 0] + reso * x[:, :, 1] + index = index[:, None, :] + return index + + # xy is the normalized coordinates of the point cloud of each plane + # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + index = self.coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, + self.reso_plane) # sparce matrix (B x 512 x reso x reso) + #print(fea_plane.shape) + + return fea_plane + + # sample_plane_feature function copied from /src/conv_onet/models/decoder.py + # uses values from plane_feature and pixel locations from vgrid to interpolate feature + def sample_plane_feature(self, query, plane_feature, plane): + xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding) + xy = xy[:, :, None].float() + vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1) + sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, + mode='bilinear').squeeze(-1) + return sampled_feat + + + diff --git a/models/modules/image_sampler.py b/models/modules/image_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..681403e165aa63b4bc4ff9dcdc20a77342fd8fc0 --- /dev/null +++ b/models/modules/image_sampler.py @@ -0,0 +1,1046 @@ +import sys +sys.path.append('../..') +import torch +import torch.nn as nn +import math +from models.modules.unet import RollOut_Conv +from einops import rearrange, reduce +MB =1024.0*1024.0 +def mask_kernel(x, sigma=1): + return torch.abs(x) < sigma #if the distance is smaller than the kernel size, return True + +def mask_kernel_close_false(x, sigma=1): + return torch.abs(x) > sigma #if the distance is smaller than the kernel size, return False + +class Image_Local_Sampler(nn.Module): + def __init__(self,reso,padding=0.1,in_channels=1280,out_channels=512): + super().__init__() + self.triplane_reso=reso + self.padding=padding + self.get_triplane_coord() + self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1) + def get_triplane_coord(self): + '''xz plane firstly, z is at the ''' + x=torch.arange(self.triplane_reso) + z=torch.arange(self.triplane_reso) + X,Z=torch.meshgrid(x,z,indexing='xy') + xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order + + '''xy plane''' + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + X, Y = torch.meshgrid(x, y, indexing='xy') + xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order + + '''yz plane''' + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + Y,Z = torch.meshgrid(y,z,indexing='xy') + yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1) + + triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0) + triplane_coords=triplane_coords/(self.triplane_reso-1) + triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6) + self.triplane_coords=triplane_coords.float().cuda() + + def forward(self,image_feat,proj_mat): + image_feat=self.img_proj(image_feat) + batch_size=image_feat.shape[0] + triplane_coords=self.triplane_coords.unsqueeze(0).expand(batch_size,-1,-1,-1) #B,192,64,3 + #print(torch.amin(triplane_coords),torch.amax(triplane_coords)) + coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,triplane_coords.shape[1],triplane_coords.shape[2],1)).float().cuda()],dim=-1) + coord_inimg=torch.einsum('bhwc,bck->bhwk',coord_homo,proj_mat.transpose(1,2)) + x=coord_inimg[:,:,:,0]/coord_inimg[:,:,:,2] + y=coord_inimg[:,:,:,1]/coord_inimg[:,:,:,2] + x=(x/(224.0-1.0)-0.5)*2 #-1~1 + y=(y/(224.0-1.0)-0.5)*2 #-1~1 + dist=coord_inimg[:,:,:,2] + + xy=torch.cat([x[:,:,:,None],y[:,:,:,None]],dim=-1) + #print(image_feat.shape,xy.shape) + sample_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear') + return sample_feat + +def position_encoding(d_model, length): + if d_model % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(d_model)) + pe = torch.zeros(length, d_model) + position = torch.arange(0, length).unsqueeze(1) #length,1 + div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * + -(math.log(10000.0) / d_model))) #d_model//2, this is the frequency + pe[:, 0::2] = torch.sin(position.float() * div_term) #length*(d_model//2) + pe[:, 1::2] = torch.cos(position.float() * div_term) + + return pe + +class Image_Vox_Local_Sampler(nn.Module): + def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,n_heads=8): + super().__init__() + self.triplane_reso=reso + self.padding=padding + self.get_vox_coord() + self.out_channels=out_channels + self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1) + + self.vox_process=nn.Sequential( + nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1,), + ) + self.k=nn.Linear(in_features=inner_channel,out_features=inner_channel) + self.q=nn.Linear(in_features=inner_channel,out_features=inner_channel) + self.v=nn.Linear(in_features=inner_channel,out_features=inner_channel) + self.attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + + self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1) + self.condition_pe = position_encoding(inner_channel, self.triplane_reso).unsqueeze(0) + def get_vox_coord(self): + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + + X,Y,Z=torch.meshgrid(x,y,z,indexing='ij') + vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1) + vox_coor=vox_coor/(self.triplane_reso-1) + vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6) + self.vox_coor=vox_coor.view(-1,3).float().cuda() + + + def forward(self,triplane_feat,image_feat,proj_mat): + xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64 + image_feat=self.img_proj(image_feat) + batch_size=image_feat.shape[0] + vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3 + vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1) + coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2)) + x=coord_inimg[:,:,0]/coord_inimg[:,:,2] + y=coord_inimg[:,:,1]/coord_inimg[:,:,2] + x=(x/(224.0-1.0)-0.5)*2 #-1~1 + y=(y/(224.0-1.0)-0.5)*2 #-1~1 + + xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2 + #print(image_feat.shape,xy.shape) + grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\ + view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3 + + grid_feat=self.vox_process(grid_feat) + xzy_grid=grid_feat.permute(0,4,2,3,1) + xz_as_query=xz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1) + xz_as_key=xzy_grid.reshape(batch_size*self.triplane_reso**2,self.triplane_reso,-1) + + xyz_grid=grid_feat.permute(0,3,2,4,1) + xy_as_query=xy_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1) + xy_as_key = xyz_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1) + + yzx_grid = grid_feat.permute(0, 4, 3, 2, 1) + yz_as_query = yz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1) + yz_as_key = yzx_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1) + + query=self.q(torch.cat([xz_as_query,xy_as_query,yz_as_query],dim=0)) + key=self.k(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device) + value=self.v(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device) + + attn,_=self.attn(query,key,value) + xz_plane,xy_plane,yz_plane=torch.split(attn,dim=0,split_size_or_sections=batch_size*self.triplane_reso**2) + xz_plane=xz_plane.reshape(batch_size,self.triplane_reso,self.triplane_reso,-1).permute(0,3,1,2) + xy_plane = xy_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2) + yz_plane = yz_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2) + + triplane_wImg=torch.cat([xz_plane,xy_plane,yz_plane],dim=2) + triplane_wImg=self.proj_out(triplane_wImg) + #print(triplane_wImg.shape) + + return triplane_wImg + +class Image_Direct_AttenwMask_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64, + img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8): + super().__init__() + self.triplane_reso=reso + self.vit_reso=vit_reso + self.padding=padding + self.n_heads=n_heads + self.get_plane_expand_coord() + self.get_vit_coords() + self.out_channels=out_channels + self.kernel_func=mask_kernel + self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel) + self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + + self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels) + self.image_pe = position_encoding(inner_channel, self.vit_reso**2+1).unsqueeze(0).cuda().float() #1,n_img*reso*reso,channel + self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float() + def get_plane_expand_coord(self): + x = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + y = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + z = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + + first,second,third=torch.meshgrid(x,y,z,indexing='xy') + xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3 + xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz + xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy + yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx + + # print(xyz_coords[0,0,0],xyz_coords[0,0,1],xyz_coords[1,0,0],xyz_coords[0,1,0]) + # print(xzy_coords[0, 0, 0], xzy_coords[0, 0, 1], xzy_coords[1, 0, 0], xzy_coords[0, 1, 0]) + # print(yzx_coords[0, 0, 0], yzx_coords[0, 0, 1], yzx_coords[1, 0, 0], yzx_coords[0, 1, 0]) + + xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1) + xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1) + yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1) + + coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0) + self.plane_coords=coords.cuda().float() + # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3 + # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3 + # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3 + + def get_vit_coords(self): + x=torch.arange(self.vit_reso) + y=torch.arange(self.vit_reso) + + X,Y=torch.meshgrid(x,y,indexing='xy') + vit_coords=torch.stack([X,Y],dim=-1) + self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float() + + def get_attn_mask(self,coords_proj,vit_coords,kernel_size=1.0): + ''' + :param coords_proj: B,reso**3,2, in range of 0~1 + :param vit_coords: B,vit_reso**2,2, in range of 0~vit_reso + :param kernel_size: 0.5, so that only one pixel will be select + :return: + ''' + bs=coords_proj.shape[0] + coords_proj=coords_proj*(self.vit_reso-1) + #print(torch.amin(coords_proj[0,0:self.triplane_reso**3]),torch.amax(coords_proj[0,0:self.triplane_reso**3])) + dist=torch.cdist(coords_proj.float(),vit_coords.float()) + mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B,3*reso**3,vit_reso**2 + mask=mask.reshape(bs,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2) + mask=torch.sum(mask,dim=2) + attn_mask=(mask==0) + return attn_mask + + def forward(self,triplane_feat,image_feat,proj_mat): + #xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64 + batch_size=image_feat.shape[0] + #print(self.plane_coords.shape) + coords=self.plane_coords.unsqueeze(0).expand(batch_size,-1,-1) + + coords_homo=torch.cat([coords,torch.ones(batch_size,self.triplane_reso**3*3,1).float().cuda()],dim=-1) + coords_inimg=torch.einsum('bhc,bck->bhk',coords_homo,proj_mat.transpose(1,2)) + coords_x=coords_inimg[:,:,0]/coords_inimg[:,:,2]/(224.0-1) #0~1 + coords_y=coords_inimg[:,:,1]/coords_inimg[:,:,2]/(224.0-1) #0~1 + coords_x=torch.clamp(coords_x,min=0.0,max=1.0) + coords_y=torch.clamp(coords_y,min=0.0,max=1.0) + #print(torch.amin(coords_x),torch.amax(coords_x)) + coords_proj=torch.stack([coords_x,coords_y],dim=-1) + vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) + attn_mask=torch.repeat_interleave( + self.get_attn_mask(coords_proj,vit_coords,kernel_size=1.0),self.n_heads, 0 + ) + attn_mask = torch.cat([torch.zeros([attn_mask.shape[0], attn_mask.shape[1], 1]).cuda().bool(), attn_mask], + dim=-1) # add global token + #print(attn_mask.shape,torch.sum(attn_mask.float())) + triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1) + #print(triplane_feat.shape,self.triplane_pe.shape) + query=self.q(triplane_feat)+self.triplane_pe + key=self.k(image_feat)+self.image_pe + value=self.v(image_feat)+self.image_pe + #print(query.shape,key.shape,value.shape) + attn,_=self.attn(query,key,value,attn_mask=attn_mask) + #print(attn.shape) + output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso) + + return output + +class MultiImage_Direct_AttenwMask_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64, + img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5): + super().__init__() + self.triplane_reso=reso + self.vit_reso=vit_reso + self.padding=padding + self.n_heads=n_heads + self.get_plane_expand_coord() + self.get_vit_coords() + self.out_channels=out_channels + self.kernel_func=mask_kernel + self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel) + self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + + self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels) + self.image_pe = position_encoding(inner_channel, max_nimg*(self.vit_reso**2+1)).unsqueeze(0).cuda().float() + self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float() + def get_plane_expand_coord(self): + x = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + y = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + z = torch.arange(self.triplane_reso)/(self.triplane_reso-1) + + first,second,third=torch.meshgrid(x,y,z,indexing='xy') + xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3 + xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz + xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy + yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx + + xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1) + xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1) + yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1) + + coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0) + self.plane_coords=coords.cuda().float() + # self.xzy_coords=xzy_coords.cuda().float() #reso**3,3 + # self.xyz_coords=xyz_coords.cuda().float() #reso**3,3 + # self.yzx_coords=yzx_coords.cuda().float() #reso**3,3 + + def get_vit_coords(self): + x=torch.arange(self.vit_reso) + y=torch.arange(self.vit_reso) + + X,Y=torch.meshgrid(x,y,indexing='xy') + vit_coords=torch.stack([X,Y],dim=-1) + self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float() + + def get_attn_mask(self,coords_proj,vit_coords,valid_frames,kernel_size=1.0): + ''' + :param coords_proj: B,n_img,3*reso**3,2, in range of 0~vit_reso + :param vit_coords: B,n_img,vit_reso**2,2, in range of 0~vit_reso + :param kernel_size: 0.5, so that only one pixel will be select + :return: + ''' + bs,n_img=coords_proj.shape[0],coords_proj.shape[1] + coords_proj_flat=coords_proj.reshape(bs*n_img,3*self.triplane_reso**3,2) + vit_coords_flat=vit_coords.reshape(bs*n_img,self.vit_reso**2,2) + dist=torch.cdist(coords_proj_flat.float(),vit_coords_flat.float()) + mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B*n_img,3*reso**3,vit_reso**2 + mask=mask.reshape(bs,n_img,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2) + mask=torch.sum(mask,dim=3) #B,n_img,3*reso**2,vit_reso**2 + mask=torch.cat([torch.ones(size=mask.shape[0:3]).unsqueeze(3).float().cuda(),mask],dim=-1) #B,n_img,3*reso**2,vit_reso**2+1, add global mask + mask[valid_frames == 0, :, :] = False + mask=mask.permute(0,2,1,3).reshape(bs,3*self.triplane_reso**2,-1) #B,3*reso**2,n_img*(vit_resso**2+1) + attn_mask=(mask==0) #invert the mask, False indicates valid, True indicates invalid + return attn_mask + + def forward(self,triplane_feat,image_feat,proj_mat,valid_frames): + '''image feat is bs,n_img,length,channel''' + batch_size,n_img=image_feat.shape[0],image_feat.shape[1] + img_length=image_feat.shape[2] + image_feat_flat=image_feat.view(batch_size,n_img*img_length,-1) + coords=self.plane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1) + + coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3*3,1).float().cuda()],dim=-1) + #print(coord_homo.shape,proj_mat.shape) + coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3)) + x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2] + y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2] + x = x/(224.0-1) + y = y/(224.0-1) + coords_x=torch.clamp(x,min=0.0,max=1.0)*(self.vit_reso-1) + coords_y=torch.clamp(y,min=0.0,max=1.0)*(self.vit_reso-1) + coords_proj=torch.stack([coords_x,coords_y],dim=-1) + vit_coords=self.vit_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1) + attn_mask=torch.repeat_interleave( + self.get_attn_mask(coords_proj,vit_coords,valid_frames,kernel_size=1.0),self.n_heads, 0 + ) + triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1) + query=self.q(triplane_feat)+self.triplane_pe + key=self.k(image_feat_flat)+self.image_pe + value=self.v(image_feat_flat)+self.image_pe + attn,_=self.attn(query,key,value,attn_mask=attn_mask) + output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso) + + return output + +class MultiImage_Fuse_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64, + img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8): + super().__init__() + self.triplane_reso=reso + self.vit_reso=vit_reso + self.inner_channel=inner_channel + self.padding=padding + self.n_heads=n_heads + self.get_vox_coord() + self.get_vit_coords() + self.out_channels=out_channels + self.kernel_func=mask_kernel + self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso)) + self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel) + self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + + #self.cross_attn=CrossAttention(query_dim=inner_channel,heads=8,dim_head=inner_channel//8) + self.cross_attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels) + self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].cuda().float() #1,1,length,channel + #self.image_pe = self.image_pe.reshape(1,max_nimg,self.vit_reso,self.vit_reso,inner_channel) + self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float() + + def get_vit_coords(self): + x = torch.arange(self.vit_reso) + y = torch.arange(self.vit_reso) + + X, Y = torch.meshgrid(x, y, indexing='xy') + vit_coords = torch.stack([X, Y], dim=-1) + self.vit_coords = vit_coords.cuda().float() #reso,reso,2 + + def get_vox_coord(self): + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + + X, Y, Z = torch.meshgrid(x, y, z, indexing='ij') + vox_coor = torch.cat([X[:, :, :, None], Y[:, :, :, None], Z[:, :, :, None]], dim=-1) + self.vox_index = vox_coor.view(-1, 3).long().cuda() + + vox_coor = self.vox_index.float() / (self.triplane_reso - 1) + vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6) + self.vox_coor = vox_coor.view(-1, 3).float().cuda() + + def get_attn_mask(self,valid_frames): + ''' + :param valid_frames: of shape B,n_img + ''' + #print(valid_frames) + #bs,n_img=valid_frames.shape[0:2] + attn_mask=(valid_frames.float()==0) + #attn_mask=attn_mask.unsqueeze(1).unsqueeze(2).expand(-1,self.triplane_reso**3,-1,-1) #B,1,n_img + #attn_mask=attn_mask.reshape(bs*self.triplane_reso**3,-1,n_img).bool() + attn_mask=torch.repeat_interleave(attn_mask.unsqueeze(1),self.triplane_reso**3,0) + # print(attn_mask[self.triplane_reso**3*1+10]) + # print(attn_mask[self.triplane_reso ** 3 * 2+10]) + # print(attn_mask[self.triplane_reso ** 3 * 3+10]) + return attn_mask + + def forward(self,triplane_feat,image_feat,proj_mat,valid_frames): + '''image feat is bs,n_img,length,channel''' + batch_size,n_img=image_feat.shape[0],image_feat.shape[1] + image_feat=image_feat[:,:,1:,:] #discard global feature + + #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c + image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c + image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c + image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c + unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso + #unflat_k_v=image_k_v.permute(0,4,1,2,3) + #vit_coords=self.vit_coords[None,None].expand(batch_size,n_img,-1,-1,-1) #Bs,n_img,reso,reso,2 + + coords=self.vox_coor.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1) + coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3,1).float().cuda()],dim=-1) + coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3)) + x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2] + y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2] + x = x/(224.0-1) #0~1 + y = y/(224.0-1) + coords_proj=torch.stack([x,y],dim=-1) + coords_proj=(coords_proj-0.5)*2 + img_index=((torch.arange(n_img)[None,:,None,None].expand( + batch_size,-1,self.triplane_reso**3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1 + + # img_index_feat=torch.arange(n_img)[None,:,None,None,None].expand( + # batch_size,-1,self.vit_reso,self.vit_reso,-1).float().cuda() #Bs,n_img,reso,reso,1 + #coords_feat=torch.cat([vit_coords,img_index_feat],dim=-1).permute(0,4,1,2,3)#Bs,n_img,reso,reso,3 + grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index + grid=torch.clamp(grid,min=-1.0,max=1.0) + sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3 + xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, + dim=2) # B,C,64,64 + xz_vox_feat=xz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,4,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zxy + xz_vox_feat=rearrange(xz_vox_feat, 'b c z x y -> b (x y z) c') + xy_vox_feat=xy_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,2,4).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #yxz + xy_vox_feat=rearrange(xy_vox_feat, 'b c y x z -> b (x y z) c') + yz_vox_feat=yz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,4,3,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zyx + yz_vox_feat=rearrange(yz_vox_feat, 'b c z y x -> b (x y z) c') + #xz_vox_feat = xz_feat[:, :, vox_index[:, 2], vox_index[:, 0]].transpose(1, 2) # B,C,64*64*64 + #xy_vox_feat = xy_feat[:, :, vox_index[:, 1], vox_index[:, 0]].transpose(1, 2) + #yz_vox_feat = yz_feat[:, :, vox_index[:, 2], vox_index[:, 1]].transpose(1, 2) + + triplane_expand_feat = torch.cat([xz_vox_feat, xy_vox_feat, yz_vox_feat], dim=-1) # B,64*64*64,3*C + triplane_query = self.q(triplane_expand_feat) + self.triplane_pe + k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c') + #k_v=sample_k_v.permute(0,3,2,1).reshape(batch_size*self.triplane_reso**3,n_img,-1) #B*64**3,n_img,C + k=k_v[:,:,0:self.inner_channel] + v=k_v[:,:,self.inner_channel:] + q=rearrange(triplane_query,'b k c -> (b k) 1 c') + #q=triplane_query.view(batch_size*self.triplane_reso**3,1,-1) + #k,v is of shape, B*reso**3,k,channel, q is of shape B*reso**3,1,channel + #attn mask should be B*reso**3*n_heads,1,k + #attn_mask=torch.repeat_interleave(self.get_attn_mask(valid_frames),self.n_heads,0) + #print(q.shape,k.shape,v.shape) + attn_out,_=self.cross_attn(q,k,v)#attn_mask=attn_mask) #fuse multi-view feature + #volume=attn_out.view(batch_size,self.triplane_reso,self.triplane_reso,self.triplane_reso,-1) #B,reso,reso,reso,channel + #print(attn_out.shape) + volume=rearrange(attn_out,'(b x y z) 1 c -> b x y z c',x=self.triplane_reso,y=self.triplane_reso,z=self.triplane_reso) + #xz_feat = torch.mean(volume, dim=2).transpose(1,2) #B,reso,reso,C + xz_feat = reduce(volume, "b x y z c -> b z x c", 'mean') + #xy_feat = torch.mean(volume, dim=3).transpose(1,2) #B,reso,reso,C + xy_feat= reduce(volume, 'b x y z c -> b y x c', 'mean') + #yz_feat = torch.mean(volume, dim=1).transpose(1,2) #B,reso,reso,C + yz_feat=reduce(volume, 'b x y z c -> b z y c', 'mean') + triplane_out = torch.cat([xz_feat, xy_feat, yz_feat], dim=1) #B,reso*3,reso,C + #print(triplane_out.shape) + triplane_out = self.proj_out(triplane_out) + triplane_out = triplane_out.permute(0,3,1,2) + #print(triplane_out.shape) + return triplane_out + +class MultiImage_TriFuse_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64, + img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5): + super().__init__() + self.triplane_reso=reso + self.vit_reso=vit_reso + self.inner_channel=inner_channel + self.padding=padding + self.n_heads=n_heads + self.get_triplane_coord() + self.get_vit_coords() + self.out_channels=out_channels + self.kernel_func=mask_kernel + self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso)) + self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel) + self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + + self.cross_attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1) + self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel + self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 2*3).unsqueeze(0).cuda().float() + + def get_vit_coords(self): + x = torch.arange(self.vit_reso) + y = torch.arange(self.vit_reso) + + X, Y = torch.meshgrid(x, y, indexing='xy') + vit_coords = torch.stack([X, Y], dim=-1) + self.vit_coords = vit_coords.cuda().float() #reso,reso,2 + + def get_triplane_coord(self): + '''xz plane firstly, z is at the ''' + x = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + X, Z = torch.meshgrid(x, z, indexing='xy') + xz_coords = torch.cat( + [X[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2, Z[:, :, None]], + dim=-1) # in xyz order + + '''xy plane''' + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + X, Y = torch.meshgrid(x, y, indexing='xy') + xy_coords = torch.cat( + [X[:, :, None], Y[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2], + dim=-1) # in xyz order + + '''yz plane''' + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + Y, Z = torch.meshgrid(y, z, indexing='xy') + yz_coords = torch.cat( + [torch.ones_like(Y[:, :, None]) * (self.triplane_reso - 1) / 2, Y[:, :, None], Z[:, :, None]], dim=-1) + + triplane_coords = torch.cat([xz_coords, xy_coords, yz_coords], dim=0) + triplane_coords = triplane_coords / (self.triplane_reso - 1) + triplane_coords = (triplane_coords - 0.5) * 2 * (1 + self.padding + 10e-6) + self.triplane_coords = triplane_coords.view(-1,3).float().cuda() + def forward(self,triplane_feat,image_feat,proj_mat,valid_frames): + '''image feat is bs,n_img,length,channel''' + batch_size,n_img=image_feat.shape[0],image_feat.shape[1] + image_feat=image_feat[:,:,1:,:] #discard global feature + #print(image_feat.shape) + + #image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c + image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c + image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c + image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c + unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso + + coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1) + coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**2*3,1).float().cuda()],dim=-1) + coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3)) + x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2] + y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2] + x = x/(224.0-1) #0~1 + y = y/(224.0-1) + coords_proj=torch.stack([x,y],dim=-1) + coords_proj=(coords_proj-0.5)*2 + img_index=((torch.arange(n_img)[None,:,None,None].expand( + batch_size,-1,self.triplane_reso**2*3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1 + + grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index + grid=torch.clamp(grid,min=-1.0,max=1.0) + sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3 + + triplane_flat_feat=rearrange(triplane_feat,'b c h w -> b (h w) c') + triplane_query = self.q(triplane_flat_feat) + self.triplane_pe + + k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c') + k=k_v[:,:,0:self.inner_channel] + v=k_v[:,:,self.inner_channel:] + q=rearrange(triplane_query,'b k c -> (b k) 1 c') + attn_out,_=self.cross_attn(q,k,v) + triplane_out=rearrange(attn_out,'(b h w) 1 c -> b c h w',b=batch_size,h=self.triplane_reso*3,w=self.triplane_reso) + triplane_out = self.proj_out(triplane_out) + return triplane_out + + +class MultiImage_Global_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64, + img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5): + super().__init__() + self.triplane_reso=reso + self.vit_reso=vit_reso + self.inner_channel=inner_channel + self.padding=padding + self.n_heads=n_heads + self.out_channels=out_channels + self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel) + self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel) + + self.cross_attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels) + self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel + self.triplane_pe = position_encoding(inner_channel, self.triplane_reso**2*3).unsqueeze(0).cuda().float() + def forward(self,triplane_feat,image_feat,proj_mat,valid_frames): + '''image feat is bs,n_img,length,channel + triplane feat is bs,C,H*3,W + ''' + batch_size,n_img=image_feat.shape[0],image_feat.shape[1] + L=image_feat.shape[2]-1 + image_feat=image_feat[:,:,1:,:] #discard global feature + + image_k=self.k(image_feat)+self.image_pe #B,n_img,h*w,c + image_v=self.v(image_feat)+self.image_pe #B,n_img,h*w,c + image_k=image_k.view(batch_size,n_img*L,-1) + image_v=image_v.view(batch_size,n_img*L,-1) + + triplane_flat_feat=rearrange(triplane_feat,"b c h w -> b (h w) c") + triplane_query = self.q(triplane_flat_feat) + self.triplane_pe + #print(triplane_query.shape,image_k.shape,image_v.shape) + attn_out,_=self.cross_attn(triplane_query,image_k,image_v) + triplane_flat_out = self.proj_out(attn_out) + triplane_out=rearrange(triplane_flat_out,"b (h w) c -> b c h w",h=self.triplane_reso*3,w=self.triplane_reso) + + return triplane_out + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + + if context_dim is None: + context_dim = query_dim + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, q,k,v): + h = self.heads + + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + +class Image_Vox_Local_Sampler_Pooling(nn.Module): + def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,stride=4): + super().__init__() + self.triplane_reso=reso + self.padding=padding + self.get_vox_coord() + self.out_channels=out_channels + self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1) + + self.vox_process=nn.Sequential( + nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1) + ) + self.xz_conv=nn.Sequential( + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((1,stride,1),stride=(1,stride,1)), #8 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((1,stride,1), stride=(1,stride,1)), #2 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + ) + self.xy_conv = nn.Sequential( + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 8 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 2 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + ) + self.yz_conv = nn.Sequential( + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 8 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 2 + nn.BatchNorm3d(inner_channel), + nn.ReLU(), + nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1), + ) + self.roll_out_conv=RollOut_Conv(in_channels=inner_channel,out_channels=out_channels) + #self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1) + def get_vox_coord(self): + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + + X,Y,Z=torch.meshgrid(x,y,z,indexing='ij') + vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1) + vox_coor=vox_coor/(self.triplane_reso-1) + vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6) + self.vox_coor=vox_coor.view(-1,3).float().cuda() + + + def forward(self,image_feat,proj_mat): + image_feat=self.img_proj(image_feat) + batch_size=image_feat.shape[0] + vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3 + vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1) + coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2)) + x=coord_inimg[:,:,0]/coord_inimg[:,:,2] + y=coord_inimg[:,:,1]/coord_inimg[:,:,2] + x=(x/(224.0-1.0)-0.5)*2 #-1~1 + y=(y/(224.0-1.0)-0.5)*2 #-1~1 + + xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2 + #print(image_feat.shape,xy.shape) + grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\ + view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3 + + grid_feat=self.vox_process(grid_feat) + xz_feat=torch.mean(self.xz_conv(grid_feat),dim=3).transpose(2,3) + xy_feat=torch.mean(self.xy_conv(grid_feat),dim=4).transpose(2,3) + yz_feat=torch.mean(self.yz_conv(grid_feat),dim=2).transpose(2,3) + triplane_wImg=torch.cat([xz_feat,xy_feat,yz_feat],dim=2) + #print(triplane_wImg.shape) + + return self.roll_out_conv(triplane_wImg) + +class Image_ExpandVox_attn_Sampler(nn.Module): + def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8): + super().__init__() + self.triplane_reso=reso + self.padding=padding + self.vit_reso=vit_reso + self.get_vox_coord() + self.get_vit_coords() + self.out_channels=out_channels + self.n_heads=n_heads + + self.kernel_func = mask_kernel_close_false + self.k = nn.Linear(in_features=img_in_channels, out_features=inner_channel) + # self.q_xz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1) + # self.q_xy = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1) + # self.q_yz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1) + self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel) + + self.v = nn.Linear(in_features=img_in_channels, out_features=inner_channel) + self.attn = torch.nn.MultiheadAttention( + embed_dim=inner_channel, num_heads=n_heads, batch_first=True) + self.out_proj=nn.Linear(in_features=inner_channel,out_features=out_channels) + + self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float() + self.image_pe = position_encoding(inner_channel, self.vit_reso ** 2+1).unsqueeze(0).cuda().float() + def get_vox_coord(self): + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + + X,Y,Z=torch.meshgrid(x,y,z,indexing='ij') + vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1) + self.vox_index=vox_coor.view(-1,3).long().cuda() + + + vox_coor = self.vox_index.float() / (self.triplane_reso - 1) + vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6) + self.vox_coor = vox_coor.view(-1, 3).float().cuda() + # print(self.vox_coor[0]) + # print(self.vox_coor[self.triplane_reso**2])#x should increase + # print(self.vox_coor[self.triplane_reso]) #y should increase + # print(self.vox_coor[1])#z should increase + + def get_vit_coords(self): + x=torch.arange(self.vit_reso) + y=torch.arange(self.vit_reso) + + X,Y=torch.meshgrid(x,y,indexing='xy') + vit_coords=torch.stack([X,Y],dim=-1) + self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float() + + def compute_attn_mask(self,proj_coords,vit_coords,kernel_size=1.0): + dist = torch.cdist(proj_coords.float(), vit_coords.float()) + mask = self.kernel_func(dist, sigma=kernel_size) # True if valid, B,reso**3,vit_reso**2 + return mask + + + def forward(self,triplane_feat,image_feat,proj_mat): + xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, dim=2) # B,C,64,64 + #xz_feat=self.q_xz(xz_feat) + #xy_feat=self.q_xy(xy_feat) + #yz_feat=self.q_yz(yz_feat) + batch_size=image_feat.shape[0] + vox_index=self.vox_index #64*64*64,3 + xz_vox_feat=xz_feat[:,:,vox_index[:,2],vox_index[:,0]].transpose(1,2) #B,C,64*64*64 + xy_vox_feat=xy_feat[:,:,vox_index[:,1],vox_index[:,0]].transpose(1,2) + yz_vox_feat=yz_feat[:,:,vox_index[:,2],vox_index[:,1]].transpose(1,2) + triplane_expand_feat=torch.cat([xz_vox_feat,xy_vox_feat,yz_vox_feat],dim=-1)#B,C,64*64*64,3 + triplane_query=self.q(triplane_expand_feat)+self.triplane_pe + + '''compute projection''' + vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) # + vox_homo = torch.cat([vox_coords, torch.ones((batch_size, self.triplane_reso ** 3, 1)).float().cuda()], dim=-1) + coord_inimg = torch.einsum('bhc,bck->bhk', vox_homo, proj_mat.transpose(1, 2)) + x = coord_inimg[:, :, 0] / coord_inimg[:, :, 2] + y = coord_inimg[:, :, 1] / coord_inimg[:, :, 2] + # + x = x / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 + y = y / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 #B,N + xy=torch.stack([x,y],dim=-1) #B,64*64*64,2 + xy=torch.clamp(xy,min=0,max=self.vit_reso-1) + vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) #B, 16*16,2 + attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,vit_coords,kernel_size=0.5), + self.n_heads,0) #B*n_heads, reso**3, vit_reso**2 + + k=self.k(image_feat)+self.image_pe + v=self.v(image_feat)+self.image_pe + attn_mask=torch.cat([torch.zeros([attn_mask.shape[0],attn_mask.shape[1],1]).cuda().bool(),attn_mask],dim=-1) #add empty token to each key and value + vox_feat,_=self.attn(triplane_query,k,v,attn_mask=attn_mask) #B,reso**3,C + feat_volume=self.out_proj(vox_feat).transpose(1,2).reshape(batch_size,-1,self.triplane_reso, + self.triplane_reso,self.triplane_reso) + xz_feat=torch.mean(feat_volume,dim=3).transpose(2,3) + xy_feat=torch.mean(feat_volume,dim=4).transpose(2,3) + yz_feat=torch.mean(feat_volume,dim=2).transpose(2,3) + triplane_out=torch.cat([xz_feat,xy_feat,yz_feat],dim=2) + return triplane_out + +class Multi_Image_Fusion(nn.Module): + def __init__(self,reso,image_reso=16,padding=0.1,img_channels=1280,triplane_channel=64,inner_channels=128,output_channel=64,n_heads=8): + super().__init__() + self.triplane_reso=reso + self.image_reso=image_reso + self.padding=padding + self.get_triplane_coord() + self.get_vit_coords() + self.img_proj=nn.Conv3d(in_channels=img_channels,out_channels=512,kernel_size=1) + self.kernel_func=mask_kernel + + self.q = nn.Linear(in_features=triplane_channel, out_features=inner_channels, bias=False) + self.k = nn.Linear(in_features=512, out_features=inner_channels) + self.v = nn.Linear(in_features=512, out_features=inner_channels) + + self.attn = torch.nn.MultiheadAttention( + embed_dim=inner_channels, num_heads=n_heads, batch_first=True) + self.out_proj=nn.Linear(in_features=inner_channels,out_features=output_channel) + self.n_heads=n_heads + + def get_triplane_coord(self): + '''xz plane firstly, z is at the ''' + x=torch.arange(self.triplane_reso) + z=torch.arange(self.triplane_reso) + X,Z=torch.meshgrid(x,z,indexing='xy') + xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order + + '''xy plane''' + x = torch.arange(self.triplane_reso) + y = torch.arange(self.triplane_reso) + X, Y = torch.meshgrid(x, y, indexing='xy') + xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order + + '''yz plane''' + y = torch.arange(self.triplane_reso) + z = torch.arange(self.triplane_reso) + Y,Z = torch.meshgrid(y,z,indexing='xy') + yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1) + + triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0) + triplane_coords=triplane_coords/(self.triplane_reso-1) + triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6) + self.triplane_coords=triplane_coords.float().cuda() + + def get_vit_coords(self): + x=torch.arange(self.image_reso) + y=torch.arange(self.image_reso) + X,Y=torch.meshgrid(x,y,indexing='xy') + vit_coords=torch.cat([X[:,:,None],Y[:,:,None]],dim=-1) + self.vit_coords=vit_coords.float().cuda() #in x,y order + + def compute_attn_mask(self,proj_coord,vit_coords,valid_frames,kernel_size=2.0): + ''' + :param proj_coord: B,K,H,W,2 + :param vit_coords: H,W,2 + :return: + ''' + B,K=proj_coord.shape[0:2] + vit_coords_expand=vit_coords[None,None,:,:,:].expand(B,K,-1,-1,-1) + + proj_coord=proj_coord.view(B*K,proj_coord.shape[2]*proj_coord.shape[3],proj_coord.shape[4]) + vit_coords_expand=vit_coords_expand.view(B*K,self.image_reso*self.image_reso,2) + attn_mask=self.kernel_func(torch.cdist(proj_coord,vit_coords_expand),sigma=float(kernel_size)) + attn_mask=attn_mask.reshape(B,K,proj_coord.shape[1],vit_coords_expand.shape[1]) + valid_expand=valid_frames[:,:,None,None] + attn_mask[valid_frames>0,:,:]=True + attn_mask=attn_mask.permute(0,2,1,3) + attn_mask=attn_mask.reshape(B,proj_coord.shape[1],K*vit_coords_expand.shape[1]) + atten_index=torch.where(attn_mask[0,0]==False) + return attn_mask + + + def forward(self,triplane_feat,image_feat,proj_mat,valid_frames): + ''' + :param image_feat: B,C,K,16,16 + :param proj_mat: B,K,4,4 + :param valid_frames: B,K, true if have image, used to compute attn_mask for transformer + :return: + ''' + image_feat=self.img_proj(image_feat) + batch_size=image_feat.shape[0] #K is number of frames + K=image_feat.shape[2] + triplane_coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,K,-1,-1,-1) #B,K,192,64,3 + #print(torch.amin(triplane_coords),torch.amax(triplane_coords)) + coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,K,triplane_coords.shape[2],triplane_coords.shape[3],1)).float().cuda()],dim=-1) + #print(coord_homo.shape,proj_mat.shape) + coord_inimg=torch.einsum('bjhwc,bjck->bjhwk',coord_homo,proj_mat.transpose(2,3)) + x=coord_inimg[:,:,:,:,0]/coord_inimg[:,:,:,:,2] + y=coord_inimg[:,:,:,:,1]/coord_inimg[:,:,:,:,2] + x=x/(224.0-1.0)*(self.image_reso-1) + y=y/(224.0-1.0)*(self.image_reso-1) + + xy=torch.cat([x[...,None],y[...,None]],dim=-1) #B,K,H,W,2 + image_value=image_feat.view(image_feat.shape[0],image_feat.shape[1],-1).transpose(1,2) + triplane_query=triplane_feat.view(triplane_feat.shape[0],triplane_feat.shape[1],-1).transpose(1,2) + valid_frames=1.0-valid_frames.float() + attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,self.vit_coords,valid_frames), + self.n_heads,dim=0) + + q=self.q(triplane_query) + k=self.k(image_value) + v=self.v(image_value) + #print(q.shape,k.shape,v.shape) + + attn,_=self.attn(q,k,v,attn_mask=attn_mask) + #print(attn.shape) + output=self.out_proj(attn).transpose(1,2).reshape(batch_size,-1,triplane_feat.shape[2],triplane_feat.shape[3]) + #print(output.shape) + return output + + +if __name__=="__main__": + # import sys + # sys.path.append("../..") + # from datasets.SingleView_dataset import Object_PartialPoints_Img + # from datasets.transforms import Aug_with_Tran + # #sampler=#Image_Vox_Local_Sampler_Pooling(reso=64,padding=0.1,out_channels=64,stride=4).cuda().float() + # sampler=Image_ExpandVox_attn_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=64,inner_channel=64 + # ,out_channels=64,n_heads=8).cuda().float() + # # sampler=Image_Direct_AttenwMask_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128 + # # ,out_channels=64,n_heads=8).cuda().float() + # dataset_config = { + # "data_path": "/data1/haolin/datasets", + # "surface_size": 20000, + # "par_pc_size": 4096, + # "load_proj_mat": True, + # } + # transform = Aug_with_Tran() + # datasets = Object_PartialPoints_Img(dataset_config['data_path'], split_filename="val_par_img.json", split='val', + # transform=transform, sampling=False, + # num_samples=1024, return_surface=True, ret_sample=True, + # surface_sampling=True, par_pc_size=dataset_config['par_pc_size'], + # surface_size=dataset_config['surface_size'], + # load_proj_mat=dataset_config['load_proj_mat'], load_image=True, + # load_org_img=False, load_triplane=True, replica=1) + # + # dataloader = torch.utils.data.DataLoader( + # datasets=datasets, + # batch_size=64, + # shuffle=True + # ) + # iterator = dataloader.__iter__() + # data_batch = iterator.next() + # unflatten = torch.nn.Unflatten(1, (16, 16)) + # image = data_batch['image'][:,:,:].cuda().float() + # #image=unflatten(image).permute(0,3,1,2) + # proj_mat = data_batch['proj_mat'].cuda().float() + # triplane_feat=torch.randn((64,64,32*3,32)).cuda().float() + # sampler(triplane_feat,image,proj_mat) + # memory_usage=torch.cuda.max_memory_allocated() / MB + # print("memory usage %f mb"%(memory_usage)) + + + import sys + sys.path.append("../..") + from datasets.SingleView_dataset import Object_PartialPoints_MultiImg + from datasets.transforms import Aug_with_Tran + + dataset_config = { + "data_path": "/data1/haolin/datasets", + "surface_size": 20000, + "par_pc_size": 4096, + "load_proj_mat": True, + } + transform = Aug_with_Tran() + dataset = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename="train_par_img.json", split='train', + transform=transform, sampling=False, + num_samples=1024, return_surface=True, ret_sample=True, + surface_sampling=True, par_pc_size=dataset_config['par_pc_size'], + surface_size=dataset_config['surface_size'], + load_proj_mat=dataset_config['load_proj_mat'], load_image=True, + load_org_img=False, load_triplane=True, replica=1) + + dataloader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=10, + shuffle=False + ) + iterator = dataloader.__iter__() + data_batch = iterator.next() + #unflatten = torch.nn.Unflatten(2, (16, 16)) + image = data_batch['image'][:,:,:,:].cuda().float() + #image=unflatten(image).permute(0,4,1,2,3) + proj_mat = data_batch['proj_mat'].cuda().float() + valid_frames = data_batch['valid_frames'].cuda().float() + triplane_feat=torch.randn((10,128,32*3,32)).cuda().float() + + # fusion_module=MultiImage_Fuse_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128 + # ,out_channels=64,n_heads=8).cuda().float() + fusion_module=MultiImage_Global_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128 + ,out_channels=64,n_heads=8).cuda().float() + fusion_module(triplane_feat,image,proj_mat,valid_frames) + memory_usage=torch.cuda.max_memory_allocated() / MB + print("memory usage %f mb"%(memory_usage)) diff --git a/models/modules/parpoints_encoder.py b/models/modules/parpoints_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1648d66f751d5310d0c6ac383bf343c29f2afd45 --- /dev/null +++ b/models/modules/parpoints_encoder.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_scatter import scatter_mean, scatter_max +from .unet import UNet +from .resnet_block import ResnetBlockFC +from .PointEMB import PointEmbed +import numpy as np + +class ParPoint_Encoder(nn.Module): + ''' PointNet-based encoder network with ResNet blocks for each point. + Number of input points are fixed. + + Args: + c_dim (int): dimension of latent code c + dim (int): input points dimension + hidden_dim (int): hidden dimension of the network + scatter_type (str): feature aggregation when doing local pooling + unet (bool): weather to use U-Net + unet_kwargs (str): U-Net parameters + plane_resolution (int): defined resolution for plane feature + plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + n_blocks (int): number of blocks ResNetBlockFC layers + ''' + + def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', unet_kwargs=None, + plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5): + super().__init__() + self.c_dim = c_dim + + self.fc_pos = nn.Linear(dim, 2 * hidden_dim) + self.blocks = nn.ModuleList([ + ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks) + ]) + self.fc_c = nn.Linear(hidden_dim, c_dim) + + self.actvn = nn.ReLU() + self.hidden_dim = hidden_dim + + self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs) + + self.reso_plane = plane_resolution + self.plane_type = plane_type + self.padding = padding + + if scatter_type == 'max': + self.scatter = scatter_max + elif scatter_type == 'mean': + self.scatter = scatter_mean + + # takes in "p": point cloud and "query": sdf_xyz + # sample plane features for unlabeled_query as well + def forward(self, p,point_emb): # , query2): + batch_size, T, D = p.size() + #print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0)) + # acquire the index for each point + coord = {} + index = {} + if 'xz' in self.plane_type: + coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding) + index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane) + if 'xy' in self.plane_type: + coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding) + index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane) + if 'yz' in self.plane_type: + coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding) + index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane) + net = self.fc_pos(point_emb) + + net = self.blocks[0](net) + for block in self.blocks[1:]: + pooled = self.pool_local(coord, index, net) + net = torch.cat([net, pooled], dim=2) + net = block(net) + + c = self.fc_c(net) + #print(c.shape) + + fea = {} + # second_sum = 0 + if 'xz' in self.plane_type: + fea['xz'] = self.generate_plane_features(p, c, + plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64) + if 'xy' in self.plane_type: + fea['xy'] = self.generate_plane_features(p, c, plane='xy') + if 'yz' in self.plane_type: + fea['yz'] = self.generate_plane_features(p, c, plane='yz') + cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']], + dim=2) # concat at row dimension + #print(cat_feature.shape) + plane_feat=self.unet(cat_feature) + + return plane_feat + + + def normalize_coordinate(self, p, padding=0.1, plane='xz'): + ''' Normalize coordinate to [0, 1] for unit cube experiments + + Args: + p (tensor): point + padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55] + plane (str): plane feature type, ['xz', 'xy', 'yz'] + ''' + if plane == 'xz': + xy = p[:, :, [0, 2]] + elif plane == 'xy': + xy = p[:, :, [0, 1]] + else: + xy = p[:, :, [1, 2]] + #print("origin",torch.amin(xy), torch.amax(xy)) + xy=xy/2 #xy is originally -1 ~ 1 + xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5) + xy_new = xy_new + 0.5 # range (0, 1) + #print("scale",torch.amin(xy_new),torch.amax(xy_new)) + + # f there are outliers out of the range + if xy_new.max() >= 1: + xy_new[xy_new >= 1] = 1 - 10e-6 + if xy_new.min() < 0: + xy_new[xy_new < 0] = 0.0 + return xy_new + + def coordinate2index(self, x, reso): + ''' Normalize coordinate to [0, 1] for unit cube experiments. + Corresponds to our 3D model + + Args: + x (tensor): coordinate + reso (int): defined resolution + coord_type (str): coordinate type + ''' + x = (x * reso).long() + index = x[:, :, 0] + reso * x[:, :, 1] + index = index[:, None, :] + return index + + # xy is the normalized coordinates of the point cloud of each plane + # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input + def pool_local(self, xy, index, c): + bs, fea_dim = c.size(0), c.size(2) + keys = xy.keys() + + c_out = 0 + for key in keys: + # scatter plane features from points + fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2) + if self.scatter == scatter_max: + fea = fea[0] + # gather feature back to points + fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1)) + c_out += fea + return c_out.permute(0, 2, 1) + + def generate_plane_features(self, p, c, plane='xz'): + # acquire indices of features in plane + xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1) + index = self.coordinate2index(xy, self.reso_plane) + + # scatter plane features from points + fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2) + c = c.permute(0, 2, 1) # B x 512 x T + fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2 + fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, + self.reso_plane) # sparce matrix (B x 512 x reso x reso) + #print(fea_plane.shape) + + return fea_plane \ No newline at end of file diff --git a/models/modules/point_transformer.py b/models/modules/point_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b9263a16bd1b19db9cc229355a57a44d931ca170 --- /dev/null +++ b/models/modules/point_transformer.py @@ -0,0 +1,442 @@ +from torch import nn, einsum +import torch +import torch.nn.functional as F +from einops import rearrange,repeat +from timm.models.layers import DropPath +from torch_cluster import fps +import numpy as np + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class PositionalEmbedding(torch.nn.Module): + def __init__(self, num_channels, max_positions=10000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + + if context_dim is None: + context_dim = query_dim + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + + if context is None: + context = x + + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange( + t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + if dim_out is None: + dim_out = dim + + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + +class AdaLayerNorm(nn.Module): + def __init__(self, n_embd): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(n_embd, n_embd*2) + self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(timestep) + scale, shift = torch.chunk(emb, 2, dim=2) + x = self.layernorm(x) * (1 + scale) + shift + return x + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = AdaLayerNorm(dim) + self.norm2 = AdaLayerNorm(dim) + self.norm3 = AdaLayerNorm(dim) + self.checkpoint = checkpoint + + init_values = 0 + drop_path = 0.0 + + + self.ls1 = LayerScale( + dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.ls2 = LayerScale( + dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + self.ls3 = LayerScale( + dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path3 = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, t, context=None): + x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x + x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x + x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x + return x + +class LatentArrayTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, in_channels, t_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None, + block=BasicTransformerBlock): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + + self.t_channels = t_channels + + self.proj_in = nn.Linear(in_channels, inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for _ in range(depth)] + ) + + self.norm = nn.LayerNorm(inner_dim) + + if out_channels is None: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False)) + else: + self.num_cls = out_channels + self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False)) + + self.context_dim = context_dim + + self.map_noise = PositionalEmbedding(t_channels) + + self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim) + self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim) + + # ### + # self.pos_emb = nn.Embedding(512, inner_dim) + # ### + + def forward(self, x, t, cond, class_emb): + + t_emb = self.map_noise(t)[:, None] + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + + x = self.proj_in(x) + #print(class_emb.shape,t_emb.shape) + for block in self.transformer_blocks: + x = block(x, t_emb+class_emb[:,None,:], context=cond) + + x = self.norm(x) + + x = self.proj_out(x) + return x + +class PointTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, in_channels, t_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None, + block=BasicTransformerBlock): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + + self.t_channels = t_channels + + self.proj_in = nn.Linear(in_channels, inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for _ in range(depth)] + ) + + self.norm = nn.LayerNorm(inner_dim) + + if out_channels is None: + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False)) + else: + self.num_cls = out_channels + self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False)) + + self.context_dim = context_dim + + self.map_noise = PositionalEmbedding(t_channels) + + self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim) + self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim) + + # ### + # self.pos_emb = nn.Embedding(512, inner_dim) + # ### + + def forward(self, x, t, cond=None): + + t_emb = self.map_noise(t)[:, None] + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + + x = self.proj_in(x) + + for block in self.transformer_blocks: + x = block(x, t_emb, context=cond) + + x = self.norm(x) + + x = self.proj_out(x) + return x +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def cache_fn(f): + cache = None + @wraps(f) + def cached_fn(*args, _cache = True, **kwargs): + if not _cache: + return f(*args, **kwargs) + nonlocal cache + if cache is not None: + return cache + cache = f(*args, **kwargs) + return cache + return cached_fn + +class PreNorm(nn.Module): + def __init__(self, dim, fn, context_dim = None): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None + + def forward(self, x, **kwargs): + x = self.norm(x) + + if exists(self.norm_context): + context = kwargs['context'] + normed_context = self.norm_context(context) + kwargs.update(context = normed_context) + + return self.fn(x, **kwargs) + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias = False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) + self.to_out = nn.Linear(inner_dim, query_dim) + + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x, context = None, mask = None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim = -1) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h = h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim = -1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h = h) + return self.drop_path(self.to_out(out)) + + +class PointEmbed(nn.Module): + def __init__(self, hidden_dim=48, dim=128): + super().__init__() + + assert hidden_dim % 6 == 0 + + self.embedding_dim = hidden_dim + e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi + e = torch.stack([ + torch.cat([e, torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), e, + torch.zeros(self.embedding_dim // 6)]), + torch.cat([torch.zeros(self.embedding_dim // 6), + torch.zeros(self.embedding_dim // 6), e]), + ]) + self.register_buffer('basis', e) # 3 x 16 + + self.mlp = nn.Linear(self.embedding_dim + 3, dim) + + @staticmethod + def embed(input, basis): + projections = torch.einsum( + 'bnd,de->bne', input, basis) + embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) + return embeddings + + def forward(self, input): + # input: B x N x 3 + embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C + return embed + + +class PointEncoder(nn.Module): + def __init__(self, + dim=512, + num_inputs = 2048, + num_latents = 512, + latent_dim = 512): + super().__init__() + + self.num_inputs = num_inputs + self.num_latents = num_latents + + self.cross_attend_blocks = nn.ModuleList([ + PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim), + PreNorm(dim, FeedForward(dim)) + ]) + + self.point_embed = PointEmbed(dim=dim) + self.proj=nn.Linear(dim,latent_dim) + def encode(self, pc): + # pc: B x N x 3 + B, N, D = pc.shape + assert N == self.num_inputs + + ###### fps + flattened = pc.view(B * N, D) + + batch = torch.arange(B).to(pc.device) + batch = torch.repeat_interleave(batch, N) + + pos = flattened + + ratio = 1.0 * self.num_latents / self.num_inputs + + idx = fps(pos, batch, ratio=ratio) + + sampled_pc = pos[idx] + sampled_pc = sampled_pc.view(B, -1, 3) + ###### + + sampled_pc_embeddings = self.point_embed(sampled_pc) + + pc_embeddings = self.point_embed(pc) + + cross_attn, cross_ff = self.cross_attend_blocks + + x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings + x = cross_ff(x) + x + + return self.proj(x) \ No newline at end of file diff --git a/models/modules/pointnet2_backbone.py b/models/modules/pointnet2_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..9787e9a965c914b0bb4e06db171893f2e4665d5a --- /dev/null +++ b/models/modules/pointnet2_backbone.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import sys +import os +from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule +from .utils import zero_module +from .Positional_Embedding import PositionalEmbedding + +class Pointnet2Encoder(nn.Module): + def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]): + super().__init__() + self.sa1 = PointnetSAModuleVotes( + npoint=npoints[0], + radius=radius[0], + nsample=nsample[0], + mlp=[input_feature_dim, 64, 64, 128], + use_xyz=True, + normalize_xyz=True + ) + + self.sa2 = PointnetSAModuleVotes( + npoint=npoints[1], + radius=radius[1], + nsample=nsample[1], + mlp=[128, 128, 128, 256], + use_xyz=True, + normalize_xyz=True + ) + + self.sa3 = PointnetSAModuleVotes( + npoint=npoints[2], + radius=radius[2], + nsample=nsample[2], + mlp=[256, 256, 256, 512], + use_xyz=True, + normalize_xyz=True + ) + + self.sa4 = PointnetSAModuleVotes( + npoint=npoints[3], + radius=radius[3], + nsample=nsample[3], + mlp=[512, 512, 512, 512], + use_xyz=True, + normalize_xyz=True + ) + def _break_up_pc(self, pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + def forward(self,pointcloud,end_points=None): + if not end_points: end_points = {} + batch_size = pointcloud.shape[0] + + xyz, features = self._break_up_pc(pointcloud) + + end_points['org_xyz'] = xyz + # --------- 4 SET ABSTRACTION LAYERS --------- + xyz1, features1, _ = self.sa1(xyz, features) + end_points['sa1_xyz'] = xyz1 + end_points['sa1_features'] = features1 + + xyz2, features2, _ = self.sa2(xyz1, features1) # this fps_inds is just 0,1,...,1023 + end_points['sa2_xyz'] = xyz2 + end_points['sa2_features'] = features2 + + xyz3, features3, _ = self.sa3(xyz2, features2) # this fps_inds is just 0,1,...,511 + end_points['sa3_xyz'] = xyz3 + end_points['sa3_features'] = features3 + #print(xyz3.shape,features3.shape) + xyz4, features4, _ = self.sa4(xyz3, features3) # this fps_inds is just 0,1,...,255 + end_points['sa4_xyz'] = xyz4 + end_points['sa4_features'] = features4 + #print(xyz4.shape,features4.shape) + return end_points + + + +class PointUNet(nn.Module): + r""" + Backbone network for point cloud feature learning. + Based on Pointnet++ single-scale grouping network. + + Parameters + ---------- + input_feature_dim: int + Number of input channels in the feature descriptor for each point. + e.g. 3 for RGB. + """ + + def __init__(self): + super().__init__() + + self.noisy_encoder=Pointnet2Encoder() + self.cond_encoder=Pointnet2Encoder() + self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512]) + self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512]) + #self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512]) + self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256]) + self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256]) + #self.fp2=PointnetFPModule(mlp=[512 + 256, 512, 256]) + self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128]) + self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128]) + #self.fp3 = PointnetFPModule(mlp=[256 + 128, 256, 128]) + self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128]) + self.fp4 = PointnetFPModule(mlp=[128, 128, 128]) + #self.fp4 = PointnetFPModule(mlp=[128, 128, 128]) + + self.output_layer=nn.Sequential( + nn.LayerNorm(128), + zero_module(nn.Linear(in_features=128,out_features=3,bias=False)) + ) + self.t_emb_layer = PositionalEmbedding(256) + self.map_layer0 = nn.Linear(in_features=256, out_features=512) + self.map_layer1 = nn.Linear(in_features=512, out_features=512) + + def forward(self, noise_points, t,cond_points): + r""" + Forward pass of the network + + Parameters + ---------- + pointcloud: Variable(torch.cuda.FloatTensor) + (B, N, 3 + input_feature_dim) tensor + Point cloud to run predicts on + Each point in the point-cloud MUST + be formated as (x, y, z, features...) + + Returns + ---------- + end_points: {XXX_xyz, XXX_features, XXX_inds} + XXX_xyz: float32 Tensor of shape (B,K,3) + XXX_features: float32 Tensor of shape (B,K,D) + XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] + """ + t_emb = self.t_emb_layer(t) + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb))#B,512 + t_emb = t_emb[:, :, None] #B,512,K + noise_end_points=self.noisy_encoder(noise_points) + cond=self.cond_encoder(cond_points) + # --------- 2 FEATURE UPSAMPLING LAYERS -------- + features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb, + cond['sa4_features']) + features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'], + features) + features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features, + cond["sa3_features"]) + features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'], + features) + features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features, + cond['sa2_features']) + features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features) + features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features, + cond['sa1_features']) + features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features) + features=features.transpose(1,2) + + # features = self.fp1_cross(noise_end_points['sa4_xyz'], cond_end_points['sa4_xyz'], + # noise_end_points['sa4_features']+t_emb, cond_end_points['sa4_features']) + # features = self.fp1(noise_end_points['sa3_xyz'].clone(), noise_end_points['sa4_xyz'].clone(), noise_end_points['sa3_features'], + # features) + # features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'], + # features) + # features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features) + # features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features) + # features = features.transpose(1,2) + output_points=self.output_layer(features) + + return output_points + + +if __name__ == '__main__': + net=PointUNet().cuda().float() + net=net.eval() + noise_points=torch.randn(16,4096,3).cuda().float() + cond_points=torch.randn(16,4096,3).cuda().float() + t=torch.randn(16).cuda().float() + cond_encoder=Pointnet2Encoder().cuda().float() + + out = net(noise_points,cond_points) + print(out.shape) \ No newline at end of file diff --git a/models/modules/resnet_block.py b/models/modules/resnet_block.py new file mode 100644 index 0000000000000000000000000000000000000000..b715ff52626ad8e878593082d6c4926295261c68 --- /dev/null +++ b/models/modules/resnet_block.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Resnet Blocks +class ResnetBlockFC(nn.Module): + ''' Fully connected ResNet Block class. + Args: + size_in (int): input dimension + size_out (int): output dimension + size_h (int): hidden dimension + ''' + + def __init__(self, size_in, size_out=None, size_h=None): + super().__init__() + # Attributes + if size_out is None: + size_out = size_in + + if size_h is None: + size_h = min(size_in, size_out) + + self.size_in = size_in + self.size_h = size_h + self.size_out = size_out + # Submodules + self.fc_0 = nn.Linear(size_in, size_h) + self.fc_1 = nn.Linear(size_h, size_out) + self.actvn = nn.ReLU() + + if size_in == size_out: + self.shortcut = None + else: + self.shortcut = nn.Linear(size_in, size_out, bias=False) + # Initialization + nn.init.zeros_(self.fc_1.weight) + + def forward(self, x): + net = self.fc_0(self.actvn(x)) + dx = self.fc_1(self.actvn(net)) + + if self.shortcut is not None: + x_s = self.shortcut(x) + else: + x_s = x + + return x_s + dx \ No newline at end of file diff --git a/models/modules/resunet.py b/models/modules/resunet.py new file mode 100644 index 0000000000000000000000000000000000000000..dd871dd43eb3adb33cbd34783f0c275d014a9eac --- /dev/null +++ b/models/modules/resunet.py @@ -0,0 +1,440 @@ +import torch +import torch.nn as nn +from .unet import RollOut_Conv +from .Positional_Embedding import PositionalEmbedding +import torch.nn.functional as F +from .utils import zero_module +from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler + +class ResidualConv_MultiImgAtten(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding, reso=64, + vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, + norm="batch"): + super(ResidualConv_MultiImgAtten, self).__init__() + self.use_attn=use_attn + + if norm=="batch": + norm_layer=nn.BatchNorm2d + elif norm==None: + norm_layer=nn.Identity + + self.conv_block = nn.Sequential( + norm_layer(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, padding=padding + ) + ) + self.out_layer=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), + norm_layer(output_dim), + ) + self.roll_out_conv=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + RollOut_Conv(output_dim, output_dim), + ) + if self.use_attn: + self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, + img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, + out_channels=output_dim,padding=triplane_padding) + self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) + + self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) + self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) + def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + t_emb = t_emb[:,:,None,None] + + out=self.conv_block(x)+t_emb + out=self.out_layer(out) + feature=out+self.conv_skip(x) + feature = self.roll_out_conv(feature) + if self.use_attn: + feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect + feature=self.down_conv(feature) + + return feature + +class ResidualConv_TriMultiImgAtten(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding, reso=64, + vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, + norm="batch"): + super(ResidualConv_TriMultiImgAtten, self).__init__() + self.use_attn=use_attn + + if norm=="batch": + norm_layer=nn.BatchNorm2d + elif norm==None: + norm_layer=nn.Identity + + self.conv_block = nn.Sequential( + norm_layer(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, padding=padding + ) + ) + self.out_layer=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), + norm_layer(output_dim), + ) + self.roll_out_conv=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + RollOut_Conv(output_dim, output_dim), + ) + if self.use_attn: + self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, + img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, + out_channels=output_dim,max_nimg=5,padding=triplane_padding) + self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) + + self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) + self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) + def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + t_emb = t_emb[:,:,None,None] + + out=self.conv_block(x)+t_emb + out=self.out_layer(out) + feature=out+self.conv_skip(x) + feature = self.roll_out_conv(feature) + if self.use_attn: + feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect + feature=self.down_conv(feature) + + return feature + + +class ResidualConv_GlobalAtten(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding, reso=64, + vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1, + norm="batch"): + super(ResidualConv_GlobalAtten, self).__init__() + self.use_attn=use_attn + + if norm=="batch": + norm_layer=nn.BatchNorm2d + elif norm==None: + norm_layer=nn.Identity + + self.conv_block = nn.Sequential( + norm_layer(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, padding=padding + ) + ) + self.out_layer=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1), + norm_layer(output_dim), + ) + self.roll_out_conv=nn.Sequential( + norm_layer(output_dim), + nn.ReLU(), + RollOut_Conv(output_dim, output_dim), + ) + if self.use_attn: + self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim, + img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso, + out_channels=output_dim,max_nimg=5,padding=triplane_padding) + self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding) + + self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) + self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) + def forward(self, x,t_emb,img_feat,proj_mat,valid_frames): + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + t_emb = t_emb[:,:,None,None] + + out=self.conv_block(x)+t_emb + out=self.out_layer(out) + feature=out+self.conv_skip(x) + feature = self.roll_out_conv(feature) + if self.use_attn: + feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect + feature=self.down_conv(feature) + + return feature + +class ResidualConv(nn.Module): + def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256): + super(ResidualConv, self).__init__() + + self.conv_block = nn.Sequential( + nn.BatchNorm2d(input_dim), + nn.ReLU(), + nn.Conv2d( + input_dim, output_dim, kernel_size=3, stride=stride, padding=padding + ), + nn.BatchNorm2d(output_dim), + nn.ReLU(), + RollOut_Conv(output_dim,output_dim), + ) + self.out_layer=nn.Sequential( + nn.BatchNorm2d(output_dim), + nn.ReLU(), + nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), + ) + self.conv_skip = nn.Sequential( + nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), + nn.BatchNorm2d(output_dim), + ) + + self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim) + self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim) + def forward(self, x,t_emb): + t_emb = F.silu(self.map_layer0(t_emb)) + t_emb = F.silu(self.map_layer1(t_emb)) + t_emb = t_emb[:,:,None,None] + + out=self.conv_block(x)+t_emb + out=self.out_layer(out) + + return out + self.conv_skip(x) + +class Upsample(nn.Module): + def __init__(self, input_dim, output_dim, kernel, stride): + super(Upsample, self).__init__() + + self.upsample = nn.ConvTranspose2d( + input_dim, output_dim, kernel_size=kernel, stride=stride + ) + + def forward(self, x): + return self.upsample(x) + + + +class ResUnet_Par_cond(nn.Module): + def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32): + super(ResUnet_Par_cond, self).__init__() + + self.input_layer = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), + nn.BatchNorm2d(filters[0]), + nn.ReLU(), + nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), + ) + self.input_skip = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) + ) + + self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1) + self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) + self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1) + self.bridge = ResidualConv(filters[3],filters[4],2,1) + + + self.upsample_1 = Upsample(filters[4], filters[4], 2, 2) + self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1) + + self.upsample_2 = Upsample(filters[3], filters[3], 2, 2) + self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) + + self.upsample_3 = Upsample(filters[2], filters[2], 2, 2) + self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) + + self.upsample_4 = Upsample(filters[1], filters[1], 2, 2) + self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1) + + self.output_layer = nn.Sequential( + #nn.LayerNorm(filters[0]), + nn.LayerNorm(64),#normalize along width dimension, usually it should normalize along channel dimension, + # I don't know why, but the finetuning performance increase significantly + zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)), + ) + self.par_channel=par_channel + self.par_conv=nn.Sequential( + nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1), + ) + self.t_emb_layer=PositionalEmbedding(256) + self.cat_emb=nn.Linear( + in_features=6, + out_features=256, + ) + + def forward(self, x,t,category_code,par_point_feat): + # Encode + t_emb=self.t_emb_layer(t) + cat_emb=self.cat_emb(category_code) + t_emb=t_emb+cat_emb + #print(t_emb.shape) + x1 = self.input_layer(x) + self.input_skip(x) + if par_point_feat is not None: + par_point_feat=self.par_conv(par_point_feat) + else: + bs,_,H,W=x1.shape + #print(x1.shape) + par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device) + x1 = torch.cat([x1, par_point_feat], dim=1) + x2 = self.residual_conv_1(x1,t_emb) + x3 = self.residual_conv_2(x2,t_emb) + # Bridge + x4 = self.residual_conv_3(x3,t_emb) + x5 = self.bridge(x4,t_emb) + + x6=self.upsample_1(x5) + x6=torch.cat([x6,x4],dim=1) + x7=self.up_residual_conv1(x6,t_emb) + + x7=self.upsample_2(x7) + x7=torch.cat([x7,x3],dim=1) + x8=self.up_residual_conv2(x7,t_emb) + + x8 = self.upsample_3(x8) + x8 = torch.cat([x8, x2], dim=1) + #print(x8.shape) + x9 = self.up_residual_conv3(x8,t_emb) + + x9 = self.upsample_4(x9) + x9 = torch.cat([x9, x1], dim=1) + x10 = self.up_residual_conv4(x9,t_emb) + + output=self.output_layer(x10) + + return output + +class ResUnet_DirectAttenMultiImg_Cond(nn.Module): + def __init__(self, channel, filters=[64, 128, 256, 512, 1024], + img_in_channels=1024,vit_reso=16,output_channel=32, + use_par=False,par_channel=32,triplane_padding=0.1,norm='batch', + use_cat_embedding=False, + block_type="multiview_local"): + super(ResUnet_DirectAttenMultiImg_Cond, self).__init__() + + if block_type == "multiview_local": + block=ResidualConv_MultiImgAtten + elif block_type =="multiview_global": + block=ResidualConv_GlobalAtten + elif block_type =="multiview_tri": + block=ResidualConv_TriMultiImgAtten + else: + raise NotImplementedError + + if norm=="batch": + norm_layer=nn.BatchNorm2d + elif norm==None: + norm_layer=nn.Identity + + self.use_cat_embedding=use_cat_embedding + self.input_layer = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), + norm_layer(filters[0]), + nn.ReLU(), + nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), + ) + self.input_skip = nn.Sequential( + nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) + ) + self.use_par=use_par + input_1_channels=filters[0] + if self.use_par: + self.par_conv = nn.Sequential( + nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1), + ) + input_1_channels=filters[0]+par_channel + self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64 + ,use_attn=False,triplane_padding=triplane_padding,norm=norm) + self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32, + use_attn=False,triplane_padding=triplane_padding,norm=norm) + self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16, + use_attn=False,triplane_padding=triplane_padding,norm=norm) + self.bridge = block(filters[3] , filters[4], 2, 1, reso=8 + ,use_attn=False,triplane_padding=triplane_padding,norm=norm) #input reso is 8, output reso is 4 + + + self.upsample_1 = Upsample(filters[4], filters[4], 2, 2) + self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso, + use_attn=True,triplane_padding=triplane_padding,norm=norm) + + self.upsample_2 = Upsample(filters[3], filters[3], 2, 2) + self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso, + use_attn=True,triplane_padding=triplane_padding,norm=norm) + + self.upsample_3 = Upsample(filters[2], filters[2], 2, 2) + self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso, + use_attn=True,triplane_padding=triplane_padding,norm=norm) + + self.upsample_4 = Upsample(filters[1], filters[1], 2, 2) + self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64, + use_attn=False,triplane_padding=triplane_padding,norm=norm) + + self.output_layer = nn.Sequential( + nn.LayerNorm(64), #normalize along width dimension, usually it should normalize along channel dimension, + # I don't know why, but the finetuning performance increase significantly + #nn.LayerNorm([filters[0], 192, 64]), + zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)), + ) + self.t_emb_layer=PositionalEmbedding(256) + if use_cat_embedding: + self.cat_emb = nn.Linear( + in_features=6, + out_features=256, + ) + + def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None): + # Encode + t_emb=self.t_emb_layer(t) + if self.use_cat_embedding: + cat_emb=self.cat_emb(category_code) + t_emb=t_emb+cat_emb + x1 = self.input_layer(x) + self.input_skip(x) + if self.use_par: + par_point_feat=self.par_conv(par_point_feat) + x1 = torch.cat([x1, par_point_feat], dim=1) + x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames) + x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames) + x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames) + x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames) + + x6=self.upsample_1(x5) + x6=torch.cat([x6,x4],dim=1) + x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames) + + x7=self.upsample_2(x7) + x7=torch.cat([x7,x3],dim=1) + x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames) + + x8 = self.upsample_3(x8) + x8 = torch.cat([x8, x2], dim=1) + #print(x8.shape) + x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames) + + x9 = self.upsample_4(x9) + x9 = torch.cat([x9, x1], dim=1) + x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames) + + output=self.output_layer(x10) + + return output + + +if __name__=="__main__": + net=ResUnet(32,output_channel=32).float().cuda() + n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad) + print("Model = %s" % str(net)) + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + par_point_feat=torch.randn((10,32,64*3,64)).float().cuda() + input=torch.randn((10,32,64*3,64)).float().cuda() + t=torch.randn((10,1,1,1)).float().cuda() + output=net(input,t.flatten(),par_point_feat) + #print(output.shape) \ No newline at end of file diff --git a/models/modules/unet.py b/models/modules/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..dceeae5b5a562a20952361b7ee61b80550bee0ed --- /dev/null +++ b/models/modules/unet.py @@ -0,0 +1,304 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from collections import OrderedDict +from torch.nn import init +import numpy as np + + +def conv3x3(in_channels, out_channels, stride=1, + padding=1, bias=True, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + + +def upconv2x2(in_channels, out_channels, mode='transpose'): + if mode == 'transpose': + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=2, + stride=2) + else: + # out_channels is always going to be the same + # as in_channels + return nn.Sequential( + nn.Upsample(mode='bilinear', scale_factor=2), + conv1x1(in_channels, out_channels)) + + +def conv1x1(in_channels, out_channels, groups=1): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + groups=groups, + stride=1) + +class RollOut_Conv(nn.Module): + def __init__(self,in_channels,out_channels): + super(RollOut_Conv,self).__init__() + #pass + self.in_channels=in_channels + self.out_channels=out_channels + self.conv = conv3x3(self.in_channels*3, self.out_channels) + + def forward(self,row_features): + H,W=row_features.shape[2],row_features.shape[3] + H_per=H//3 + xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per) + xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) + yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) + cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1) + + xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) + zy_feature=yz_feature.transpose(2,3) #switch z y axis, for reduced confusion + zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) + cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1) + + xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) + yx_feature=xy_feature.transpose(2,3) + yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) + cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1) + + fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) #concat at row dimension + + x = self.conv(fuse_row_feat) + + return x + + +class DownConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 MaxPool. + A ReLU activation follows each convolution. + """ + + def __init__(self, in_channels, out_channels, pooling=True): + super(DownConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.pooling = pooling + + self.conv1 = conv3x3(self.in_channels, self.out_channels) + self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + if self.pooling: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.Rollout_conv(x)) + x = F.relu(self.conv2(x)) + before_pool = x + if self.pooling: + x = self.pool(x) + return x, before_pool + + +class UpConv(nn.Module): + """ + A helper Module that performs 2 convolutions and 1 UpConvolution. + A ReLU activation follows each convolution. + """ + + def __init__(self, in_channels, out_channels, + merge_mode='concat', up_mode='transpose'): + super(UpConv, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.merge_mode = merge_mode + self.up_mode = up_mode + + self.upconv = upconv2x2(self.in_channels, self.out_channels, + mode=self.up_mode) + + if self.merge_mode == 'concat': + self.conv1 = conv3x3( + 2 * self.out_channels, self.out_channels) + else: + # num of input channels to conv2 is same + self.conv1 = conv3x3(self.out_channels, self.out_channels) + self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels) + self.conv2 = conv3x3(self.out_channels, self.out_channels) + + def forward(self, from_down, from_up): + """ Forward pass + Arguments: + from_down: tensor from the encoder pathway + from_up: upconv'd tensor from the decoder pathway + """ + from_up = self.upconv(from_up) + if self.merge_mode == 'concat': + x = torch.cat((from_up, from_down), 1) + else: + x = from_up + from_down + x = F.relu(self.conv1(x)) + x = F.relu(self.Rollout_conv(x)) + x = F.relu(self.conv2(x)) + return x + + +class UNet(nn.Module): + """ `UNet` class is based on https://arxiv.org/abs/1505.04597 + + The U-Net is a convolutional encoder-decoder neural network. + Contextual spatial information (from the decoding, + expansive pathway) about an input tensor is merged with + information representing the localization of details + (from the encoding, compressive pathway). + + Modifications to the original paper: + (1) padding is used in 3x3 convolutions to prevent loss + of border pixels + (2) merging outputs does not require cropping due to (1) + (3) residual connections can be used by specifying + UNet(merge_mode='add') + (4) if non-parametric upsampling is used in the decoder + pathway (specified by upmode='upsample'), then an + additional 1x1 2d convolution occurs after upsampling + to reduce channel dimensionality by a factor of 2. + This channel halving happens with the convolution in + the tranpose convolution (specified by upmode='transpose') + """ + + def __init__(self, num_classes, in_channels=3, depth=5, + start_filts=64, up_mode='transpose', + merge_mode='concat', **kwargs): + """ + Arguments: + in_channels: int, number of channels in the input tensor. + Default is 3 for RGB images. + depth: int, number of MaxPools in the U-Net. + start_filts: int, number of convolutional filters for the + first conv. + up_mode: string, type of upconvolution. Choices: 'transpose' + for transpose convolution or 'upsample' for nearest neighbour + upsampling. + """ + super(UNet, self).__init__() + + if up_mode in ('transpose', 'upsample'): + self.up_mode = up_mode + else: + raise ValueError("\"{}\" is not a valid mode for " + "upsampling. Only \"transpose\" and " + "\"upsample\" are allowed.".format(up_mode)) + + if merge_mode in ('concat', 'add'): + self.merge_mode = merge_mode + else: + raise ValueError("\"{}\" is not a valid mode for" + "merging up and down paths. " + "Only \"concat\" and " + "\"add\" are allowed.".format(up_mode)) + + # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add' + if self.up_mode == 'upsample' and self.merge_mode == 'add': + raise ValueError("up_mode \"upsample\" is incompatible " + "with merge_mode \"add\" at the moment " + "because it doesn't make sense to use " + "nearest neighbour to reduce " + "depth channels (by half).") + + self.num_classes = num_classes + self.in_channels = in_channels + self.start_filts = start_filts + self.depth = depth + + self.down_convs = [] + self.up_convs = [] + + # create the encoder pathway and add to a list + for i in range(depth): + ins = self.in_channels if i == 0 else outs + outs = self.start_filts * (2 ** i) + pooling = True if i < depth - 1 else False + + down_conv = DownConv(ins, outs, pooling=pooling) + self.down_convs.append(down_conv) + + # create the decoder pathway and add to a list + # - careful! decoding only requires depth-1 blocks + for i in range(depth - 1): + ins = outs + outs = ins // 2 + up_conv = UpConv(ins, outs, up_mode=up_mode, + merge_mode=merge_mode) + self.up_convs.append(up_conv) + + # add the list of modules to current module + self.down_convs = nn.ModuleList(self.down_convs) + self.up_convs = nn.ModuleList(self.up_convs) + self.conv_final = conv1x1(outs, self.num_classes) + + self.reset_params() + + @staticmethod + def weight_init(m): + if isinstance(m, nn.Conv2d): + init.xavier_normal_(m.weight) + init.constant_(m.bias, 0) + + def reset_params(self): + for i, m in enumerate(self.modules()): + self.weight_init(m) + + def forward(self, feature_plane): + #cat_feature=torch.cat([feature_plane['xz'],feature_plane['xy'],feature_plane,feature_plane['yz']],dim=2) #concat at row dimension + x=feature_plane + encoder_outs = [] + # encoder pathway, save outputs for merging + for i, module in enumerate(self.down_convs): + x, before_pool = module(x) + encoder_outs.append(before_pool) + for i, module in enumerate(self.up_convs): + before_pool = encoder_outs[-(i + 2)] + x = module(before_pool, x) + + # No softmax is used. This means you need to use + # nn.CrossEntropyLoss is your training script, + # as this module includes a softmax already. + x = self.conv_final(x) + return x + + +if __name__ == "__main__": + # """ + # testing + # """ + # model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32) + # print(model) + # print(sum(p.numel() for p in model.parameters())) + # + # reso = 176 + # x = np.zeros((1, 1, reso, reso)) + # x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan + # x = torch.FloatTensor(x) + # + # out = model(x) + # print('%f' % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso))) + # + # # loss = torch.sum(out) + # # loss.backward() + #roll_out_conv=RollOut_Conv(in_channels=32,out_channels=32).cuda().float() + model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float() + row_feature=torch.randn((10,32,128*3,128)).cuda().float() + output=model(row_feature) + #output_feature=roll_out_conv(row_feature) + #print(output_feature.shape) \ No newline at end of file diff --git a/models/modules/utils.py b/models/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..552813bff9f90e2508a4922a7bf1f026d3247f6c --- /dev/null +++ b/models/modules/utils.py @@ -0,0 +1,25 @@ +import torch + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) diff --git a/output/put_checkpoints_here b/output/put_checkpoints_here new file mode 100644 index 0000000000000000000000000000000000000000..56a6051ca2b02b04ef92d5150c9ef600403cb1de --- /dev/null +++ b/output/put_checkpoints_here @@ -0,0 +1 @@ +1 \ No newline at end of file diff --git a/process_scripts/augment_arkit_partial_point.py b/process_scripts/augment_arkit_partial_point.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3f40f67fd93b13bf3bd9f9e1b38404375850d6 --- /dev/null +++ b/process_scripts/augment_arkit_partial_point.py @@ -0,0 +1,64 @@ +import numpy as np +import scipy +import os +import trimesh +from sklearn.cluster import KMeans +import random +import glob +import tqdm +import argparse +import multiprocessing as mp +import sys +sys.path.append("..") +from datasets.taxonomy import arkit_category + +parser=argparse.ArgumentParser() +parser.add_argument('--category',nargs="+",type=str) +parser.add_argument("--keyword",type=str,default="lowres") #augment only the low resolution points +parser.add_argument("--data_root",type=str,default="../data/other_data") +args=parser.parse_args() +category=args.category +if category[0]=="all": + category=arkit_category["all"] +kmeans=KMeans( + init="random", + n_clusters=20, + n_init=10, + max_iter=300, + random_state=42 +) + +def process_data(src_point_path,save_folder,keyword): + src_point_tri = trimesh.load(src_point_path) + src_point = np.asarray(src_point_tri.vertices) + kmeans.fit(src_point) + point_cluster_index = kmeans.labels_ + + '''choose 10~19 clusters to form the augmented new point''' + for i in range(10): + n_cluster = random.randint(14, 19) # 14,19 for lowres, 10,19 for highres + choose_cluster = np.random.choice(20, n_cluster, replace=False) + aug_point_list = [] + for cluster_index in choose_cluster: + cluster_point = src_point[point_cluster_index == cluster_index] + aug_point_list.append(cluster_point) + aug_point = np.concatenate(aug_point_list, axis=0) + save_path = os.path.join(save_folder, "%s_partial_points_%d.ply" % (keyword, i + 1)) + print("saving to %s"%(save_path)) + aug_point_tri = trimesh.PointCloud(vertices=aug_point) + aug_point_tri.export(save_path) + +pool=mp.Pool(10) +for cat in category[0:]: + keyword=args.keyword + point_dir = os.path.join(args.data_root,cat,"5_partial_points") + folder_list=os.listdir(point_dir) + for folder in tqdm.tqdm(folder_list[0:]): + folder_path=os.path.join(point_dir,folder) + src_point_path=os.path.join(point_dir,folder,"%s_partial_points_0.ply"%(keyword)) + if os.path.exists(src_point_path)==False: + continue + save_folder=folder_path + pool.apply_async(process_data,(src_point_path,save_folder,keyword)) +pool.close() +pool.join() \ No newline at end of file diff --git a/process_scripts/augment_synthetic_partial_points.py b/process_scripts/augment_synthetic_partial_points.py new file mode 100644 index 0000000000000000000000000000000000000000..a10f9005786729be88786f8981601db868a67e10 --- /dev/null +++ b/process_scripts/augment_synthetic_partial_points.py @@ -0,0 +1,64 @@ +import numpy as np +import scipy +import os +import trimesh +from sklearn.cluster import KMeans +import random +import glob +import tqdm +import multiprocessing as mp +import sys +sys.path.append("..") +from datasets.taxonomy import synthetic_category_combined + +import argparse +parser=argparse.ArgumentParser() +parser.add_argument("--category",nargs="+",type=str) +parser.add_argument("--root_dir",type=str,default="../data/other_data") +args=parser.parse_args() +categories=args.category +if categories[0]=="all": + categories=synthetic_category_combined["all"] + +kmeans=KMeans( + init="random", + n_clusters=7, + n_init=10, + max_iter=300, + random_state=42 +) + +def process_data(src_filepath,save_path): + #print("processing %s"%(src_filepath)) + src_point_tri = trimesh.load(src_filepath) + src_point = np.asarray(src_point_tri.vertices) + kmeans.fit(src_point) + point_cluster_index = kmeans.labels_ + + n_cluster = random.randint(3, 6) + choose_cluster = np.random.choice(7, n_cluster, replace=False) + aug_point_list = [] + for cluster_index in choose_cluster: + cluster_point = src_point[point_cluster_index == cluster_index] + aug_point_list.append(cluster_point) + aug_point = np.concatenate(aug_point_list, axis=0) + aug_point_tri = trimesh.PointCloud(vertices=aug_point) + print("saving to %s"%(save_path)) + aug_point_tri.export(save_path) + +pool=mp.Pool(10) +for cat in categories: + print("processing %s"%cat) + point_dir=os.path.join(args.root_dir,cat,"5_partial_points") + folder_list=os.listdir(point_dir) + for folder in folder_list[:]: + folder_path=os.path.join(point_dir,folder) + src_filelist=glob.glob(folder_path+"/partial_points_*.ply") + for src_filepath in src_filelist: + basename=os.path.basename(src_filepath) + save_path = os.path.join(point_dir, folder, "aug7_" + basename) + pool.apply_async(process_data,(src_filepath,save_path)) +pool.close() +pool.join() + + diff --git a/process_scripts/dist_export_triplane_features.sh b/process_scripts/dist_export_triplane_features.sh new file mode 100644 index 0000000000000000000000000000000000000000..1609d4d76bfa9363472af63c9d27f5533c3cb3d2 --- /dev/null +++ b/process_scripts/dist_export_triplane_features.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \ +export_triplane_features.py \ +--configs ../configs/train_triplane_vae.yaml \ +--batch_size 10 \ +--ae-pth ../output/ae/chair/best-checkpoint.pth \ +--data-pth ../data \ +--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair + #sub category \ No newline at end of file diff --git a/process_scripts/dist_extract_vit.sh b/process_scripts/dist_extract_vit.sh new file mode 100644 index 0000000000000000000000000000000000000000..99e6d177bc6b10c321579e8b4f78897f08d58049 --- /dev/null +++ b/process_scripts/dist_extract_vit.sh @@ -0,0 +1,6 @@ +CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15000 --nproc_per_node=2 \ +extract_img_vit_features.py \ +--batch_size 24 \ +--ckpt_path ../data/open_clip_pytorch_model.bin \ +--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair #sub category +#--category 02871439 future_shelf ABO_shelf arkit_shelf \ diff --git a/process_scripts/export_triplane_features.py b/process_scripts/export_triplane_features.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e4599001a90d8f0ab7ca409c40005e12b4375a --- /dev/null +++ b/process_scripts/export_triplane_features.py @@ -0,0 +1,122 @@ +import argparse +import math +import sys +sys.path.append("..") +import numpy as np +import os +import torch + +import trimesh + +from datasets import Object_Occ,Scale_Shift_Rotate +from models import get_model +from pathlib import Path +import open3d as o3d +from configs.config_utils import CONFIG +import tqdm +from util import misc +from datasets.taxonomy import synthetic_arkit_category_combined + +if __name__ == "__main__": + + parser = argparse.ArgumentParser('', add_help=False) + parser.add_argument('--configs',type=str,required=True) + parser.add_argument('--ae-pth',type=str) + parser.add_argument("--category",nargs='+', type=str) + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument("--batch_size", default=1, type=int) + parser.add_argument("--data-pth",default="../data",type=str) + + args = parser.parse_args() + misc.init_distributed_mode(args) + device = torch.device(args.device) + + config_path=args.configs + config=CONFIG(config_path) + dataset_config=config.config['dataset'] + dataset_config['data_path']=args.data_pth + #transform = AxisScaling((0.75, 1.25), True) + transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True) + if len(args.category)==1 and args.category[0]=="all": + category=synthetic_arkit_category_combined["all"] + else: + category=args.category + train_dataset = Object_Occ(dataset_config['data_path'], split="train", + categories=category, + transform=transform, sampling=True, + num_samples=1024, return_surface=True, + surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) + val_dataset = Object_Occ(dataset_config['data_path'], split="val", + categories=category, + transform=transform, sampling=True, + num_samples=1024, return_surface=True, + surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + train_sampler = torch.utils.data.DistributedSampler( + train_dataset, num_replicas=num_tasks, rank=global_rank, + shuffle=False) # shuffle=True to reduce monitor bias + val_sampler=torch.utils.data.DistributedSampler( + val_dataset, num_replicas=num_tasks, rank=global_rank, + shuffle=False) # shu + #dataset=val_dataset + batch_size=args.batch_size + train_dataloader=torch.utils.data.DataLoader( + train_dataset,sampler=train_sampler, + batch_size=batch_size, + num_workers=10, + shuffle=False, + drop_last=False, + ) + val_dataloader = torch.utils.data.DataLoader( + val_dataset, sampler=val_sampler, + batch_size=batch_size, + num_workers=10, + shuffle=False, + drop_last=False, + ) + dataloader_list=[train_dataloader,val_dataloader] + #dataloader_list=[val_dataloader] + output_dir=os.path.join(dataset_config['data_path'],"other_data") + #output_dir="/data1/haolin/datasets/ShapeNetV2_watertight" + + model_config=config.config['model'] + model=get_model(model_config) + model.load_state_dict(torch.load(args.ae_pth)['model']) + model.eval().float().to(device) + #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) + + with torch.no_grad(): + for e in range(5): + for dataloader in dataloader_list: + for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)): + surface = data_batch['surface'].to(device, non_blocking=True) + model_ids=data_batch['model_id'] + tran_mats=data_batch['tran_mat'] + categories=data_batch['category'] + with torch.no_grad(): + plane_feat,_,means,logvars=model.encode(surface) + plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear') + vars=torch.exp(logvars) + means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear") + vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4 + sample_logvars=torch.log(vars) + + for j in range(means.shape[0]): + #plane_dist=plane_feat[j].float().cpu().numpy() + mean=means[j].float().cpu().numpy() + logvar=sample_logvars[j].float().cpu().numpy() + tran_mat=tran_mats[j].float().cpu().numpy() + + output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j]) + Path(output_folder).mkdir(parents=True, exist_ok=True) + exist_len=len(os.listdir(output_folder)) + save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len)) + np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat) diff --git a/process_scripts/extract_img_vit_features.py b/process_scripts/extract_img_vit_features.py new file mode 100644 index 0000000000000000000000000000000000000000..17d3b0fc89e31e2c84305afdbc56178d88ca84de --- /dev/null +++ b/process_scripts/extract_img_vit_features.py @@ -0,0 +1,73 @@ +import os,sys +sys.path.append("..") +from util.simple_image_loader import Image_dataset +from torch.utils.data import DataLoader +import timm +import torch +from tqdm import tqdm +import numpy as np +from transformers import DPTForDepthEstimation, DPTFeatureExtractor +import argparse +from util import misc +from datasets.taxonomy import synthetic_arkit_category_combined +parser=argparse.ArgumentParser() + +parser.add_argument("--category",nargs="+",type=str) +parser.add_argument("--root_dir",type=str, default="../data") +parser.add_argument("--ckpt_path",type=str,default="../open_clip_pytorch_model.bin") +parser.add_argument("--batch_size",type=int,default=24) +parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') +parser.add_argument('--local_rank', default=-1, type=int) +parser.add_argument('--dist_on_itp', action='store_true') +parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') +args= parser.parse_args() +misc.init_distributed_mode(args) +category=args.category + +#dataset=Image_dataset(categories=['03001627','ABO_chair','future_chair']) +if args.category[0]=="all": + category=synthetic_arkit_category_combined["all"] +print("loading dataset") +dataset=Image_dataset(dataset_folder=args.root_dir,categories=category,n_px=224) +num_tasks = misc.get_world_size() +global_rank = misc.get_rank() +sampler = torch.utils.data.DistributedSampler( + dataset, num_replicas=num_tasks, rank=global_rank, + shuffle=False) # shuffle=True to reduce monitor bias + +dataloader=DataLoader( + dataset, + sampler=sampler, + batch_size=args.batch_size, + num_workers=4, + pin_memory=True, + drop_last=False +) +print("loading model") +VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b' +model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file=args.ckpt_path)) +model=model.eval().float().cuda() +save_dir=os.path.join(args.root_dir,"other_data") +for idx,data_batch in enumerate(dataloader): + if idx%50==0: + print("{}/{}".format(dataloader.__len__(),idx)) + images = data_batch["images"].cuda().float() + model_id= data_batch["model_id"] + image_name=data_batch["image_name"] + category=data_batch["category"] + with torch.no_grad(): + #output=model(images,output_hidden_states=True) + output_features=model.forward_features(images) + #predict_depth=output.predicted_depth + #print(predict_depth.shape) + for j in range(output_features.shape[0]): + save_folder=os.path.join(save_dir,category[j],"7_img_features",model_id[j]) + os.makedirs(save_folder,exist_ok=True) + save_path=os.path.join(save_folder,image_name[j]+".npz") + #print("saving to",save_path) + np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32)) + + + diff --git a/process_scripts/generate_split_for_arkit.py b/process_scripts/generate_split_for_arkit.py new file mode 100644 index 0000000000000000000000000000000000000000..02b156392c63c8edd780c6eb124c0c5070483ba1 --- /dev/null +++ b/process_scripts/generate_split_for_arkit.py @@ -0,0 +1,102 @@ +import os +import numpy as np +import glob +import open3d as o3d +import json +import argparse +import glob + +parser=argparse.ArgumentParser() +parser.add_argument("--cat",required=True,type=str,nargs="+") +parser.add_argument("--keyword",default="lowres",type=str) +parser.add_argument("--root_dir",type=str,default="../data") +args=parser.parse_args() + +keyword=args.keyword +sdf_folder="occ_data" +other_folder="other_data" +data_dir=args.root_dir + +align_dir=os.path.join(args.root_dir,"align_mat_all") # this alignment matrix is aligned from highres scan to lowres scan +# the alignment matrix is still under cleaning, not all the data have proper alignment matrix yet. +align_filelist=glob.glob(align_dir+"/*/*.txt") +valid_model_list=[] +for align_filepath in align_filelist: + if "-v" in align_filepath: + align_mat=np.loadtxt(align_filepath) + if align_mat.shape[0]!=4: + continue + model_id=os.path.basename(align_filepath).split("-")[0] + valid_model_list.append(model_id) + +print("there are %d valid lowres models"%(len(valid_model_list))) + +category_list=args.cat +for category in category_list: + train_path=os.path.join(data_dir,sdf_folder,category,"train.lst") + with open(train_path,'r') as f: + train_list=f.readlines() + train_list=[item.rstrip() for item in train_list] + if ".npz" in train_list[0]: + train_list=[item[:-4] for item in train_list] + val_path=os.path.join(data_dir,sdf_folder,category,"val.lst") + with open(val_path,'r') as f: + val_list=f.readlines() + val_list=[item.rstrip() for item in val_list] + if ".npz" in val_list[0]: + val_list=[item[:-4] for item in val_list] + + + sdf_dir=os.path.join(data_dir,sdf_folder,category) + filelist=os.listdir(sdf_dir) + model_id_list=[item[:-4] for item in filelist if ".npz" in item] + + train_par_img_list=[] + val_par_img_list=[] + for model_id in model_id_list: + if model_id not in valid_model_list: + continue + image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id) + partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id) + if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False: + continue + if os.path.exists(image_dir): + image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png") + image_list=[os.path.basename(image_path) for image_path in image_list] + else: + image_list=[] + + if os.path.exists(partial_dir): + partial_list=glob.glob(partial_dir+"/%s_partial_points_*.ply"%(keyword)) + else: + partial_list=[] + partial_valid_list=[] + for partial_filepath in partial_list: + par_o3d=o3d.io.read_point_cloud(partial_filepath) + par_xyz=np.asarray(par_o3d.points) + if par_xyz.shape[0]>2048: + partial_valid_list.append(os.path.basename(partial_filepath)) + if model_id in val_list: + if "%s_partial_points_0.ply"%(keyword) in partial_valid_list: + partial_valid_list=["%s_partial_points_0.ply"%(keyword)] + else: + partial_valid_list=[] + if len(image_list)==0 and len(partial_valid_list)==0: + continue + ret_dict={ + "model_id":model_id, + "image_filenames":image_list[:], + "partial_filenames":partial_valid_list[:] + } + if model_id in train_list: + train_par_img_list.append(ret_dict) + elif model_id in val_list: + val_par_img_list.append(ret_dict) + + train_save_path=os.path.join(sdf_dir,"%s_train_par_img.json"%(keyword)) + with open(train_save_path,'w') as f: + json.dump(train_par_img_list,f,indent=4) + + val_save_path=os.path.join(sdf_dir,"%s_val_par_img.json"%(keyword)) + with open(val_save_path,'w') as f: + json.dump(val_par_img_list,f,indent=4) diff --git a/process_scripts/generate_split_for_synthetic_data.py b/process_scripts/generate_split_for_synthetic_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6d50377c7ad4db036eadf6d12fef0d1f771f0e2a --- /dev/null +++ b/process_scripts/generate_split_for_synthetic_data.py @@ -0,0 +1,78 @@ +import os,sys +import numpy as np +import glob +import open3d as o3d +import json +import argparse + +parser=argparse.ArgumentParser() +parser.add_argument("--cat",required=True,type=str,nargs="+") +parser.add_argument("--root_dir",type=str,default="../data") +args=parser.parse_args() + +sdf_folder="occ_data" +other_folder="other_folder" +data_dir=args.root_dir +category=args.cat +train_path=os.path.join(data_dir,sdf_folder,category,"train.lst") +with open(train_path,'r') as f: + train_list=f.readlines() + train_list=[item.rstrip() for item in train_list] + if ".npz" in train_list[0]: + train_list=[item[:-4] for item in train_list] +val_path=os.path.join(data_dir,sdf_folder,category,"val.lst") +with open(val_path,'r') as f: + val_list=f.readlines() + val_list=[item.rstrip() for item in val_list] + if ".npz" in val_list[0]: + val_list=[item[:-4] for item in val_list] + +category_list=args.cat +for category in category_list: + sdf_dir=os.path.join(data_dir,sdf_folder,category) + filelist=os.listdir(sdf_dir) + model_id_list=[item[:-4] for item in filelist if ".npz" in item] + + train_par_img_list=[] + val_par_img_list=[] + for model_id in model_id_list: + image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id) + partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id) + if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False: + continue + if os.path.exists(image_dir): + image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png") + image_list=[os.path.basename(image_path) for image_path in image_list] + else: + image_list=[] + + if os.path.exists(partial_dir): + partial_list=glob.glob(partial_dir+"/partial_points_*.ply") + else: + partial_list=[] + partial_valid_list=[] + for partial_filepath in partial_list: + par_o3d=o3d.io.read_point_cloud(partial_filepath) + par_xyz=np.asarray(par_o3d.points) + if par_xyz.shape[0]>2048: + partial_valid_list.append(os.path.basename(partial_filepath)) + if len(image_list)==0 and len(partial_valid_list)==0: + continue + ret_dict={ + "model_id":model_id, + "image_filenames":image_list[:], + "partial_filenames":partial_valid_list[:] + } + if model_id in train_list: + train_par_img_list.append(ret_dict) + elif model_id in val_list: + val_par_img_list.append(ret_dict) + + #print(train_par_img_list) + train_save_path=os.path.join(sdf_dir,"train_par_img.json") + with open(train_save_path,'w') as f: + json.dump(train_par_img_list,f,indent=4) + + val_save_path=os.path.join(sdf_dir,"val_par_img.json") + with open(val_save_path,'w') as f: + json.dump(val_par_img_list,f,indent=4) diff --git a/process_scripts/unzip_all_data.py b/process_scripts/unzip_all_data.py new file mode 100644 index 0000000000000000000000000000000000000000..52b8dc3cef87d5a184724b57d906ec45a82fc9cf --- /dev/null +++ b/process_scripts/unzip_all_data.py @@ -0,0 +1,38 @@ +import os +import glob +import argparse + +parser = argparse.ArgumentParser("unzip the prepared data") +parser.add_argument("--occ_root", type=str, default="../data/occ_data") +parser.add_argument("--other_root", type=str,default="../data/other_data") +parser.add_argument("--unzip_occ",default=False,action="store_true") +parser.add_argument("--unzip_other",default=False,action="store_true") + +args=parser.parse_args() +if args.unzip_occ: + filelist=os.listdir(args.occ_root) + for filename in filelist: + filepath=os.path.join(args.occ_root,filename) + if ".rar" in filename: + unrar_command="unrar x %s %s"%(filepath,args.occ_root) + os.system(unrar_command) + elif ".zip" in filename: + unzip_command="7z x %s -o%s"%(filepath,args.occ_root) + os.system(unzip_command) + + +if args.unzip_other: + category_list=os.listdir(args.other_root) + for category in category_list: + category_folder=os.path.join(args.other_root,category) + #print(category_folder) + rar_filelist=glob.glob(category_folder+"/*.rar") + zip_filelist=glob.glob(category_folder+"/*.zip") + + for rar_filepath in rar_filelist: + unrar_command="unrar x %s %s"%(rar_filepath,category_folder) + os.system(unrar_command) + for zip_filepath in zip_filelist: + unzip_command="7z x %s -o%s"%(zip_filepath,category_folder) + os.system(unzip_command) + diff --git a/scripts/train_triplane_diffusion.py b/scripts/train_triplane_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4bfe1cc1d97639ff2cc6deecdf31762750e02756 --- /dev/null +++ b/scripts/train_triplane_diffusion.py @@ -0,0 +1,317 @@ +import argparse +import sys +sys.path.append("..") +import datetime +import json +import numpy as np +import os +import time +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import util.misc as misc +from datasets import build_dataset +from util.misc import NativeScalerWithGradNormCount as NativeScaler +from models import get_model,get_criterion + +from engine.engine_triplane_dm import train_one_epoch,evaluate_reconstruction + +def get_args_parser(): + parser = argparse.ArgumentParser('Latent Diffusion', add_help=False) + parser.add_argument('--batch_size', default=64, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--epochs', default=800, type=int) + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + + parser.add_argument('--ae-pth',type=str) + # Optimizer parameters + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + parser.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', # 2e-4 + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--layer_decay', type=float, default=0.75, + help='layer-wise lr decay from ELECTRA/BEiT') + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', + help='epochs to warmup LR') + # Dataset parameters + parser.add_argument('--data-pth', default='../data', type=str, + help='dataset path') + + parser.add_argument('--output_dir', default='./output/', + help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default='./output/', + help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', + help='resume from checkpoint') + + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, + help='Enabling distributed evaluation (recommended during training for faster monitor') + parser.add_argument('--num_workers', default=60, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + parser.add_argument('--constant_lr', default=False, action='store_true') + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + parser.add_argument('--load_proj_mat',default=True,type=bool) + parser.add_argument('--num_objects',type=int,default=-1) + + parser.add_argument('--configs', type=str) + parser.add_argument('--finetune', default=False, action="store_true") + parser.add_argument('--finetune-pth', type=str) + parser.add_argument('--use_cls_free',action="store_true",default=False) + parser.add_argument('--sync_bn',action="store_true",default=False) + parser.add_argument('--category',type=str) + parser.add_argument('--stop',type=int,default=1000) + parser.add_argument('--replica', type=int, default=5) + + return parser + + +def main(args,config): + misc.init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_config = config.config['dataset'] + dataset_config['category']=args.category + dataset_config['replica']=args.replica + dataset_config['num_objects']=args.num_objects + dataset_config['data_path']=args.data_pth + dataset_train = build_dataset('train', dataset_config) + print("training dataset len is %d"%(len(dataset_train))) + dataset_val=build_dataset('val', dataset_config) + #dataset_val = build_dataset('val', dataset_config) + + if True: # args.distributed: + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, + shuffle=True) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.log_dir is not None and not args.eval: + os.makedirs(args.log_dir, exist_ok=True) + log_writer = SummaryWriter(log_dir=args.log_dir) + else: + log_writer = None + + if misc.get_rank()==0: + log_dir=args.log_dir + src_folder="/data1/haolin/TriplaneDiffusion" + misc.log_codefiles(src_folder,log_dir+"/code_bak") + #cmd="cp -r %s %s"%(src_folder,log_dir+"/code_bak") + #print(cmd) + #os.system(cmd) + config_dict=vars(args) + config_save_path=os.path.join(log_dir,"config.json") + with open(config_save_path,'w') as f: + json.dump(config_dict,f,indent=4) + + model_dict=config + model_config_save_path=os.path.join(log_dir,"model.json") + config.write_config(model_config_save_path) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + prefetch_factor=2, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + # batch_size=args.batch_size, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + ae_args=config.config['model']['ae'] + ae = get_model(ae_args) + ae.eval() + print("Loading autoencoder %s" % args.ae_pth) + ae.load_state_dict(torch.load(args.ae_pth, map_location='cpu')['model']) + ae.to(device) + + dm_args=config.config['model']['dm'] + if args.category[0] == "all": + dm_args["use_cat_embedding"]=True + else: + dm_args["use_cat_embedding"] = False + dm_model = get_model(dm_args) + if args.sync_bn: + dm_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(dm_model) + if args.finetune: + print("finetune the model, load from %s"%(args.finetune_pth)) + dm_model.load_state_dict(torch.load(args.finetune_pth,map_location="cpu")['model']) + dm_model.to(device) + + model_without_ddp = dm_model + n_parameters = sum(p.numel() for p in dm_model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + dm_model = torch.nn.parallel.DistributedDataParallel(dm_model, device_ids=[args.gpu], find_unused_parameters=False) + model_without_ddp = dm_model.module + + # # build optimizer with layer-wise lr decay (lrd) + # param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, + # no_weight_decay_list=model_without_ddp.no_weight_decay(), + # layer_decay=args.layer_decay + # ) + optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr) + loss_scaler = NativeScaler() + + cri_args=config.config['criterion'] + criterion = get_criterion(cri_args) + + print("criterion = %s" % str(criterion)) + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if args.eval: + test_stats = evaluate(data_loader_val, dm_model, device) + print(f"loss of the network on the {len(dataset_val)} test images: {test_stats['loss']:.3f}") + exit(0) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + min_loss = 1000.0 + max_iou=0 + + stop_epochs=min(args.stop,args.epochs) + for epoch in range(args.start_epoch, stop_epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + #test_stats = evaluate_reconstruction(data_loader_val, dm_model, ae, criterion, device) + train_stats = train_one_epoch( + dm_model, ae, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, + log_writer=log_writer, + log_dir=args.log_dir, + args=args + ) + if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs): + misc.save_model( + args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch,prefix="latest") + + if epoch % 5 == 0 or epoch + 1 == args.epochs: + test_stats = evaluate_reconstruction(data_loader_val, dm_model, ae, criterion, device) + print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}") + # print(f"loss of the network on the {len(dataset_val)} test images: {test_stats['loss']:.3f}") + + if test_stats["iou"] > max_iou: + max_iou = test_stats["iou"] + misc.save_model( + args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, prefix='best') + else: + misc.save_model( + args=args, model=dm_model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, prefix='latest') + + if log_writer is not None: + log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) + log_writer.add_scalar('perf/test_iou', test_stats['iou'], epoch) + log_writer.add_scalar('perf/test_accuracy', test_stats['accuracy'], epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + config_path = args.configs + from configs.config_utils import CONFIG + + config = CONFIG(config_path) + main(args,config) diff --git a/scripts/train_triplane_vae.py b/scripts/train_triplane_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..8634b75b95f49f390f123d8b8c8d90d64f463ffe --- /dev/null +++ b/scripts/train_triplane_vae.py @@ -0,0 +1,287 @@ +import argparse +import datetime +import json +import numpy as np +import os,sys +sys.path.append("..") +# os.system("taskset -p 0xff %d"%(os.getpid())) +import time +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +torch.set_num_threads(8) + +import util.misc as misc +from datasets import build_dataset +from util.misc import NativeScalerWithGradNormCount as NativeScaler +from models import get_model + +from engine.engine_triplane_vae import train_one_epoch, evaluate + + +def get_args_parser(): + parser = argparse.ArgumentParser('Autoencoder', add_help=False) + parser.add_argument('--batch_size', default=64, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--epochs', default=800, type=int) + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + + # Optimizer parameters + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + parser.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1e-4, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--layer_decay', type=float, default=0.75, + help='layer-wise lr decay from ELECTRA/BEiT') + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + + parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', + help='epochs to warmup LR') + + parser.add_argument('--output_dir', default='./output/', + help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default='./output/', + help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', + help='resume from checkpoint') + parser.add_argument('--data-pth',default="../data",type=str) + + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, + help='Enabling distributed evaluation (recommended during training for faster monitor') + parser.add_argument('--num_workers', default=60, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=False) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + parser.add_argument('--configs',type=str) + parser.add_argument('--finetune', default=False, action="store_true") + parser.add_argument('--finetune-pth', type=str) + parser.add_argument('--category',type=str) + parser.add_argument('--replica',type=int,default=8) + + return parser + + +def main(args,config): + misc.init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_config=config.config['dataset'] + dataset_config['category']=args.category + dataset_config['replica']=args.replica + dataset_config['data_path']=args.data_pth + dataset_train = build_dataset('train',dataset_config) + dataset_val = build_dataset('val', dataset_config) + + if True: # args.distributed: + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, + shuffle=True) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if global_rank == 0 and args.log_dir is not None and not args.eval: + os.makedirs(args.log_dir, exist_ok=True) + log_writer = SummaryWriter(log_dir=args.log_dir) + else: + log_writer = None + + if misc.get_rank() == 0: + log_dir = args.log_dir + src_folder = "/data1/haolin/TriplaneDiffusion" + misc.log_codefiles(src_folder, log_dir + "/code_bak") + config_dict = vars(args) + config_save_path = os.path.join(log_dir, "config.json") + with open(config_save_path, 'w') as f: + json.dump(config_dict, f, indent=4) + model_config_path=os.path.join(log_dir,"setup.yaml") + config.write_config(model_config_path) + + print("dataset len", dataset_train.__len__()) + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + prefetch_factor=2, + ) + print("dataset len", dataset_train.__len__(), "dataloader len", len(data_loader_train)) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + # batch_size=args.batch_size, + batch_size=1, + # num_workers=args.num_workers, + num_workers=1, + pin_memory=args.pin_mem, + drop_last=False + ) + + #model = models_ae.__dict__[args.model](N=args.point_cloud_size) + model_config=config.config['model'] + model = get_model(model_config) + if args.finetune: + print("finetune the model, load from %s"%(args.finetune_pth)) + model.load_state_dict(torch.load(args.finetune_pth)['model']) + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) + model_without_ddp = model.module + + # # build optimizer with layer-wise lr decay (lrd) + # param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, + # no_weight_decay_list=model_without_ddp.no_weight_decay(), + # layer_decay=args.layer_decay + # ) + optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr) + loss_scaler = NativeScaler() + + criterion = torch.nn.BCEWithLogitsLoss() + + print("criterion = %s" % str(criterion)) + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if args.eval: + test_stats = evaluate(data_loader_val, model, device) + print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}") + exit(0) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_iou = 0.0 + for epoch in range(args.start_epoch, args.epochs): + # if args.distributed: + # data_loader_train.sampler.set_epoch(epoch) + #test_stats = evaluate(data_loader_val, model, device) + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, + log_writer=log_writer, + args=args + ) + # if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs): + # misc.save_model( + # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + # loss_scaler=loss_scaler, epoch=epoch) + + if epoch % 5 == 0 or epoch + 1 == args.epochs: + test_stats = evaluate(data_loader_val, model, device) + print(f"iou of the network on the {len(dataset_val)} test images: {test_stats['iou']:.3f}") + if test_stats["iou"] > max_iou: + max_iou = test_stats["iou"] + misc.save_model( + args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, prefix='best') + else: + misc.save_model( + args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, + loss_scaler=loss_scaler, epoch=epoch, prefix='latest') + # max_iou = max(max_iou, test_stats["iou"]) + print(f'Max iou: {max_iou:.2f}%') + + if log_writer is not None: + log_writer.add_scalar('perf/test_iou', test_stats['iou'], epoch) + log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if args.output_dir and misc.is_main_process(): + if log_writer is not None: + log_writer.flush() + with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: + f.write(json.dumps(log_stats) + "\n") + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + config_path=args.configs + from configs.config_utils import CONFIG + config=CONFIG(config_path) + main(args,config) \ No newline at end of file diff --git a/train_VAE.sh b/train_VAE.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab4390dd959c931924d7ab9add3415b1bbdbd621 --- /dev/null +++ b/train_VAE.sh @@ -0,0 +1,15 @@ +cd scripts +CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" torchrun --master_port 15000 --nproc_per_node=8 \ +train_triplane_vae.py \ +--configs ../configs/train_triplane_vae.yaml \ +--accum_iter 2 \ +--output_dir ../output/ae/chair \ +--log_dir ../output/ae/chair --num_workers 8 \ +--batch_size 22 \ +--epochs 200 \ +--warmup_epochs 5 \ +--dist_eval \ +--clip_grad 0.35 \ +--category chair \ +--data-pth ../data \ +--replica 5 \ No newline at end of file diff --git a/train_diffusion.sh b/train_diffusion.sh new file mode 100644 index 0000000000000000000000000000000000000000..2f8da321b46b90d98bfb71ba01f0948499911235 --- /dev/null +++ b/train_diffusion.sh @@ -0,0 +1,16 @@ +cd scripts +CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15004 --nproc_per_node=8 \ +train_triplane_diffusion.py \ +--configs ../configs/train_triplane_diffusion.yaml \ +--accum_iter 2 \ +--output_dir ../output/dm/debug \ +--log_dir ../output/dm/debug \ +--num_workers 8 \ +--batch_size 22 \ +--epochs 1000 \ +--dist_eval \ +--warmup_epochs 40 \ +--category chair \ +--ae-pth ../output/ae/chair/best-checkpoint.pth \ +--data-pth ../data \ +--replica 5 \ No newline at end of file diff --git a/util/lr_sched.py b/util/lr_sched.py new file mode 100644 index 0000000000000000000000000000000000000000..04697fc1a34d4a736cb9c86dcaaada150be7e78c --- /dev/null +++ b/util/lr_sched.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr \ No newline at end of file diff --git a/util/misc.py b/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c9c1334249434d8b6e07dadbb4debee75dfde6 --- /dev/null +++ b/util/misc.py @@ -0,0 +1,358 @@ +# -------------------------------------------------------- +# References: +# MAE: https://github.com/facebookresearch/mae +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf +import numpy as np + +def log_codefiles(data_root,save_root): + import glob + import shutil + print("saving codes to",data_root,glob.glob(data_root+"/*.py")) + all_files=glob.glob(data_root+"/*.py")+glob.glob(data_root+"/*.sh")+glob.glob(data_root+"/*/*.py")+glob.glob(data_root+"/*/*.sh")+ \ + glob.glob(data_root + "/*/*/*.py") + glob.glob(data_root + "/*/*/*.sh") + for file in all_files: + rel_path=os.path.relpath(file,data_root) + dst_path=os.path.join(save_root,rel_path) + if os.path.exists(os.path.dirname(dst_path))==False: + os.makedirs(os.path.dirname(dst_path)) + shutil.copyfile(file,dst_path) + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + #print(name,meter) + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + #print("available threads",torch.get_num_threads()) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + #print(str(self)) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + +def worker_init_fn(worker_id): + random_data = os.urandom(4) + base_seed = int.from_bytes(random_data, byteorder="big") + np.random.seed(base_seed + worker_id) + #torch.random.seed(base_seed+worker_id) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, prefix="latest"): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('%s-checkpoint.pth' % (prefix))] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + model.save_checkpoint(save_dir=args.output_dir, tag='%s-checkpoint.pth' % (prefix), client_state=client_state) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x \ No newline at end of file diff --git a/util/projection_utils.py b/util/projection_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc8721469a1ef5ef01cc793a875552ce13c36fc --- /dev/null +++ b/util/projection_utils.py @@ -0,0 +1,12 @@ +import numpy as np +import cv2 +def draw_proj_image(image,proj_mat,points): + points_homo=np.concatenate([points,np.ones((points.shape[0],1))],axis=1) + pts_inimg=np.dot(points_homo,proj_mat.T) + image=cv2.resize(image,dsize=(224,224),interpolation=cv2.INTER_LINEAR) + x=pts_inimg[:,0]/pts_inimg[:,2] + y=pts_inimg[:,1]/pts_inimg[:,2] + x=np.clip(x,a_min=0,a_max=223).astype(np.int32) + y=np.clip(y,a_min=0,a_max=223).astype(np.int32) + image[y,x]=np.array([[0,255,0]]) + return image \ No newline at end of file diff --git a/util/simple_image_loader.py b/util/simple_image_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..92baa973b873e0a59b6435a9632da1e820211baa --- /dev/null +++ b/util/simple_image_loader.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +from torch.utils import data +import os +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +try: + from torchvision.transforms import InterpolationMode + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC +import glob + +def image_transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + ToTensor(), + # Normalize((123.675/255.0,116.28/255.0,103.53/255.0), + # (58.395/255.0,57.12/255.0,57.375/255.0)) + Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + # Normalize((0.5, 0.5, 0.5), + # (0.5, 0.5, 0.5)), + ]) + +class Image_dataset(data.Dataset): + def __init__(self,dataset_folder="/data1/haolin/datasets",categories=['03001627'],n_px=224): + self.dataset_folder=dataset_folder + self.image_folder=os.path.join(self.dataset_folder,'other_data') + self.preprocess=image_transform(n_px) + self.image_path=[] + for cat in categories: + subpath=os.path.join(self.image_folder,cat,"6_images") + model_list=os.listdir(subpath) + for folder in model_list: + model_folder=os.path.join(subpath,folder) + image_list=os.listdir(model_folder) + for image_filename in image_list: + image_filepath=os.path.join(model_folder,image_filename) + self.image_path.append(image_filepath) + def __len__(self): + return len(self.image_path) + + def __getitem__(self,index): + path=self.image_path[index] + basename=os.path.basename(path)[:-4] + model_id=path.split(os.sep)[-2] + category=path.split(os.sep)[-4] + image=Image.open(path) + image_tensor=self.preprocess(image) + + return {"images":image_tensor,"image_name":basename,"model_id":model_id,"category":category} + +class Image_InTheWild_dataset(data.Dataset): + def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224): + self.dataset_dir=dataset_dir + self.preprocess = image_transform(n_px) + self.image_path = [] + if scene_id=="all": + scene_list=os.listdir(self.dataset_dir) + for id in scene_list: + image_folder=os.path.join(self.dataset_dir,id,"6_images") + self.image_path+=glob.glob(image_folder+"/*/*jpg") + else: + image_folder = os.path.join(self.dataset_dir, scene_id, "6_images") + self.image_path += glob.glob(image_folder + "/*/*jpg") + def __len__(self): + return len(self.image_path) + + def __getitem__(self,index): + path=self.image_path[index] + basename=os.path.basename(path)[:-4] + model_id=path.split(os.sep)[-2] + scene_id=path.split(os.sep)[-4] + image=Image.open(path) + image_tensor=self.preprocess(image) + + return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id} + diff --git a/util/train_test_utils.py b/util/train_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391