|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from termcolor import colored |
|
|
|
from ..dataset.mesh_util import * |
|
from ..net.geometry import orthogonal |
|
|
|
|
|
class Format: |
|
end = '\033[0m' |
|
start = '\033[4m' |
|
|
|
|
|
def init_loss(): |
|
|
|
losses = { |
|
|
|
"cloth": {"weight": 1e3, "value": 0.0}, |
|
|
|
"stiff": {"weight": 1e5, "value": 0.0}, |
|
|
|
"rigid": {"weight": 1e5, "value": 0.0}, |
|
|
|
"edge": {"weight": 0, "value": 0.0}, |
|
|
|
"nc": {"weight": 0, "value": 0.0}, |
|
|
|
"lapla": {"weight": 1e2, "value": 0.0}, |
|
|
|
"normal": {"weight": 1e0, "value": 0.0}, |
|
|
|
"silhouette": {"weight": 1e0, "value": 0.0}, |
|
|
|
"joint": {"weight": 5e0, "value": 0.0}, |
|
} |
|
|
|
return losses |
|
|
|
|
|
class SubTrainer(pl.Trainer): |
|
def save_checkpoint(self, filepath, weights_only=False): |
|
"""Save model/training states as a checkpoint file through state-dump and file-write. |
|
Args: |
|
filepath: write-target file's path |
|
weights_only: saving model weights only |
|
""" |
|
_checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only) |
|
|
|
del_keys = [] |
|
for key in _checkpoint["state_dict"].keys(): |
|
for ignore_key in ["normal_filter", "voxelization", "reconEngine"]: |
|
if ignore_key in key: |
|
del_keys.append(key) |
|
for key in del_keys: |
|
del _checkpoint["state_dict"][key] |
|
|
|
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath) |
|
|
|
|
|
def query_func(opt, netG, features, points, proj_matrix=None): |
|
""" |
|
- points: size of (bz, N, 3) |
|
- proj_matrix: size of (bz, 4, 4) |
|
return: size of (bz, 1, N) |
|
""" |
|
assert len(points) == 1 |
|
samples = points.repeat(opt.num_views, 1, 1) |
|
samples = samples.permute(0, 2, 1) |
|
|
|
|
|
if proj_matrix is not None: |
|
samples = orthogonal(samples, proj_matrix) |
|
|
|
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples) |
|
|
|
preds = netG.query( |
|
features=features, |
|
points=samples, |
|
calibs=calib_tensor, |
|
regressor=netG.if_regressor, |
|
) |
|
|
|
if type(preds) is list: |
|
preds = preds[0] |
|
|
|
return preds |
|
|
|
|
|
def query_func_IF(batch, netG, points): |
|
""" |
|
- points: size of (bz, N, 3) |
|
return: size of (bz, 1, N) |
|
""" |
|
|
|
batch["samples_geo"] = points |
|
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points) |
|
|
|
preds = netG(batch) |
|
|
|
return preds.unsqueeze(1) |
|
|
|
|
|
def batch_mean(res, key): |
|
return torch.stack([ |
|
x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res |
|
]).mean() |
|
|
|
|
|
def accumulate(outputs, rot_num, split): |
|
|
|
hparam_log_dict = {} |
|
|
|
metrics = outputs[0].keys() |
|
datasets = split.keys() |
|
|
|
for dataset in datasets: |
|
for metric in metrics: |
|
keyword = f"{dataset}/{metric}" |
|
if keyword not in hparam_log_dict.keys(): |
|
hparam_log_dict[keyword] = 0 |
|
for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num): |
|
hparam_log_dict[keyword] += outputs[idx][metric].item() |
|
hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num |
|
|
|
print(colored(hparam_log_dict, "green")) |
|
|
|
return hparam_log_dict |
|
|