|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
|
|
from lib.common.seg3d_lossless import Seg3dLossless |
|
from lib.common.train_util import * |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
class IFGeo(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super(IFGeo, self).__init__() |
|
|
|
self.cfg = cfg |
|
self.batch_size = self.cfg.batch_size |
|
self.lr_G = self.cfg.lr_G |
|
|
|
self.use_sdf = cfg.sdf |
|
self.mcube_res = cfg.mcube_res |
|
self.clean_mesh_flag = cfg.clean_mesh |
|
self.overfit = cfg.overfit |
|
|
|
if cfg.dataset.prior_type == "SMPL": |
|
from lib.net.IFGeoNet import IFGeoNet |
|
self.netG = IFGeoNet(cfg) |
|
else: |
|
from lib.net.IFGeoNet_nobody import IFGeoNet |
|
self.netG = IFGeoNet(cfg) |
|
|
|
self.resolutions = ( |
|
np.logspace( |
|
start=5, |
|
stop=np.log2(self.mcube_res), |
|
base=2, |
|
num=int(np.log2(self.mcube_res) - 4), |
|
endpoint=True, |
|
) + 1.0 |
|
) |
|
|
|
self.resolutions = self.resolutions.astype(np.int16).tolist() |
|
|
|
self.reconEngine = Seg3dLossless( |
|
query_func=query_func_IF, |
|
b_min=[[-1.0, 1.0, -1.0]], |
|
b_max=[[1.0, -1.0, 1.0]], |
|
resolutions=self.resolutions, |
|
align_corners=True, |
|
balance_value=0.50, |
|
visualize=False, |
|
debug=False, |
|
use_cuda_impl=False, |
|
faster=True, |
|
) |
|
|
|
self.export_dir = None |
|
self.result_eval = {} |
|
|
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
weight_decay = self.cfg.weight_decay |
|
momentum = self.cfg.momentum |
|
|
|
optim_params_G = [{"params": self.netG.parameters(), "lr": self.lr_G}] |
|
|
|
if self.cfg.optim == "Adadelta": |
|
|
|
optimizer_G = torch.optim.Adadelta( |
|
optim_params_G, lr=self.lr_G, weight_decay=weight_decay |
|
) |
|
|
|
elif self.cfg.optim == "Adam": |
|
|
|
optimizer_G = torch.optim.Adam(optim_params_G, lr=self.lr_G, weight_decay=weight_decay) |
|
|
|
elif self.cfg.optim == "RMSprop": |
|
|
|
optimizer_G = torch.optim.RMSprop( |
|
optim_params_G, |
|
lr=self.lr_G, |
|
weight_decay=weight_decay, |
|
momentum=momentum, |
|
) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
scheduler_G = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer_G, milestones=self.cfg.schedule, gamma=self.cfg.gamma |
|
) |
|
|
|
return [optimizer_G], [scheduler_G] |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
self.netG.train() |
|
|
|
preds_G = self.netG(batch) |
|
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) |
|
|
|
|
|
metrics_log = { |
|
"loss": error_G, |
|
} |
|
|
|
self.log_dict( |
|
metrics_log, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True |
|
) |
|
|
|
return metrics_log |
|
|
|
def training_epoch_end(self, outputs): |
|
|
|
|
|
metrics_log = { |
|
"train/avgloss": batch_mean(outputs, "loss"), |
|
} |
|
|
|
self.log_dict( |
|
metrics_log, |
|
prog_bar=False, |
|
logger=True, |
|
on_step=False, |
|
on_epoch=True, |
|
rank_zero_only=True |
|
) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
|
self.netG.eval() |
|
self.netG.training = False |
|
|
|
preds_G = self.netG(batch) |
|
error_G = self.netG.compute_loss(preds_G, batch["labels_geo"]) |
|
|
|
metrics_log = { |
|
"val/loss": error_G, |
|
} |
|
|
|
self.log_dict( |
|
metrics_log, prog_bar=True, logger=False, on_step=True, on_epoch=False, sync_dist=True |
|
) |
|
|
|
return metrics_log |
|
|
|
def validation_epoch_end(self, outputs): |
|
|
|
|
|
metrics_log = { |
|
"val/avgloss": batch_mean(outputs, "val/loss"), |
|
} |
|
|
|
self.log_dict( |
|
metrics_log, |
|
prog_bar=False, |
|
logger=True, |
|
on_step=False, |
|
on_epoch=True, |
|
rank_zero_only=True |
|
) |
|
|