Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import matplotlib.pyplot as plt | |
import mcubes | |
import trimesh | |
import functools | |
import json | |
import os | |
from pathlib import Path | |
from pdb import set_trace as st | |
import traceback | |
# from nsr.gs import GaussianRenderer | |
from nsr.gs_surfel import GaussianRenderer2DGS | |
import blobfile as bf | |
import imageio | |
import numpy as np | |
# from sympy import O | |
import torch as th | |
import torch.distributed as dist | |
import torchvision | |
from PIL import Image | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from torch.optim import AdamW | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm, trange | |
from guided_diffusion import dist_util, logger | |
from guided_diffusion.fp16_util import MixedPrecisionTrainer | |
from guided_diffusion.nn import update_ema | |
from guided_diffusion.resample import LossAwareSampler, UniformSampler | |
from guided_diffusion.train_util import (calc_average_loss, | |
find_ema_checkpoint, | |
find_resume_checkpoint, | |
get_blob_logdir, log_rec3d_loss_dict, | |
parse_resume_step_from_filename) | |
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics | |
# from ..guided_diffusion.train_util import TrainLoop | |
def flip_yaw(pose_matrix): | |
flipped = pose_matrix.clone() | |
flipped[:, 0, 1] *= -1 | |
flipped[:, 0, 2] *= -1 | |
flipped[:, 1, 0] *= -1 | |
flipped[:, 2, 0] *= -1 | |
flipped[:, 0, 3] *= -1 | |
# st() | |
return flipped | |
# basic reconstruction model | |
class TrainLoopBasic: | |
def __init__( | |
self, | |
*, | |
rec_model, | |
loss_class, | |
# diffusion, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=1e-3, | |
# schedule_sampler=None, | |
weight_decay=0.0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
compile=False, | |
**kwargs): | |
self.pool_512 = th.nn.AdaptiveAvgPool2d((512, 512)) | |
self.pool_256 = th.nn.AdaptiveAvgPool2d((256, 256)) | |
self.pool_128 = th.nn.AdaptiveAvgPool2d((128, 128)) | |
self.pool_64 = th.nn.AdaptiveAvgPool2d((64, 64)) | |
self.rec_model = rec_model | |
self.loss_class = loss_class | |
# self.diffusion = diffusion | |
# self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) | |
self.data = data | |
self.eval_data = eval_data | |
self.batch_size = batch_size | |
self.microbatch = microbatch if microbatch > 0 else batch_size | |
self.lr = lr | |
self.ema_rate = ([ema_rate] if isinstance(ema_rate, float) else | |
[float(x) for x in ema_rate.split(",")]) | |
self.log_interval = log_interval | |
self.eval_interval = eval_interval | |
self.save_interval = save_interval | |
self.iterations = iterations | |
self.resume_checkpoint = resume_checkpoint | |
self.use_fp16 = use_fp16 | |
self.fp16_scale_growth = fp16_scale_growth | |
self.weight_decay = weight_decay | |
self.lr_anneal_steps = lr_anneal_steps | |
self.step = 0 | |
self.resume_step = 0 | |
# self.global_batch = self.batch_size * dist.get_world_size() | |
self.global_batch = self.batch_size * dist_util.get_world_size() | |
self.sync_cuda = th.cuda.is_available() | |
# self._load_and_sync_parameters(load_submodule_name) | |
self._load_and_sync_parameters() | |
# ! force bf16 | |
# https://zhuanlan.zhihu.com/p/671165275 | |
self.dtype = th.float32 | |
if use_amp: | |
if th.cuda.get_device_capability(0)[0] < 8: | |
self.dtype = th.float16 # e.g., v100 | |
else: | |
self.dtype = th.bfloat16 # e.g., a100 / a6000 | |
self.mp_trainer_rec = MixedPrecisionTrainer( | |
model=self.rec_model, | |
use_fp16=self.use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
model_name=model_name, | |
use_amp=use_amp) | |
self.writer = SummaryWriter(log_dir=f'{logger.get_dir()}/runs') | |
self.opt = AdamW(self._init_optim_groups(kwargs)) | |
if dist_util.get_rank() == 0: | |
logger.log(self.opt) | |
if self.resume_step: | |
if not ignore_resume_opt: | |
self._load_optimizer_state() | |
else: | |
logger.warn("Ignoring optimizer state from checkpoint.") | |
# Model was resumed, either due to a restart or a checkpoint | |
# being specified at the command line. | |
# self.ema_params = [ | |
# self._load_ema_parameters(rate, load_submodule_name) for rate in self.ema_rate | |
# ] | |
self.ema_params = [ | |
self._load_ema_parameters( | |
rate, | |
self.rec_model, | |
self.mp_trainer_rec, | |
model_name=self.mp_trainer_rec.model_name) | |
for rate in self.ema_rate | |
] | |
else: | |
self.ema_params = [ | |
copy.deepcopy(self.mp_trainer_rec.master_params) | |
for _ in range(len(self.ema_rate)) | |
] | |
# compile | |
self.compile = compile | |
if compile: | |
logger.log('compiling... ignore vit_decoder') | |
self.model = th.compile(self.model) | |
# # self.rec_model.encoder = th.compile(self.rec_model.encoder) | |
# self.rec_model.decoder.decoder_pred = th.compile( | |
# self.rec_model.decoder.decoder_pred) | |
# # self.rec_model.decoder.triplane_decoder = th.compile(self.rec_model.decoder.triplane_decoder) | |
# for module_k, sub_module in self.rec_model.decoder.superresolution.items( | |
# ): | |
# self.rec_model.decoder.superresolution[module_k] = th.compile( | |
# sub_module) | |
if th.cuda.is_available(): | |
self.use_ddp = True | |
self.rec_model = th.nn.SyncBatchNorm.convert_sync_batchnorm( | |
self.rec_model) | |
self.rec_model = DDP( | |
self.rec_model, | |
device_ids=[dist_util.dev()], | |
output_device=dist_util.dev(), | |
broadcast_buffers=False, | |
bucket_cap_mb=128, | |
find_unused_parameters=False, | |
) | |
else: | |
if dist_util.get_world_size() > 1: | |
logger.warn("Distributed training requires CUDA. " | |
"Gradients will not be synchronized properly!") | |
self.use_ddp = False | |
self.rec_model = self.rec_model | |
self.novel_view_poses = None | |
th.cuda.empty_cache() | |
def _init_optim_groups(self, kwargs): | |
raise NotImplementedError('') | |
def _load_and_sync_parameters(self, submodule_name=''): | |
# resume_checkpoint, self.resume_step = find_resume_checkpoint() or self.resume_checkpoint | |
resume_checkpoint = self.resume_checkpoint # * default behaviour | |
# logger.log('resume_checkpoint', resume_checkpoint, self.resume_checkpoint) | |
if resume_checkpoint: | |
self.resume_step = parse_resume_step_from_filename( | |
resume_checkpoint) | |
if dist_util.get_rank() == 0: | |
logger.log( | |
f"loading model from checkpoint: {resume_checkpoint}...") | |
map_location = { | |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
} # configure map_location properly | |
resume_state_dict = dist_util.load_state_dict( | |
resume_checkpoint, map_location=map_location) | |
if submodule_name != '': | |
model_state_dict = getattr(self.rec_model, | |
submodule_name).state_dict() | |
if dist_util.get_rank() == 0: | |
logger.log('loading submodule: ', submodule_name) | |
else: | |
model_state_dict = self.rec_model.state_dict() | |
# model = self.rec_model | |
# for k, v in resume_state_dict.items(): | |
# if k in model_state_dict.keys() and v.size( | |
# ) == model_state_dict[k].size(): | |
# model_state_dict[k] = v | |
# else: | |
# logger.log('!!!! ignore key: ', k, ": ", v.size()) | |
for k, v in resume_state_dict.items(): | |
if '._orig_mod' in k: # prefix in torch.compile | |
k = k.replace('._orig_mod', '') | |
if k in model_state_dict.keys(): | |
if v.size() == model_state_dict[k].size(): | |
model_state_dict[k] = v | |
# model_state_dict[k].copy_(v) | |
else: | |
# if k == 'encoder.conv_in.weight': | |
if False: | |
model_state_dict[k][:, :v.shape[1]] = v | |
model_state_dict[k][:, v.shape[1]:] = 0 | |
logger.log('!!!! partially load: ', k, ": ", | |
v.size(), "state_dict: ", | |
model_state_dict[k].size()) | |
# if v.ndim > 1: | |
# model_state_dict[k][:v.shape[0], :v.shape[1], ...] = v # load the decoder | |
# model_state_dict[k][v.shape[0]:, v.shape[1]:, ...] = 0 | |
# else: | |
# model_state_dict[k][:v.shape[0], ...] = v # load the decoder | |
# model_state_dict[k][v.shape[0]:, ...] = 0 | |
# logger.log('!!!! size mismatch, partially load: ', k, ": ", v.size(), "state_dict: ", model_state_dict[k].size()) | |
else: | |
logger.log('!!!! size mismatch, ignore: ', k, | |
": ", v.size(), "state_dict: ", | |
model_state_dict[k].size()) | |
# elif 'decoder.vit_decoder.blocks' in k: | |
# # st() | |
# # load from 2D ViT pre-trained into 3D ViT blocks. | |
# assert len(model.decoder.vit_decoder.blocks[0].vit_blks | |
# ) == 2 # assert depth=2 here. | |
# fusion_ca_depth = len( | |
# model.decoder.vit_decoder.blocks[0].vit_blks) | |
# vit_subblk_index = int(k.split('.')[3]) | |
# vit_blk_keyname = ('.').join(k.split('.')[4:]) | |
# fusion_blk_index = vit_subblk_index // fusion_ca_depth | |
# fusion_blk_subindex = vit_subblk_index % fusion_ca_depth | |
# model_state_dict[ | |
# f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}'] = v | |
# logger.log('load 2D ViT weight: {}'.format( | |
# f'decoder.vit_decoder.blocks.{fusion_blk_index}.vit_blks.{fusion_blk_subindex}.{vit_blk_keyname}' | |
# )) | |
else: | |
logger.log( | |
'!!!! ignore key, not in the model_state_dict: ', | |
k, ": ", v.size()) | |
logger.log('model loading finished') | |
if submodule_name != '': | |
getattr(self.rec_model, | |
submodule_name).load_state_dict(model_state_dict, | |
strict=True) | |
else: | |
self.rec_model.load_state_dict(model_state_dict, | |
strict=False) | |
# strict=True) | |
if dist_util.get_world_size() > 1: | |
# dist_util.sync_params(self.model.named_parameters()) | |
dist_util.sync_params(self.rec_model.parameters()) | |
logger.log('synced params') | |
def _load_ema_parameters(self, | |
rate, | |
model=None, | |
mp_trainer=None, | |
model_name='ddpm'): | |
if mp_trainer is None: | |
mp_trainer = self.mp_trainer_rec | |
if model is None: | |
model = self.rec_model | |
ema_params = copy.deepcopy(mp_trainer.master_params) | |
# main_checkpoint, _ = find_resume_checkpoint( | |
# self.resume_checkpoint, model_name) or self.resume_checkpoint | |
main_checkpoint = self.resume_checkpoint | |
ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, | |
rate, model_name) | |
if ema_checkpoint and model_name == 'ddpm': | |
if dist_util.get_rank() == 0: | |
if not Path(ema_checkpoint).exists(): | |
logger.log( | |
f"failed to load EMA from checkpoint: {ema_checkpoint}, not exist" | |
) | |
return | |
logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") | |
map_location = { | |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
} # configure map_location properly | |
state_dict = dist_util.load_state_dict( | |
ema_checkpoint, map_location=map_location) | |
model_ema_state_dict = model.state_dict() | |
for k, v in state_dict.items(): | |
if k in model_ema_state_dict.keys() and v.size( | |
) == model_ema_state_dict[k].size(): | |
model_ema_state_dict[k] = v | |
elif 'IN' in k and getattr(model, 'decomposed_IN', False): | |
model_ema_state_dict[k.replace( | |
'IN', 'IN.IN')] = v # decomposed IN | |
else: | |
logger.log('ignore key: ', k, ": ", v.size()) | |
ema_params = mp_trainer.state_dict_to_master_params( | |
model_ema_state_dict) | |
del state_dict | |
# logger.log('ema mark 3, ', model_name, ) | |
# ! debugging, remove to check which key fails. | |
if dist_util.get_world_size() > 1: | |
dist_util.sync_params(ema_params) | |
# logger.log('ema mark 4, ', model_name, ) | |
# del ema_params | |
return ema_params | |
def _load_optimizer_state(self): | |
main_checkpoint, _ = find_resume_checkpoint() | |
if self.resume_checkpoint == '': | |
main_checkpoint, _ = find_resume_checkpoint() | |
else: | |
main_checkpoint = self.resume_checkpoint | |
opt_checkpoint = bf.join(bf.dirname(main_checkpoint), | |
f"opt{self.resume_step:07}.pt") | |
# st() | |
if bf.exists(opt_checkpoint): | |
logger.log( | |
f"loading optimizer state from checkpoint: {opt_checkpoint}") | |
map_location = { | |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank() | |
} # configure map_location properly | |
try: | |
state_dict = dist_util.load_state_dict(opt_checkpoint, | |
map_location=map_location) | |
self.opt.load_state_dict(state_dict) | |
except Exception as e: | |
logger.log(e) | |
# self.opt.load_state_dict({k: v for k, v in state_dict.items() if k in self.opt.state_dict()}) | |
del state_dict | |
else: | |
logger.log('optimizer state load fail: {}'.format(opt_checkpoint)) | |
def run_loop(self, batch=None): | |
while (not self.lr_anneal_steps | |
or self.step + self.resume_step < self.lr_anneal_steps): | |
# let all processes sync up before starting with a new epoch of training | |
dist_util.synchronize() | |
# batch, cond = next(self.data) | |
# if batch is None: | |
# if isinstance(self.data, list): | |
# if self.step <= self.data[2]: | |
# batch = next(self.data[1]) | |
# else: | |
# batch = next(self.data[0]) | |
# else: | |
# for _ in trange(10000): # io profiling | |
batch = next(self.data) # | 56/10000 [00:13<40:48, 4.06it/s] | |
# batch = next(self.data) | |
# ! comment out | |
# if self.novel_view_poses is None: | |
# self.novel_view_poses = th.roll(batch['c'], 1, 0).to( | |
# dist_util.dev()) # save for eval visualization use | |
self.run_step(batch) | |
if self.step % 1000 == 0: | |
dist_util.synchronize() | |
th.cuda.empty_cache() # avoid memory leak | |
if self.step % self.log_interval == 0 and dist_util.get_rank( | |
) == 0: | |
try: | |
out = logger.dumpkvs() | |
# * log to tensorboard | |
for k, v in out.items(): | |
self.writer.add_scalar(f'Loss/{k}', v, | |
self.step + self.resume_step) | |
except Exception as e: # disk no quota | |
logger.log(e) | |
if self.step % self.eval_interval == 0 and self.step != 0: | |
# if self.step % self.eval_interval == 0 and (self.step + | |
# self.resume_step) != 0: | |
# if self.step % self.eval_interval == 0: # ! for debugging | |
# if self.step % self.eval_interval == 0: | |
if dist_util.get_rank() == 0: | |
try: | |
self.eval_loop() | |
except Exception as e: | |
logger.log(e) | |
# self.eval_novelview_loop() | |
# let all processes sync up before starting with a new epoch of training | |
dist_util.synchronize() | |
if self.step % self.save_interval == 0 and self.step != 0: | |
self.save() | |
dist_util.synchronize() | |
# Run for a finite amount of time in integration tests. | |
if os.environ.get("DIFFUSION_TRAINING_TEST", | |
"") and self.step > 0: | |
return | |
self.step += 1 | |
if self.step > self.iterations: | |
logger.log('reached maximum iterations, exiting') | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - | |
1) % self.save_interval != 0 and self.step != 1: | |
self.save() | |
exit() | |
# Save the last checkpoint if it wasn't already saved. | |
if (self.step - 1) % self.save_interval != 0 and self.step != 1: | |
self.save() | |
def eval_loop(self): | |
raise NotImplementedError('') | |
def run_step(self, batch, *args): | |
self.forward_backward(batch) | |
took_step = self.mp_trainer_rec.optimize(self.opt) | |
if took_step: | |
self._update_ema() | |
self._anneal_lr() | |
self.log_step() | |
def forward_backward(self, batch, *args, **kwargs): | |
# th.cuda.empty_cache() | |
raise NotImplementedError('') | |
def _update_ema(self): | |
for rate, params in zip(self.ema_rate, self.ema_params): | |
update_ema(params, self.mp_trainer_rec.master_params, rate=rate) | |
def _anneal_lr(self): | |
if not self.lr_anneal_steps: | |
return | |
frac_done = (self.step + self.resume_step) / self.lr_anneal_steps | |
lr = self.lr * (1 - frac_done) | |
for param_group in self.opt.param_groups: | |
param_group["lr"] = lr | |
def log_step(self): | |
logger.logkv("step", self.step + self.resume_step) | |
logger.logkv("samples", | |
(self.step + self.resume_step + 1) * self.global_batch) | |
def save(self): | |
def save_checkpoint(rate, params): | |
state_dict = self.mp_trainer_rec.master_params_to_state_dict( | |
params) | |
if dist_util.get_rank() == 0: | |
logger.log(f"saving model {rate}...") | |
if not rate: | |
filename = f"model_rec{(self.step+self.resume_step):07d}.pt" | |
else: | |
filename = f"ema_rec_{rate}_{(self.step+self.resume_step):07d}.pt" | |
with bf.BlobFile(bf.join(get_blob_logdir(), filename), | |
"wb") as f: | |
th.save(state_dict, f) | |
try: | |
save_checkpoint( | |
0, self.mp_trainer_rec.master_params) # avoid OOM when saving ckpt | |
for rate, params in zip(self.ema_rate, self.ema_params): | |
save_checkpoint(rate, params) | |
# ! save optimizer | |
if dist.get_rank() == 0: | |
with bf.BlobFile( | |
bf.join(get_blob_logdir(), | |
f"opt{(self.step+self.resume_step):07d}.pt"), | |
"wb", | |
) as f: | |
th.save(self.opt.state_dict(), f) | |
except Exception as e: # disk quota exceed | |
logger.log(e) | |
th.cuda.empty_cache() | |
dist.barrier() | |
class TrainLoop3DRec(TrainLoopBasic): | |
def __init__( | |
self, | |
*, | |
rec_model, | |
loss_class, | |
# diffusion, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=1e-3, | |
# schedule_sampler=None, | |
weight_decay=0.0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
compile=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
compile=compile, | |
**kwargs) | |
# self.rec_model = self.ddp_model | |
# self._prepare_nvs_pose() # for eval novelview visualization | |
th.cuda.empty_cache() | |
self.triplane_scaling_divider = 1.0 | |
self.latent_name = 'latent_normalized_2Ddiffusion' # normalized triplane latent | |
self.render_latent_behaviour = 'decode_after_vae' # directly render using triplane operations | |
def render_video_given_triplane(self, | |
planes, | |
rec_model, | |
name_prefix='0', | |
save_img=False, | |
render_reference=None, | |
save_mesh=False, render_reference_length=40, | |
return_gen_imgs=False): | |
planes *= self.triplane_scaling_divider # if setting clip_denoised=True, the sampled planes will lie in [-1,1]. Thus, values beyond [+- std] will be abandoned in this version. Move to IN for later experiments. | |
# sr_w_code = getattr(self.ddp_rec_model.module.decoder, 'w_avg', None) | |
# sr_w_code = None | |
batch_size = planes.shape[0] | |
# if sr_w_code is not None: | |
# sr_w_code = sr_w_code.reshape(1, 1, | |
# -1).repeat_interleave(batch_size, 0) | |
# used during diffusion sampling inference | |
# if not save_img: | |
# ! mesh | |
if planes.shape[1] == 16: # ffhq/car | |
ddpm_latent = { | |
self.latent_name: planes[:, :12], | |
'bg_plane': planes[:, 12:16], | |
} | |
else: | |
ddpm_latent = { | |
self.latent_name: planes, | |
} | |
ddpm_latent.update( | |
rec_model(latent=ddpm_latent, | |
behaviour='decode_after_vae_no_render')) | |
# if export_mesh: | |
# if True: | |
if save_mesh: | |
# mesh_size = 512 | |
mesh_size = 192 | |
# mesh_size = 384 | |
# mesh_size = 320 | |
# mesh_thres = 3 # TODO, requires tuning | |
# mesh_thres = 5 # TODO, requires tuning | |
mesh_thres = 10 # TODO, requires tuning | |
dump_path = f'{logger.get_dir()}/mesh/' | |
os.makedirs(dump_path, exist_ok=True) | |
grid_out = rec_model( | |
latent=ddpm_latent, | |
grid_size=mesh_size, | |
behaviour='triplane_decode_grid', | |
) | |
vtx, faces = mcubes.marching_cubes( | |
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), | |
mesh_thres) | |
vtx = vtx / (mesh_size - 1) * 2 - 1 | |
# vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0) | |
# vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1) | |
# vtx_colors = (vtx_colors * 255).astype(np.uint8) | |
# mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors) | |
mesh = trimesh.Trimesh( | |
vertices=vtx, | |
faces=faces, | |
) | |
mesh_dump_path = os.path.join(dump_path, f'{name_prefix}.ply') | |
mesh.export(mesh_dump_path, 'ply') | |
print(f"Mesh dumped to {dump_path}") | |
del grid_out, mesh | |
th.cuda.empty_cache() | |
# return | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/triplane_{name_prefix}.mp4', | |
mode='I', | |
fps=15, | |
codec='libx264') | |
if planes.shape[1] == 16: # ffhq/car | |
ddpm_latent = { | |
self.latent_name: planes[:, :12], | |
'bg_plane': planes[:, 12:16], | |
} | |
else: | |
ddpm_latent = { | |
self.latent_name: planes, | |
} | |
ddpm_latent.update( | |
rec_model(latent=ddpm_latent, | |
behaviour='decode_after_vae_no_render')) | |
# planes = planes.repeat_interleave(micro['c'].shape[0], 0) | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
# micro_batchsize = 2 | |
# micro_batchsize = batch_size | |
if render_reference is None: | |
render_reference = self.eval_data # compat | |
else: # use train_traj | |
for key in ['ins', 'bbox', 'caption']: | |
if key in render_reference: | |
render_reference.pop(key) | |
# render_reference.pop('bbox') | |
# render_reference.pop('caption') | |
# compat lst for enumerate | |
render_reference = [{ | |
k: v[idx:idx + 1] | |
for k, v in render_reference.items() | |
} for idx in range(render_reference_length)] | |
# for i, batch in enumerate(tqdm(self.eval_data)): | |
if return_gen_imgs: | |
gen_imgs = [] | |
for i, batch in enumerate(tqdm(render_reference)): | |
micro = { | |
k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v | |
for k, v in batch.items() | |
} | |
# micro = {'c': batch['c'].to(dist_util.dev()).repeat_interleave(batch_size, 0)} | |
# all_pred = [] | |
pred = rec_model( | |
img=None, | |
c=micro['c'], | |
latent=ddpm_latent, | |
# latent={ | |
# # k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None | |
# k: v.repeat_interleave(micro['c'].shape[0], 0) if v is not None else None | |
# for k, v in ddpm_latent.items() | |
# }, | |
behaviour='triplane_dec') | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
# save viridis_r depth | |
pred_depth = pred_depth.cpu()[0].permute(1, 2, 0).numpy() | |
pred_depth = (plt.cm.viridis(pred_depth[..., 0])[..., :3]) * 2 - 1 | |
pred_depth = th.from_numpy(pred_depth).to( | |
pred['image_raw'].device).permute(2, 0, 1).unsqueeze(0) | |
# st() | |
# pred_depth = | |
if 'image_sr' in pred: | |
gen_img = pred['image_sr'] | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), gen_img, | |
self.pool_512(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 128: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), pred['image_sr'], | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
else: | |
gen_img = pred['image_raw'] | |
if return_gen_imgs: | |
gen_imgs.append(gen_img) | |
pred_vis = th.cat( | |
[ | |
# self.pool_128(micro['img']), | |
self.pool_128(gen_img), | |
# self.pool_128(pred_depth.repeat_interleave(3, dim=1)) | |
self.pool_128(pred_depth) | |
], | |
dim=-1) # B, 3, H, W | |
if save_img: | |
for batch_idx in range(gen_img.shape[0]): | |
sampled_img = Image.fromarray( | |
(gen_img[batch_idx].permute(1, 2, 0).cpu().numpy() * | |
127.5 + 127.5).clip(0, 255).astype(np.uint8)) | |
if sampled_img.size != (512, 512): | |
sampled_img = sampled_img.resize( | |
(128, 128), Image.HAMMING) # for shapenet | |
sampled_img.save(logger.get_dir() + | |
'/FID_Cals/{}_{}.png'.format( | |
int(name_prefix) * batch_size + | |
batch_idx, i)) | |
# print('FID_Cals/{}_{}.png'.format(int(name_prefix)*batch_size+batch_idx, i)) | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
# if vis.shape[0] > 1: | |
# vis = np.concatenate(np.split(vis, vis.shape[0], axis=0), | |
# axis=-3) | |
# if not save_img: | |
for j in range(vis.shape[0] | |
): # ! currently only export one plane at a time | |
video_out.append_data(vis[j]) | |
# if not save_img: | |
video_out.close() | |
del video_out | |
print('logged video to: ', | |
f'{logger.get_dir()}/triplane_{name_prefix}.mp4') | |
if return_gen_imgs: | |
return th.cat(gen_imgs) | |
del vis, pred_vis, micro, pred, | |
def _init_optim_groups(self, kwargs): | |
optim_groups = [] | |
if kwargs.get('decomposed', False): # AE | |
if not kwargs.get('ignore_encoder'): | |
optim_groups += [ | |
# vit encoder | |
{ | |
'name': 'encoder', | |
'params': self.mp_trainer_rec.model.encoder.parameters(), | |
'lr': kwargs['encoder_lr'], | |
'weight_decay': kwargs['encoder_weight_decay'] | |
}, | |
] | |
optim_groups += [ | |
# vit decoder backbone | |
{ | |
'name': | |
'decoder.vit_decoder', | |
'params': | |
self.mp_trainer_rec.model.decoder.vit_decoder.parameters(), | |
'lr': | |
kwargs['vit_decoder_lr'], | |
'weight_decay': | |
kwargs['vit_decoder_wd'] | |
}, | |
] | |
# gs rendering no MLP | |
if not ( | |
# isinstance( | |
# self.mp_trainer_rec.model.decoder.triplane_decoder, | |
# GaussianRenderer) or | |
isinstance( | |
self.mp_trainer_rec.model.decoder.triplane_decoder, | |
GaussianRenderer2DGS)): | |
optim_groups.append( | |
# triplane decoder, may include bg synthesis network | |
{ | |
'name': | |
'decoder.triplane_decoder', | |
'params': | |
self.mp_trainer_rec.model.decoder.triplane_decoder. | |
parameters(), | |
'lr': | |
kwargs['triplane_decoder_lr'], | |
# 'weight_decay': self.weight_decay | |
}, ) | |
if self.mp_trainer_rec.model.decoder.superresolution is not None: | |
optim_groups.append({ | |
'name': | |
'decoder.superresolution', | |
'params': | |
self.mp_trainer_rec.model.decoder.superresolution. | |
parameters(), | |
'lr': | |
kwargs['super_resolution_lr'], | |
}) | |
if self.mp_trainer_rec.model.dim_up_mlp is not None: | |
optim_groups.append({ | |
'name': | |
'dim_up_mlp', | |
'params': | |
self.mp_trainer_rec.model.dim_up_mlp.parameters(), | |
'lr': | |
kwargs['encoder_lr'], | |
# 'weight_decay': | |
# self.weight_decay | |
}) | |
# add 3D aware operators | |
if self.mp_trainer_rec.model.decoder.decoder_pred_3d is not None: | |
optim_groups.append({ | |
'name': | |
'decoder_pred_3d', | |
'params': | |
self.mp_trainer_rec.model.decoder.decoder_pred_3d. | |
parameters(), | |
'lr': | |
kwargs['vit_decoder_lr'], | |
'weight_decay': | |
kwargs['vit_decoder_wd'] | |
}) | |
if self.mp_trainer_rec.model.decoder.transformer_3D_blk is not None: | |
optim_groups.append({ | |
'name': | |
'decoder_transformer_3D_blk', | |
'params': | |
self.mp_trainer_rec.model.decoder.transformer_3D_blk. | |
parameters(), | |
'lr': | |
kwargs['vit_decoder_lr'], | |
'weight_decay': | |
kwargs['vit_decoder_wd'] | |
}) | |
if self.mp_trainer_rec.model.decoder.logvar is not None: | |
optim_groups.append({ | |
'name': | |
'decoder_logvar', | |
'params': | |
self.mp_trainer_rec.model.decoder.logvar, | |
'lr': | |
kwargs['vit_decoder_lr'], | |
'weight_decay': | |
kwargs['vit_decoder_wd'] | |
}) | |
if self.mp_trainer_rec.model.decoder.decoder_pred is not None: | |
optim_groups.append( | |
# MLP triplane SR | |
{ | |
'name': | |
'decoder.decoder_pred', | |
'params': | |
self.mp_trainer_rec.model.decoder.decoder_pred. | |
parameters(), | |
'lr': | |
kwargs['vit_decoder_lr'], | |
# 'weight_decay': 0 | |
'weight_decay': | |
kwargs['vit_decoder_wd'] | |
}, ) | |
if self.mp_trainer_rec.model.confnet is not None: | |
optim_groups.append({ | |
'name': | |
'confnet', | |
'params': | |
self.mp_trainer_rec.model.confnet.parameters(), | |
'lr': | |
1e-5, # as in unsup3d | |
}) | |
# self.opt = AdamW(optim_groups) | |
if dist_util.get_rank() == 0: | |
logger.log('using independent optimizer for each components') | |
else: | |
optim_groups = [ | |
dict(name='mp_trainer.master_params', | |
params=self.mp_trainer_rec.master_params, | |
lr=self.lr, | |
weight_decay=self.weight_decay) | |
] | |
logger.log(optim_groups) | |
return optim_groups | |
# def eval_loop(self, c_list:list): | |
def eval_novelview_loop(self): | |
# novel view synthesis given evaluation camera trajectory | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}.mp4', | |
mode='I', | |
fps=60, | |
codec='libx264') | |
all_loss_dict = [] | |
novel_view_micro = {} | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
for i, batch in enumerate(tqdm(self.eval_data)): | |
# for i in range(0, 8, self.microbatch): | |
# c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
if i == 0: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], 0) | |
if isinstance(v, th.Tensor) else v[0:1] | |
for k, v in batch.items() | |
} | |
else: | |
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]: | |
novel_view_micro = { | |
k: | |
v[0:1].to(dist_util.dev()).repeat_interleave( | |
micro['img'].shape[0], 0) | |
for k, v in novel_view_micro.items() | |
} | |
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], | |
c=micro['c']) # pred: (B, 3, 64, 64) | |
# target = { | |
# 'img': micro['img'], | |
# 'depth': micro['depth'], | |
# 'depth_mask': micro['depth_mask'] | |
# } | |
# targe | |
_, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
all_loss_dict.append(loss_dict) | |
# ! move to other places, add tensorboard | |
# pred_vis = th.cat([ | |
# pred['image_raw'], | |
# -pred['image_depth'].repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
else: | |
# pred_vis = th.cat([ | |
# self.pool_64(micro['img']), pred['image_raw'], | |
# pred_depth.repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) # B, 3, H, W | |
pred_vis = th.cat([ | |
self.pool_128(micro['img']), | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
for j in range(vis.shape[0]): | |
video_out.append_data(vis[j]) | |
video_out.close() | |
val_scores_for_logging = calc_average_loss(all_loss_dict) | |
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), | |
'a') as f: | |
json.dump({'step': self.step, **val_scores_for_logging}, f) | |
# * log to tensorboard | |
for k, v in val_scores_for_logging.items(): | |
self.writer.add_scalar(f'Eval/NovelView/{k}', v, | |
self.step + self.resume_step) | |
del video_out | |
# del pred_vis | |
# del pred | |
th.cuda.empty_cache() | |
# @th.no_grad() | |
# def eval_loop(self, c_list:list): | |
def eval_loop(self): | |
# novel view synthesis given evaluation camera trajectory | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4', | |
mode='I', | |
fps=60, | |
codec='libx264') | |
all_loss_dict = [] | |
self.rec_model.eval() | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
for i, batch in enumerate(tqdm(self.eval_data)): | |
# for i in range(0, 8, self.microbatch): | |
# c = c_list[i].to(dist_util.dev()).reshape(1, -1) | |
micro = { | |
k: v.to(dist_util.dev()) if isinstance(v, th.Tensor) else v | |
for k, v in batch.items() | |
} | |
pred = self.rec_model(img=micro['img_to_encoder'], | |
c=micro['c']) # pred: (B, 3, 64, 64) | |
# target = { | |
# 'img': micro['img'], | |
# 'depth': micro['depth'], | |
# 'depth_mask': micro['depth_mask'] | |
# } | |
# if last_batch or not self.use_ddp: | |
# loss, loss_dict = self.loss_class(pred, target) | |
# else: | |
# with self.ddp_model.no_sync(): # type: ignore | |
_, loss_dict = self.loss_class(pred, micro, test_mode=True) | |
all_loss_dict.append(loss_dict) | |
# ! move to other places, add tensorboard | |
# gt_vis = th.cat([micro['img'], micro['img']], dim=-1) # TODO, fail to load depth. range [0, 1] | |
# pred_vis = th.cat([ | |
# pred['image_raw'], | |
# -pred['image_depth'].repeat_interleave(3, dim=1) | |
# ], | |
# dim=-1) | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(1,2,0).cpu().numpy() # ! pred in range[-1, 1] | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
self.pool_128(micro['img']), | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
for j in range(vis.shape[0]): | |
video_out.append_data(vis[j]) | |
video_out.close() | |
val_scores_for_logging = calc_average_loss(all_loss_dict) | |
with open(os.path.join(logger.get_dir(), 'scores.json'), 'a') as f: | |
json.dump({'step': self.step, **val_scores_for_logging}, f) | |
# * log to tensorboard | |
for k, v in val_scores_for_logging.items(): | |
self.writer.add_scalar(f'Eval/Rec/{k}', v, | |
self.step + self.resume_step) | |
th.cuda.empty_cache() | |
# if 'SuperresolutionHybrid8X' in self.rendering_kwargs: # ffhq/afhq | |
# self.eval_novelview_loop_trajectory() | |
# else: | |
self.eval_novelview_loop() | |
self.rec_model.train() | |
def eval_novelview_loop_trajectory(self): | |
# novel view synthesis given evaluation camera trajectory | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
for i, batch in enumerate(tqdm(self.eval_data)): | |
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4', | |
mode='I', | |
fps=60, | |
codec='libx264') | |
for idx, c in enumerate(self.all_nvs_params): | |
pred = self.rec_model(img=micro['img_to_encoder'], | |
c=c.unsqueeze(0).repeat_interleave( | |
micro['img'].shape[0], | |
0)) # pred: (B, 3, 64, 64) | |
# c=micro['c']) # pred: (B, 3, 64, 64) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
# st() | |
pred_vis = th.cat([ | |
self.pool_128(micro['img']), | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
# ! cooncat h dim | |
pred_vis = pred_vis.permute(0, 2, 3, 1).flatten(0, | |
1) # H W 3 | |
# vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
# vis = pred_vis.permute(1,2,0).cpu().numpy() | |
vis = pred_vis.cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
# for j in range(vis.shape[0]): | |
# video_out.append_data(vis[j]) | |
video_out.append_data(vis) | |
video_out.close() | |
th.cuda.empty_cache() | |
def _prepare_nvs_pose(self): | |
device = dist_util.dev() | |
fov_deg = 18.837 # for ffhq/afhq | |
intrinsics = FOV_to_intrinsics(fov_deg, device=device) | |
all_nvs_params = [] | |
pitch_range = 0.25 | |
yaw_range = 0.35 | |
num_keyframes = 10 # how many nv poses to sample from | |
w_frames = 1 | |
cam_pivot = th.Tensor( | |
self.rendering_kwargs.get('avg_camera_pivot')).to(device) | |
cam_radius = self.rendering_kwargs.get('avg_camera_radius') | |
for frame_idx in range(num_keyframes): | |
cam2world_pose = LookAtPoseSampler.sample( | |
3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx / | |
(num_keyframes * w_frames)), | |
3.14 / 2 - 0.05 + | |
pitch_range * np.cos(2 * 3.14 * frame_idx / | |
(num_keyframes * w_frames)), | |
cam_pivot, | |
radius=cam_radius, | |
device=device) | |
camera_params = th.cat( | |
[cam2world_pose.reshape(-1, 16), | |
intrinsics.reshape(-1, 9)], 1) | |
all_nvs_params.append(camera_params) | |
self.all_nvs_params = th.cat(all_nvs_params, 0) | |
def forward_backward(self, batch, *args, **kwargs): | |
# th.cuda.empty_cache() | |
self.mp_trainer_rec.zero_grad() | |
batch_size = batch['img_to_encoder'].shape[0] | |
for i in range(0, batch_size, self.microbatch): | |
micro = { | |
k: v[i:i + self.microbatch].to(dist_util.dev()) | |
for k, v in batch.items() | |
} | |
last_batch = (i + self.microbatch) >= batch_size | |
# wrap forward within amp | |
with th.autocast(device_type='cuda', | |
dtype=th.float16, | |
enabled=self.mp_trainer_rec.use_amp): | |
pred = self.rec_model(img=micro['img_to_encoder'], | |
c=micro['c']) # pred: (B, 3, 64, 64) | |
target = micro | |
# ! only enable in ffhq dataset | |
conf_sigma_percl = None | |
conf_sigma_percl_flip = None | |
if 'conf_sigma' in pred: | |
# all_conf_sigma_l1, all_conf_sigma_percl = pred['conf_sigma'] | |
# all_conf_sigma_l1 = pred['conf_sigma'] | |
all_conf_sigma_l1 = th.nn.functional.interpolate( | |
pred['conf_sigma'], | |
size=pred['image_raw'].shape[-2:], | |
mode='bilinear' | |
) # dynamically resize to target img size | |
conf_sigma_l1 = all_conf_sigma_l1[:, :1] | |
conf_sigma_l1_flip = all_conf_sigma_l1[:, 1:] | |
# conf_sigma_percl = all_conf_sigma_percl[:,:1] | |
# conf_sigma_percl_flip = all_conf_sigma_percl[:,1:] | |
else: | |
conf_sigma = None | |
conf_sigma_l1 = None | |
conf_sigma_l1_flip = None | |
with self.rec_model.no_sync(): # type: ignore | |
loss, loss_dict, fg_mask = self.loss_class( | |
pred, | |
target, | |
step=self.step + self.resume_step, | |
test_mode=False, | |
return_fg_mask=True, | |
conf_sigma_l1=conf_sigma_l1, | |
conf_sigma_percl=conf_sigma_percl) | |
if self.loss_class.opt.symmetry_loss: | |
loss_dict['conf_sigma_log'] = conf_sigma_l1.log() | |
pose, intrinsics = micro['c'][:, :16].reshape( | |
-1, 4, 4), micro['c'][:, 16:] | |
flipped_pose = flip_yaw(pose) | |
mirror_c = th.cat( | |
[flipped_pose.reshape(-1, 16), intrinsics], -1) | |
nvs_pred = self.rec_model(latent={ | |
k: v | |
for k, v in pred.items() if 'latent' in k | |
}, | |
c=mirror_c, | |
behaviour='triplane_dec', | |
return_raw_only=True) | |
# concat data for supervision | |
nvs_gt = { | |
k: th.flip(target[k], [-1]) | |
for k in | |
['img'] # fliplr leads to wrong color; B 3 H W shape | |
} | |
flipped_fg_mask = th.flip(fg_mask, [-1]) | |
# if 'conf_sigma' in pred: | |
# conf_sigma = th.flip(pred['conf_sigma'], [-1]) | |
# conf_sigma = th.nn.AdaptiveAvgPool2d(fg_mask.shape[-2:])(conf_sigma) # dynamically resize to target img size | |
# else: | |
# conf_sigma=None | |
with self.rec_model.no_sync(): # type: ignore | |
loss_symm, loss_dict_symm = self.loss_class.calc_2d_rec_loss( | |
nvs_pred['image_raw'], | |
nvs_gt['img'], | |
flipped_fg_mask, | |
# test_mode=True, | |
test_mode=False, | |
step=self.step + self.resume_step, | |
# conf_sigma=conf_sigma, | |
conf_sigma_l1=conf_sigma_l1_flip, | |
conf_sigma_percl=conf_sigma_percl_flip) | |
# ) | |
loss += (loss_symm * 1.0) # as in unsup3d | |
# loss += (loss_symm * 0.5) # as in unsup3d | |
# loss += (loss_symm * 0.01) # as in unsup3d | |
# if conf_sigma is not None: | |
# loss += th.nn.functional.mse_loss(conf_sigma, flipped_fg_mask) * 0.001 # a log that regularizes all confidence to 1 | |
for k, v in loss_dict_symm.items(): | |
loss_dict[f'{k}_symm'] = v | |
loss_dict[ | |
'flip_conf_sigma_log'] = conf_sigma_l1_flip.log() | |
# ! add density-reg in eg3d, tv-loss | |
if self.loss_class.opt.density_reg > 0 and self.step % self.loss_class.opt.density_reg_every == 0: | |
initial_coordinates = th.rand( | |
(batch_size, 1000, 3), | |
device=dist_util.dev()) * 2 - 1 # [-1, 1] | |
perturbed_coordinates = initial_coordinates + th.randn_like( | |
initial_coordinates | |
) * self.loss_class.opt.density_reg_p_dist | |
all_coordinates = th.cat( | |
[initial_coordinates, perturbed_coordinates], dim=1) | |
sigma = self.rec_model( | |
latent=pred['latent'], | |
coordinates=all_coordinates, | |
directions=th.randn_like(all_coordinates), | |
behaviour='triplane_renderer', | |
)['sigma'] | |
sigma_initial = sigma[:, :sigma.shape[1] // 2] | |
sigma_perturbed = sigma[:, sigma.shape[1] // 2:] | |
TVloss = th.nn.functional.l1_loss( | |
sigma_initial, | |
sigma_perturbed) * self.loss_class.opt.density_reg | |
loss_dict.update(dict(tv_loss=TVloss)) | |
loss += TVloss | |
self.mp_trainer_rec.backward(loss) | |
log_rec3d_loss_dict(loss_dict) | |
# for name, p in self.rec_model.named_parameters(): | |
# if p.grad is None: | |
# logger.log(f"found rec unused param: {name}") | |
if dist_util.get_rank() == 0 and self.step % 500 == 0: | |
with th.no_grad(): | |
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1) | |
def norm_depth(pred_depth): # to [-1,1] | |
# pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
return -(pred_depth * 2 - 1) | |
pred_img = pred['image_raw'] | |
gt_img = micro['img'] | |
# infer novel view also | |
if self.loss_class.opt.symmetry_loss: | |
pred_nv_img = nvs_pred | |
else: | |
pred_nv_img = self.rec_model( | |
img=micro['img_to_encoder'], | |
c=self.novel_view_poses) # pred: (B, 3, 64, 64) | |
# if 'depth' in micro: | |
gt_depth = micro['depth'] | |
if gt_depth.ndim == 3: | |
gt_depth = gt_depth.unsqueeze(1) | |
gt_depth = norm_depth(gt_depth) | |
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() - | |
# gt_depth.min()) | |
# if True: | |
fg_mask = pred['image_mask'] * 2 - 1 # 0-1 | |
nv_fg_mask = pred_nv_img['image_mask'] * 2 - 1 # 0-1 | |
if 'image_depth' in pred: | |
pred_depth = norm_depth(pred['image_depth']) | |
pred_nv_depth = norm_depth(pred_nv_img['image_depth']) | |
else: | |
pred_depth = th.zeros_like(gt_depth) | |
pred_nv_depth = th.zeros_like(gt_depth) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_img = th.cat( | |
[self.pool_512(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat( | |
[self.pool_512(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_512(pred_depth) | |
gt_depth = self.pool_512(gt_depth) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_img = th.cat( | |
[self.pool_256(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat( | |
[self.pool_256(micro['img']), micro['img_sr']], | |
dim=-1) | |
pred_depth = self.pool_256(pred_depth) | |
gt_depth = self.pool_256(gt_depth) | |
else: | |
pred_img = th.cat( | |
[self.pool_128(pred_img), pred['image_sr']], | |
dim=-1) | |
gt_img = th.cat( | |
[self.pool_128(micro['img']), micro['img_sr']], | |
dim=-1) | |
gt_depth = self.pool_128(gt_depth) | |
pred_depth = self.pool_128(pred_depth) | |
else: | |
gt_img = self.pool_128(gt_img) | |
gt_depth = self.pool_128(gt_depth) | |
pred_vis = th.cat([ | |
pred_img, | |
pred_depth.repeat_interleave(3, dim=1), | |
fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
if 'conf_sigma' in pred: | |
conf_sigma_l1 = (1 / (conf_sigma_l1 + 1e-7) | |
).repeat_interleave(3, dim=1) * 2 - 1 | |
pred_vis = th.cat([ | |
pred_vis, | |
conf_sigma_l1, | |
], dim=-1) # B, 3, H, W | |
pred_vis_nv = th.cat([ | |
pred_nv_img['image_raw'], | |
pred_nv_depth.repeat_interleave(3, dim=1), | |
nv_fg_mask.repeat_interleave(3, dim=1), | |
], | |
dim=-1) # B, 3, H, W | |
if 'conf_sigma' in pred: | |
# conf_sigma_for_vis = (1/conf_sigma).repeat_interleave(3, dim=1) | |
# conf_sigma_for_vis = (conf_sigma_for_vis / conf_sigma_for_vis.max() ) * 2 - 1 # normalize to [-1,1] | |
conf_sigma_for_vis_flip = ( | |
1 / (conf_sigma_l1_flip + 1e-7)).repeat_interleave( | |
3, dim=1) * 2 - 1 | |
pred_vis_nv = th.cat( | |
[ | |
pred_vis_nv, | |
conf_sigma_for_vis_flip, | |
# th.cat([conf_sigma_for_vis, flipped_fg_mask*2-1], -1) | |
], | |
dim=-1) # B, 3, H, W | |
pred_vis = th.cat([pred_vis, pred_vis_nv], | |
dim=-2) # cat in H dim | |
gt_vis = th.cat( | |
[ | |
gt_img, | |
gt_depth.repeat_interleave(3, dim=1), | |
th.zeros_like(gt_img) | |
], | |
dim=-1) # TODO, fail to load depth. range [0, 1] | |
if 'conf_sigma' in pred: | |
gt_vis = th.cat([gt_vis, fg_mask], | |
dim=-1) # placeholder | |
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute( | |
# st() | |
vis = th.cat([gt_vis, pred_vis], dim=-2) | |
# .permute( | |
# 0, 2, 3, 1).cpu() | |
vis_tensor = torchvision.utils.make_grid( | |
vis, nrow=vis.shape[-1] // 64) # HWC | |
torchvision.utils.save_image( | |
vis_tensor, | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', | |
value_range=(-1, 1), | |
normalize=True) | |
# vis = vis.numpy() * 127.5 + 127.5 | |
# vis = vis.clip(0, 255).astype(np.uint8) | |
# Image.fromarray(vis).save( | |
# f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
logger.log( | |
'log vis to: ', | |
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') | |
# self.writer.add_image(f'images', | |
# vis, | |
# self.step + self.resume_step, | |
# dataformats='HWC') | |
return pred | |
class TrainLoop3DTriplaneRec(TrainLoop3DRec): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
compile=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
compile=compile, | |
**kwargs) | |
def eval_loop(self): | |
# novel view synthesis given evaluation camera trajectory | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4', | |
mode='I', | |
fps=60, | |
codec='libx264') | |
all_loss_dict = [] | |
self.rec_model.eval() | |
device = dist_util.dev() | |
# to get intrinsics | |
demo_pose = next(self.data) | |
intrinsics = demo_pose['c'][0][16:25].to(device) | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_{self.step+self.resume_step}.mp4', | |
mode='I', | |
fps=24, | |
bitrate='10M', | |
codec='libx264') | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
# for i, batch in enumerate(tqdm(self.eval_data)): | |
cam_pivot = th.tensor([0, 0, 0], device=dist_util.dev()) | |
cam_radius = 1.8 | |
pitch_range = 0.45 | |
yaw_range = 3.14 # 0.35 | |
frames = 72 | |
# TODO, use PanoHead trajectory | |
# for frame_idx in range(frames): | |
for pose_idx, (angle_y, angle_p) in enumerate( | |
# zip(np.linspace(-0.4, 0.4, 72), [-0.2] * 72)): | |
# zip(np.linspace(-1.57, 1.57, 72), [-1.57] * 72)): | |
# zip(np.linspace(0,3.14, 72), [0] * 72)): # check canonical pose | |
zip([0.2] * 72, np.linspace(-3.14, 3.14, 72))): | |
# cam2world_pose = LookAtPoseSampler.sample(3.14/2 + yaw_range * np.cos(2 * 3.14 * frame_idx / (frames)), | |
# 3.14/2 -0.05 + pitch_range * np.sin(2 * 3.14 * frame_idx / (frames)), | |
# cam_pivot, | |
# radius=cam_radius, device=device) | |
cam2world_pose = LookAtPoseSampler.sample( | |
np.pi / 2 + angle_y, | |
np.pi / 2 + angle_p, | |
# angle_p, | |
cam_pivot, | |
# horizontal_stddev=0.1, # 0.25 | |
# vertical_stddev=0.125, # 0.35, | |
radius=cam_radius, | |
device=device) | |
camera_params = th.cat( | |
[cam2world_pose.reshape(-1, 16), | |
intrinsics.reshape(-1, 9)], 1).to(dist_util.dev()) | |
# micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
micro = {'c': camera_params} | |
pred = self.rec_model(c=micro['c']) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - | |
pred_depth.min()) | |
pred_vis = th.cat([ | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
for j in range(vis.shape[0]): | |
video_out.append_data(vis[j]) | |
video_out.close() | |
self.rec_model.train() | |
class TrainLoop3DRecTrajVis(TrainLoop3DRec): | |
def __init__(self, | |
*, | |
rec_model, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
load_submodule_name='', | |
ignore_resume_opt=False, | |
model_name='rec', | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
load_submodule_name=load_submodule_name, | |
ignore_resume_opt=ignore_resume_opt, | |
model_name=model_name, | |
use_amp=use_amp, | |
**kwargs) | |
self.rendering_kwargs = self.rec_model.module.decoder.triplane_decoder.rendering_kwargs # type: ignore | |
self._prepare_nvs_pose() # for eval novelview visualization | |
def eval_novelview_loop(self): | |
# novel view synthesis given evaluation camera trajectory | |
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval | |
for i, batch in enumerate(tqdm(self.eval_data)): | |
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} | |
video_out = imageio.get_writer( | |
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_batch_{i}.mp4', | |
mode='I', | |
fps=60, | |
codec='libx264') | |
for idx, c in enumerate(self.all_nvs_params): | |
pred = self.rec_model(img=micro['img_to_encoder'], | |
c=c.unsqueeze(0).repeat_interleave( | |
micro['img'].shape[0], | |
0)) # pred: (B, 3, 64, 64) | |
# c=micro['c']) # pred: (B, 3, 64, 64) | |
# normalize depth | |
# if True: | |
pred_depth = pred['image_depth'] | |
pred_depth = (pred_depth - pred_depth.min()) / ( | |
pred_depth.max() - pred_depth.min()) | |
if 'image_sr' in pred: | |
if pred['image_sr'].shape[-1] == 512: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_512(pred['image_raw']), pred['image_sr'], | |
self.pool_512(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
elif pred['image_sr'].shape[-1] == 256: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_256(pred['image_raw']), pred['image_sr'], | |
self.pool_256(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
pred_vis = th.cat([ | |
micro['img_sr'], | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred['image_sr']), | |
self.pool_128(pred_depth).repeat_interleave(3, | |
dim=1) | |
], | |
dim=-1) | |
else: | |
# st() | |
pred_vis = th.cat([ | |
self.pool_128(micro['img']), | |
self.pool_128(pred['image_raw']), | |
self.pool_128(pred_depth).repeat_interleave(3, dim=1) | |
], | |
dim=-1) # B, 3, H, W | |
# ! cooncat h dim | |
pred_vis = pred_vis.permute(0, 2, 3, 1).flatten(0, | |
1) # H W 3 | |
# vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() | |
# vis = pred_vis.permute(1,2,0).cpu().numpy() | |
vis = pred_vis.cpu().numpy() | |
vis = vis * 127.5 + 127.5 | |
vis = vis.clip(0, 255).astype(np.uint8) | |
# for j in range(vis.shape[0]): | |
# video_out.append_data(vis[j]) | |
video_out.append_data(vis) | |
video_out.close() | |
th.cuda.empty_cache() | |
def _prepare_nvs_pose(self): | |
device = dist_util.dev() | |
fov_deg = 18.837 # for ffhq/afhq | |
intrinsics = FOV_to_intrinsics(fov_deg, device=device) | |
all_nvs_params = [] | |
pitch_range = 0.25 | |
yaw_range = 0.35 | |
num_keyframes = 10 # how many nv poses to sample from | |
w_frames = 1 | |
cam_pivot = th.Tensor( | |
self.rendering_kwargs.get('avg_camera_pivot')).to(device) | |
cam_radius = self.rendering_kwargs.get('avg_camera_radius') | |
for frame_idx in range(num_keyframes): | |
cam2world_pose = LookAtPoseSampler.sample( | |
3.14 / 2 + yaw_range * np.sin(2 * 3.14 * frame_idx / | |
(num_keyframes * w_frames)), | |
3.14 / 2 - 0.05 + | |
pitch_range * np.cos(2 * 3.14 * frame_idx / | |
(num_keyframes * w_frames)), | |
cam_pivot, | |
radius=cam_radius, | |
device=device) | |
camera_params = th.cat( | |
[cam2world_pose.reshape(-1, 16), | |
intrinsics.reshape(-1, 9)], 1) | |
all_nvs_params.append(camera_params) | |
self.all_nvs_params = th.cat(all_nvs_params, 0) | |