diff --git "a/3DTopia/model/triplane_vae.py" "b/3DTopia/model/triplane_vae.py" new file mode 100644--- /dev/null +++ "b/3DTopia/model/triplane_vae.py" @@ -0,0 +1,2656 @@ +import os +import imageio +import torch +import wandb +import numpy as np +import pytorch_lightning as pl +import torch.nn.functional as F + +from module.model_2d import Encoder, Decoder, DiagonalGaussianDistribution, Encoder_GroupConv, Decoder_GroupConv, Encoder_GroupConv_LateFusion, Decoder_GroupConv_LateFusion +from utility.initialize import instantiate_from_config +from utility.triplane_renderer.renderer import get_embedder, NeRF, run_network, render_path1, to8b, img2mse, mse2psnr +from utility.triplane_renderer.eg3d_renderer import Renderer_TriPlane + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + learning_rate=1e-3, + ckpt_path=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + decoder_ckpt=None, + norm=False, + renderer_type='nerf', + renderer_config=dict( + rgbnet_dim=18, + rgbnet_width=128, + viewpe=0, + feape=0 + ), + ): + super().__init__() + self.save_hyperparameters() + self.norm = norm + self.renderer_config = renderer_config + self.learning_rate = learning_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + # self.loss = instantiate_from_config(lossconfig) + self.lossconfig = lossconfig + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.decoder_ckpt = decoder_ckpt + self.renderer_type = renderer_type + # if decoder_ckpt is not None: + assert self.renderer_type in ['nerf', 'eg3d'] + if self.renderer_type == 'nerf': + self.triplane_decoder, self.triplane_render_kwargs = self.create_nerf(decoder_ckpt) + elif self.renderer_type == 'eg3d': + self.triplane_decoder, self.triplane_render_kwargs = self.create_eg3d_decoder(decoder_ckpt) + else: + raise NotImplementedError + + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + self.latent_list = [] + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x, rollout=False): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def unrollout(self, *args, **kwargs): + pass + + def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): + reconstructions = reconstructions.contiguous() + rec_loss = torch.abs(inputs.contiguous() - reconstructions) + rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss + + ret_dict = { + prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), + prefix+'rec_loss': rec_loss, + prefix+'kl_loss': kl_loss, + prefix+'loss': loss, + prefix+'mean': posteriors.mean.mean(), + prefix+'logvar': posteriors.logvar.mean(), + } + + render_weight = self.lossconfig.get("render_weight", 0) + tv_weight = self.lossconfig.get("tv_weight", 0) + l1_weight = self.lossconfig.get("l1_weight", 0) + latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) + latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) + + triplane_rec = self.unrollout(reconstructions) + if render_weight > 0 and batch is not None: + rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) + render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 + loss += render_weight * render_loss + ret_dict[prefix + 'render_loss'] = render_loss + if tv_weight > 0: + tvloss_y = torch.abs(triplane_rec[:, :, :-1] - triplane_rec[:, :, 1:]).sum() / triplane_rec.shape[0] + tvloss_x = torch.abs(triplane_rec[:, :, :, :-1] - triplane_rec[:, :, :, 1:]).sum() / triplane_rec.shape[0] + tvloss = tvloss_y + tvloss_x + loss += tv_weight * tvloss + ret_dict[prefix + 'tv_loss'] = tvloss + if l1_weight > 0: + l1 = (triplane_rec ** 2).sum() / triplane_rec.shape[0] + loss += l1_weight * l1 + ret_dict[prefix + 'l1_loss'] = l1 + if latent_tv_weight > 0: + latent = posteriors.mean + latent_tv_y = torch.abs(latent[:, :, :-1] - latent[:, :, 1:]).sum() / latent.shape[0] + latent_tv_x = torch.abs(latent[:, :, :, :-1] - latent[:, :, :, 1:]).sum() / latent.shape[0] + latent_tv_loss = latent_tv_y + latent_tv_x + loss += latent_tv_loss * latent_tv_weight + ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss + ret_dict[prefix + 'latent_max'] = latent.max() + ret_dict[prefix + 'latent_min'] = latent.min() + if latent_l1_weight > 0: + latent = posteriors.mean + latent_l1_loss = (latent ** 2).sum() / latent.shape[0] + loss += latent_l1_loss * latent_l1_weight + ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss + + return loss, ret_dict + + def training_step(self, batch, batch_idx): + # inputs = self.get_input(batch, self.image_key) + inputs = batch['triplane'] + reconstructions, posterior = self(inputs) + + # if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + # self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + # if optimizer_idx == 1: + # # train the discriminator + # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + # last_layer=self.get_last_layer(), split="train") + + # self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + # return discloss + + def validation_step(self, batch, batch_idx): + # # inputs = self.get_input(batch, self.image_key) + # inputs = batch['triplane'] + # reconstructions, posterior = self(inputs, sample_posterior=False) + # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + + # # discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + # # last_layer=self.get_last_layer(), split="val") + + # # self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + # self.log_dict(log_dict_ae) + # # self.log_dict(log_dict_disc) + # return self.log_dict + + inputs = batch['triplane'] + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def create_eg3d_decoder(self, decoder_ckpt): + triplane_decoder = Renderer_TriPlane(**self.renderer_config) + if decoder_ckpt is not None: + pretrain_pth = torch.load(decoder_ckpt, map_location='cpu') + pretrain_pth = { + '.'.join(k.split('.')[1:]): v for k, v in pretrain_pth.items() + } + triplane_decoder.load_state_dict(pretrain_pth) + render_kwargs = { + 'depth_resolution': 128, + 'disparity_space_sampling': False, + 'box_warp': 2.4, + 'depth_resolution_importance': 128, + 'clamp_mode': 'softplus', + 'white_back': True, + 'det': True + } + return triplane_decoder, render_kwargs + + def render_triplane_eg3d_decoder(self, triplane, batch_rays, target): + ray_o = batch_rays[:, 0] + ray_d = batch_rays[:, 1] + psnr_list = [] + rec_img_list = [] + res = triplane.shape[-2] + for i in range(ray_o.shape[0]): + with torch.no_grad(): + render_out = self.triplane_decoder(triplane.reshape(1, 3, -1, res, res), + ray_o[i:i+1], ray_d[i:i+1], self.triplane_render_kwargs, whole_img=True, tvloss=False) + rec_img = render_out['rgb_marched'].permute(0, 2, 3, 1) + psnr = mse2psnr(img2mse(rec_img[0], target[i])) + psnr_list.append(psnr) + rec_img_list.append(rec_img) + return torch.cat(rec_img_list, 0), psnr_list + + def render_triplane_eg3d_decoder_sample_pixel(self, triplane, batch_rays, target, sample_num=1024): + assert batch_rays.shape[1] == 1 + sel = torch.randint(batch_rays.shape[-2], [sample_num]) + ray_o = batch_rays[:, 0, 0, sel] + ray_d = batch_rays[:, 0, 1, sel] + res = triplane.shape[-2] + render_out = self.triplane_decoder(triplane.reshape(triplane.shape[0], 3, -1, res, res), + ray_o, ray_d, self.triplane_render_kwargs, whole_img=False, tvloss=False) + rec_img = render_out['rgb_marched'] + target = target.reshape(triplane.shape[0], -1, 3)[:, sel, :] + return rec_img, target + + def create_nerf(self, decoder_ckpt): + # decoder_ckpt = '/mnt/petrelfs/share_data/caoziang/shapenet_triplane_car/003000.tar' + + multires = 10 + netchunk = 1024*64 + i_embed = 0 + perturb = 0 + raw_noise_std = 0 + + triplanechannel=18 + triplanesize=256 + chunk=4096 + num_instance=1 + batch_size=1 + use_viewdirs = True + white_bkgd = False + lrate_decay = 6 + netdepth=1 + netwidth=64 + N_samples = 512 + N_importance = 0 + N_rand = 8192 + multires_views=10 + precrop_iters = 0 + precrop_frac = 0.5 + i_weights=3000 + + embed_fn, input_ch = get_embedder(multires, i_embed) + embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed) + output_ch = 4 + skips = [4] + model = NeRF(D=netdepth, W=netwidth, + input_ch=triplanechannel, size=triplanesize,output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=use_viewdirs, num_instance=num_instance) + + network_query_fn = lambda inputs, viewdirs, label,network_fn : \ + run_network(inputs, viewdirs, network_fn, + embed_fn=embed_fn, + embeddirs_fn=embeddirs_fn,label=label, + netchunk=netchunk) + + ckpt = torch.load(decoder_ckpt) + model.load_state_dict(ckpt['network_fn_state_dict']) + + render_kwargs_test = { + 'network_query_fn' : network_query_fn, + 'perturb' : perturb, + 'N_samples' : N_samples, + # 'network_fn' : model, + 'use_viewdirs' : use_viewdirs, + 'white_bkgd' : white_bkgd, + 'raw_noise_std' : raw_noise_std, + } + render_kwargs_test['ndc'] = False + render_kwargs_test['lindisp'] = False + render_kwargs_test['perturb'] = False + render_kwargs_test['raw_noise_std'] = 0. + + return model, render_kwargs_test + + def render_triplane(self, triplane, batch_rays, target, near, far, chunk=4096): + self.triplane_decoder.tri_planes.copy_(triplane.detach()) + self.triplane_render_kwargs['network_fn'] = self.triplane_decoder + # print(triplane.device) + # print(batch_rays.device) + # print(target.device) + # print(near.device) + # print(far.device) + with torch.no_grad(): + rgb, _, _, psnr_list = \ + render_path1(batch_rays, chunk, self.triplane_render_kwargs, gt_imgs=target, + near=near, far=far, label=torch.Tensor([0]).long().to(triplane.device)) + return rgb, psnr_list + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + # inputs = batch['triplane'] + # reconstructions, posterior = self(inputs, sample_posterior=False) + # aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + # self.log_dict(log_dict_ae) + + # batch_size = inputs.shape[0] + # psnr_list = [] # between rec and gt + # psnr_input_list = [] # between input and gt + # psnr_rec_list = [] # between input and rec + + # mean = torch.Tensor([ + # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115, + # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098, + # 0.4101, -0.2922 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + # std = torch.Tensor([ + # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949, + # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + + # if self.norm: + # reconstructions_unnormalize = reconstructions * std + mean + # else: + # reconstructions_unnormalize = reconstructions + + # for b in range(batch_size): + # # rgb_input, cur_psnr_list_input = self.render_triplane( + # # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # # ) + # # rgb, cur_psnr_list = self.render_triplane( + # # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # # ) + + # if self.renderer_type == 'nerf': + # rgb_input, cur_psnr_list_input = self.render_triplane( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # rgb, cur_psnr_list = self.render_triplane( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # elif self.renderer_type == 'eg3d': + # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # else: + # raise NotImplementedError + + # cur_psnr_list_rec = [] + # for i in range(rgb.shape[0]): + # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + # rgb_input = to8b(rgb_input.detach().cpu().numpy()) + # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + # rgb = to8b(rgb.detach().cpu().numpy()) + + # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + # psnr_list += cur_psnr_list + # psnr_input_list += cur_psnr_list_input + # psnr_rec_list += cur_psnr_list_rec + + # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + inputs = batch['triplane'] + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + # lr=lr, betas=(0.5, 0.9)) + # return [opt_ae, opt_disc], [] + return opt_ae + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + +class AutoencoderKLRollOut(AutoencoderKL): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.rollout(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + # mean = torch.Tensor([ + # 0.2820, 0.4103, -0.2988, 0.1491, 0.4429, -0.3117, 0.2830, 0.4115, + # -0.3032, 0.1530, 0.4466, -0.3165, 0.2617, 0.3837, -0.2692, 0.1098, + # 0.4101, -0.2922 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + # std = torch.Tensor([ + # 1.1696, 1.1287, 1.1733, 1.1583, 1.1238, 1.1675, 1.1978, 1.1585, 1.1949, + # 1.1660, 1.1576, 1.1998, 1.1987, 1.1546, 1.1930, 1.1724, 1.1450, 1.2027 + # ]).reshape(1, 18, 1, 1).to(inputs.device) + + mean = torch.Tensor([ + -1.8449, -1.8242, 0.9667, -1.0187, 1.0647, -0.5422, -1.8632, -1.8435, + 0.9314, -1.0261, 1.0356, -0.5484, -1.8543, -1.8348, 0.9109, -1.0169, + 1.0160, -0.5467 + ]).reshape(1, 18, 1, 1).to(inputs.device) + std = torch.Tensor([ + 1.7593, 1.6127, 2.7132, 1.5500, 2.7893, 0.7707, 2.1114, 1.9198, 2.6586, + 1.8021, 2.5473, 1.0305, 1.7042, 1.7507, 2.4270, 1.4365, 2.2511, 0.8792 + ]).reshape(1, 18, 1, 1).to(inputs.device) + + if self.norm: + reconstructions_unnormalize = reconstructions * std + mean + else: + reconstructions_unnormalize = reconstructions + + # for b in range(batch_size): + # if self.renderer_type == 'nerf': + # rgb_input, cur_psnr_list_input = self.render_triplane( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # rgb, cur_psnr_list = self.render_triplane( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + # batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + # ) + # elif self.renderer_type == 'eg3d': + # rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + # batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + # reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + # ) + # else: + # raise NotImplementedError + + # cur_psnr_list_rec = [] + # for i in range(rgb.shape[0]): + # cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + # rgb_input = to8b(rgb_input.detach().cpu().numpy()) + # rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + # rgb = to8b(rgb.detach().cpu().numpy()) + + # # if batch_idx < 1: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + # psnr_list += cur_psnr_list + # psnr_input_list += cur_psnr_list_input + # psnr_rec_list += cur_psnr_list_rec + + # self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + # self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + # self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAware(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + ddconfig['z_channels'] *= 3 + del self.decoder + del self.post_quant_conv + self.decoder = Decoder(**ddconfig) + self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + res = inputs.shape[1] + colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAwareOnlyInput(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + # ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + # self.decoder = Decoder(**ddconfig) + # self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLRollOut3DAwareMeanPool(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + ddconfig['z_channels'] *= 3 + self.decoder = Decoder(**ddconfig) + self.post_quant_conv = torch.nn.Conv2d(kwargs['embed_dim'] * 3, ddconfig["z_channels"], 1) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.AvgPool2d((res, 1)) + y_mp = torch.nn.AvgPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.to3daware(self.rollout(x)) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs)) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/') + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/') + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(self.to3daware(inputs), sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/') + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + res = inputs.shape[1] + colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + +class AutoencoderKLGroupConv(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + del self.decoder + del self.encoder + self.encoder = Encoder_GroupConv(**ddconfig) + self.decoder = Decoder_GroupConv(**ddconfig) + + if "mean" in ddconfig: + print("Using mean std!!") + self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + else: + self.triplane_mean = None + self.triplane_std = None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + # x = self.to3daware(self.rollout(x)) + x = self.rollout(x) + if self.triplane_mean is not None: + x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if self.triplane_mean is not None: + dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + rgb_input = np.stack([rgb_input[..., 2], rgb_input[..., 1], rgb_input[..., 0]], -1) + rgb = np.stack([rgb[..., 2], rgb[..., 1], rgb[..., 0]], -1) + + if b % 2 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)], + "val/latent_vis": [wandb.Image(colorize_z)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + + import os + import random + import string + # z_np = z.detach().cpu().numpy() + z_np = inputs.detach().cpu().numpy() + fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + np.save(f, z_np) + + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 0: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + if True: + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + def loss(self, inputs, reconstructions, posteriors, prefix, batch=None): + reconstructions = reconstructions.contiguous() + # rec_loss = torch.abs(inputs.contiguous() - reconstructions) + # rec_loss = torch.sum(rec_loss) / rec_loss.shape[0] + rec_loss = F.mse_loss(inputs.contiguous(), reconstructions) + kl_loss = posteriors.kl() + # kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + kl_loss = kl_loss.mean() + loss = self.lossconfig.rec_weight * rec_loss + self.lossconfig.kl_weight * kl_loss + + ret_dict = { + prefix+'mean_rec_loss': torch.abs(inputs.contiguous() - reconstructions.contiguous()).mean().detach(), + prefix+'rec_loss': rec_loss, + prefix+'kl_loss': kl_loss, + prefix+'loss': loss, + prefix+'mean': posteriors.mean.mean(), + prefix+'logvar': posteriors.logvar.mean(), + } + + + latent = posteriors.mean + ret_dict[prefix + 'latent_max'] = latent.max() + ret_dict[prefix + 'latent_min'] = latent.min() + + render_weight = self.lossconfig.get("render_weight", 0) + tv_weight = self.lossconfig.get("tv_weight", 0) + l1_weight = self.lossconfig.get("l1_weight", 0) + latent_tv_weight = self.lossconfig.get("latent_tv_weight", 0) + latent_l1_weight = self.lossconfig.get("latent_l1_weight", 0) + + triplane_rec = self.unrollout(reconstructions) + if render_weight > 0 and batch is not None: + rgb_rendered, target = self.render_triplane_eg3d_decoder_sample_pixel(triplane_rec, batch['batch_rays'], batch['img']) + # render_loss = ((rgb_rendered - target) ** 2).sum() / rgb_rendered.shape[0] * 256 + render_loss = F.mse_loss(rgb_rendered, target) + loss += render_weight * render_loss + ret_dict[prefix + 'render_loss'] = render_loss + if tv_weight > 0: + tvloss_y = F.mse_loss(triplane_rec[:, :, :-1], triplane_rec[:, :, 1:]) + tvloss_x = F.mse_loss(triplane_rec[:, :, :, :-1], triplane_rec[:, :, :, 1:]) + tvloss = tvloss_y + tvloss_x + loss += tv_weight * tvloss + ret_dict[prefix + 'tv_loss'] = tvloss + if l1_weight > 0: + l1 = (triplane_rec ** 2).mean() + loss += l1_weight * l1 + ret_dict[prefix + 'l1_loss'] = l1 + if latent_tv_weight > 0: + latent = posteriors.mean + latent_tv_y = F.mse_loss(latent[:, :, :-1], latent[:, :, 1:]) + latent_tv_x = F.mse_loss(latent[:, :, :, :-1], latent[:, :, :, 1:]) + latent_tv_loss = latent_tv_y + latent_tv_x + loss += latent_tv_loss * latent_tv_weight + ret_dict[prefix + 'latent_tv_loss'] = latent_tv_loss + if latent_l1_weight > 0: + latent = posteriors.mean + latent_l1_loss = (latent ** 2).mean() + loss += latent_l1_loss * latent_l1_weight + ret_dict[prefix + 'latent_l1_loss'] = latent_l1_loss + + return loss, ret_dict + + +class AutoencoderKLGroupConvLateFusion(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + del self.decoder + del self.encoder + self.encoder = Encoder_GroupConv_LateFusion(**ddconfig) + self.decoder = Decoder_GroupConv_LateFusion(**ddconfig) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + x = self.rollout(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + z = self.post_quant_conv(z) + dec = self.decoder(z) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if batch_idx < 10: + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median)) + + +from module.model_2d import ViTEncoder, ViTDecoder + +class AutoencoderVIT(AutoencoderKL): + def __init__(self, *args, **kwargs): + try: + ckpt_path = kwargs['ckpt_path'] + kwargs['ckpt_path'] = None + except: + ckpt_path = None + + super().__init__(*args, **kwargs) + self.latent_list = [] + self.psum = torch.zeros([1]) + self.psum_sq = torch.zeros([1]) + self.psum_min = torch.zeros([1]) + self.psum_max = torch.zeros([1]) + self.count = 0 + self.len_dset = 0 + + ddconfig = kwargs['ddconfig'] + # ddconfig['z_channels'] *= 3 + del self.decoder + del self.encoder + del self.quant_conv + del self.post_quant_conv + + assert ddconfig["z_channels"] == 256 + self.encoder = ViTEncoder( + image_size=(256, 256*3), + patch_size=(256//32, 256//32), + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + channels=8) + self.decoder = ViTDecoder( + image_size=(256, 256*3), + patch_size=(256//32, 256//32), + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + channels=8) + + self.quant_conv = torch.nn.Conv2d(768, 2*self.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, 768, 1) + + if "mean" in ddconfig: + print("Using mean std!!") + self.triplane_mean = torch.Tensor(ddconfig['mean']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + self.triplane_std = torch.Tensor(ddconfig['std']).reshape(-1).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float() + else: + self.triplane_mean = None + self.triplane_std = None + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + def rollout(self, triplane): + res = triplane.shape[-1] + ch = triplane.shape[1] + triplane = triplane.reshape(-1, 3, ch//3, res, res).permute(0, 2, 3, 1, 4).reshape(-1, ch//3, res, 3 * res) + return triplane + + def to3daware(self, triplane): + res = triplane.shape[-2] + plane1 = triplane[..., :res] + plane2 = triplane[..., res:2*res] + plane3 = triplane[..., 2*res:3*res] + + x_mp = torch.nn.MaxPool2d((res, 1)) + y_mp = torch.nn.MaxPool2d((1, res)) + x_mp_rep = lambda i: x_mp(i).repeat(1, 1, res, 1).permute(0, 1, 3, 2) + y_mp_rep = lambda i: y_mp(i).repeat(1, 1, 1, res).permute(0, 1, 3, 2) + # for plane1 + plane21 = x_mp_rep(plane2) + plane31 = torch.flip(y_mp_rep(plane3), (3,)) + new_plane1 = torch.cat([plane1, plane21, plane31], 1) + # for plane2 + plane12 = y_mp_rep(plane1) + plane32 = x_mp_rep(plane3) + new_plane2 = torch.cat([plane2, plane12, plane32], 1) + # for plane3 + plane13 = torch.flip(x_mp_rep(plane1), (2,)) + plane23 = y_mp_rep(plane2) + new_plane3 = torch.cat([plane3, plane13, plane23], 1) + + new_plane = torch.cat([new_plane1, new_plane2, new_plane3], -1).contiguous() + return new_plane + + def unrollout(self, triplane): + res = triplane.shape[-2] + ch = 3 * triplane.shape[1] + triplane = triplane.reshape(-1, ch//3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, ch, res, res) + return triplane + + def encode(self, x, rollout=False): + if rollout: + # x = self.to3daware(self.rollout(x)) + x = self.rollout(x) + if self.triplane_mean is not None: + x = (x - self.triplane_mean.to(x.device)) / self.triplane_std.to(x.device) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z, unrollout=False): + # z = self.to3daware(z) + z = self.post_quant_conv(z) + dec = self.decoder(z) + if self.triplane_mean is not None: + dec = dec * self.triplane_std.to(dec.device) + self.triplane_mean.to(dec.device) + if unrollout: + dec = self.unrollout(dec) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def training_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='train/', batch=batch) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + def validation_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='val/', batch=None) + self.log_dict(log_dict_ae) + + assert not self.norm + reconstructions = self.unrollout(reconstructions) + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + batch_size = inputs.shape[0] + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + if b % 4 == 0 and batch_idx < 10: + rgb_all = np.concatenate([rgb_gt[1], rgb_input[1], rgb[1]], 1) + self.logger.experiment.log({ + "val/vis": [wandb.Image(rgb_all)] + }) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("val/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("val/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("val/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + return self.log_dict + + def to_rgb(self, plane): + x = plane.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_triplane(self, plane): + x = plane.float() + if not hasattr(self, "colorize_triplane"): + self.colorize_triplane = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_triplane) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def to_rgb_3daware(self, plane): + x = plane.float() + if not hasattr(self, "colorize_3daware"): + self.colorize_3daware = torch.randn(3, x.shape[1], 1, 1).to(x) + x = torch.nn.functional.conv2d(x, weight=self.colorize_3daware) + x = ((x - x.min()) / (x.max() - x.min()) * 255.).permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + return x + + def test_step(self, batch, batch_idx): + inputs = self.rollout(batch['triplane']) + reconstructions, posterior = self(inputs, sample_posterior=False) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, prefix='test/', batch=None) + self.log_dict(log_dict_ae) + + batch_size = inputs.shape[0] + psnr_list = [] # between rec and gt + psnr_input_list = [] # between input and gt + psnr_rec_list = [] # between input and rec + + z = posterior.mode() + colorize_z = self.to_rgb(z)[0] + colorize_triplane_input = self.to_rgb_triplane(inputs)[0] + colorize_triplane_output = self.to_rgb_triplane(reconstructions)[0] + + import os + import random + import string + # z_np = z.detach().cpu().numpy() + z_np = inputs.detach().cpu().numpy() + fname = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + '.npy' + with open(os.path.join('/mnt/lustre/hongfangzhou.p/AE3D/tmp', fname), 'wb') as f: + np.save(f, z_np) + + # colorize_triplane_rollout_3daware = self.to_rgb_3daware(self.to3daware(inputs))[0] + # res = inputs.shape[1] + # colorize_triplane_rollout_3daware_1 = self.to_rgb_triplane(self.to3daware(inputs)[:,res:2*res])[0] + # colorize_triplane_rollout_3daware_2 = self.to_rgb_triplane(self.to3daware(inputs)[:,2*res:3*res])[0] + # if batch_idx < 10: + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_z_{}.png".format(batch_idx)), colorize_z) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_{}.png".format(batch_idx)), colorize_triplane_input) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_output_{}.png".format(batch_idx)), colorize_triplane_output) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}.png".format(batch_idx)), colorize_triplane_rollout_3daware) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_1.png".format(batch_idx)), colorize_triplane_rollout_3daware_1) + # imageio.imwrite(os.path.join(self.logger.log_dir, "colorize_input_3daware_{}_2.png".format(batch_idx)), colorize_triplane_rollout_3daware_2) + + np_z = z.detach().cpu().numpy() + # with open(os.path.join(self.logger.log_dir, "latent_{}.npz".format(batch_idx)), 'wb') as f: + # np.save(f, np_z) + + self.latent_list.append(np_z) + + reconstructions = self.unrollout(reconstructions) + + if self.psum.device != z.device: + self.psum = self.psum.to(z.device) + self.psum_sq = self.psum_sq.to(z.device) + self.psum_min = self.psum_min.to(z.device) + self.psum_max = self.psum_max.to(z.device) + self.psum += z.sum() + self.psum_sq += (z ** 2).sum() + self.psum_min += z.reshape(-1).min(-1)[0] + self.psum_max += z.reshape(-1).max(-1)[0] + assert len(z.shape) == 4 + self.count += z.shape[0] * z.shape[1] * z.shape[2] * z.shape[3] + self.len_dset += 1 + + if self.norm: + assert NotImplementedError + else: + reconstructions_unnormalize = reconstructions + + if True: + for b in range(batch_size): + if self.renderer_type == 'nerf': + rgb_input, cur_psnr_list_input = self.render_triplane( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + rgb, cur_psnr_list = self.render_triplane( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img_flat'][b], + batch['near'][b].unsqueeze(-1), batch['far'][b].unsqueeze(-1) + ) + elif self.renderer_type == 'eg3d': + rgb_input, cur_psnr_list_input = self.render_triplane_eg3d_decoder( + batch['triplane_ori'][b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + rgb, cur_psnr_list = self.render_triplane_eg3d_decoder( + reconstructions_unnormalize[b:b+1], batch['batch_rays'][b], batch['img'][b], + ) + else: + raise NotImplementedError + + cur_psnr_list_rec = [] + for i in range(rgb.shape[0]): + cur_psnr_list_rec.append(mse2psnr(img2mse(rgb_input[i], rgb[i]))) + + rgb_input = to8b(rgb_input.detach().cpu().numpy()) + rgb_gt = to8b(batch['img'][b].detach().cpu().numpy()) + rgb = to8b(rgb.detach().cpu().numpy()) + + # if batch_idx < 10: + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_input.png".format(batch_idx, b)), rgb_input[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_rec.png".format(batch_idx, b)), rgb[1]) + # imageio.imwrite(os.path.join(self.logger.log_dir, "{}_{}_gt.png".format(batch_idx, b)), rgb_gt[1]) + + psnr_list += cur_psnr_list + psnr_input_list += cur_psnr_list_input + psnr_rec_list += cur_psnr_list_rec + + self.log("test/psnr_input_gt", torch.Tensor(psnr_input_list).mean(), prog_bar=True) + self.log("test/psnr_input_rec", torch.Tensor(psnr_rec_list).mean(), prog_bar=True) + self.log("test/psnr_rec_gt", torch.Tensor(psnr_list).mean(), prog_bar=True) + + def on_test_epoch_end(self): + mean = self.psum / self.count + mean_min = self.psum_min / self.len_dset + mean_max = self.psum_max / self.len_dset + var = (self.psum_sq / self.count) - (mean ** 2) + std = torch.sqrt(var) + + print("mean min: {}".format(mean_min)) + print("mean max: {}".format(mean_max)) + print("mean: {}".format(mean)) + print("std: {}".format(std)) + + latent = np.concatenate(self.latent_list) + q75, q25 = np.percentile(latent.reshape(-1), [75 ,25]) + median = np.median(latent.reshape(-1)) + iqr = q75 - q25 + norm_iqr = iqr * 0.7413 + print("Norm IQR: {}".format(norm_iqr)) + print("Inverse Norm IQR: {}".format(1/norm_iqr)) + print("Median: {}".format(median))