Spaces:
Build error
Build error
import os | |
from pytorch_lightning import LightningModule, Trainer | |
from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint | |
def build_callbacks(cfg, logger=None, phase='test', **kwargs): | |
callbacks = [] | |
logger = logger | |
# Rich Progress Bar | |
callbacks.append(progressBar()) | |
# Checkpoint Callback | |
if phase == 'train': | |
callbacks.extend(getCheckpointCallback(cfg, logger=logger, **kwargs)) | |
return callbacks | |
def getCheckpointCallback(cfg, logger=None, **kwargs): | |
callbacks = [] | |
# Logging | |
metric_monitor = { | |
"loss_total": "total/train", | |
"Train_jf": "recons/text2jfeats/train", | |
"Val_jf": "recons/text2jfeats/val", | |
"Train_rf": "recons/text2rfeats/train", | |
"Val_rf": "recons/text2rfeats/val", | |
"APE root": "Metrics/APE_root", | |
"APE mean pose": "Metrics/APE_mean_pose", | |
"AVE root": "Metrics/AVE_root", | |
"AVE mean pose": "Metrics/AVE_mean_pose", | |
"R_TOP_1": "Metrics/R_precision_top_1", | |
"R_TOP_2": "Metrics/R_precision_top_2", | |
"R_TOP_3": "Metrics/R_precision_top_3", | |
"gt_R_TOP_3": "Metrics/gt_R_precision_top_3", | |
"FID": "Metrics/FID", | |
"gt_FID": "Metrics/gt_FID", | |
"Diversity": "Metrics/Diversity", | |
"MM dist": "Metrics/Matching_score", | |
"Accuracy": "Metrics/accuracy", | |
} | |
callbacks.append( | |
progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1)) | |
# Save 10 latest checkpoints | |
checkpointParams = { | |
'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"), | |
'filename': "{epoch}", | |
'monitor': "step", | |
'mode': "max", | |
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS, | |
'save_top_k': 8, | |
'save_last': True, | |
'save_on_train_epoch_end': True | |
} | |
callbacks.append(ModelCheckpoint(**checkpointParams)) | |
# Save checkpoint every n*10 epochs | |
checkpointParams.update({ | |
'every_n_epochs': | |
cfg.LOGGER.VAL_EVERY_STEPS * 10, | |
'save_top_k': | |
-1, | |
'save_last': | |
False | |
}) | |
callbacks.append(ModelCheckpoint(**checkpointParams)) | |
metrics = cfg.METRIC.TYPE | |
metric_monitor_map = { | |
'TemosMetric': { | |
'Metrics/APE_root': { | |
'abbr': 'APEroot', | |
'mode': 'min' | |
}, | |
}, | |
'TM2TMetrics': { | |
'Metrics/FID': { | |
'abbr': 'FID', | |
'mode': 'min' | |
}, | |
'Metrics/R_precision_top_3': { | |
'abbr': 'R3', | |
'mode': 'max' | |
} | |
}, | |
'MRMetrics': { | |
'Metrics/MPJPE': { | |
'abbr': 'MPJPE', | |
'mode': 'min' | |
} | |
}, | |
'HUMANACTMetrics': { | |
'Metrics/Accuracy': { | |
'abbr': 'Accuracy', | |
'mode': 'max' | |
} | |
}, | |
'UESTCMetrics': { | |
'Metrics/Accuracy': { | |
'abbr': 'Accuracy', | |
'mode': 'max' | |
} | |
}, | |
'UncondMetrics': { | |
'Metrics/FID': { | |
'abbr': 'FID', | |
'mode': 'min' | |
} | |
} | |
} | |
checkpointParams.update({ | |
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS, | |
'save_top_k': 1, | |
}) | |
for metric in metrics: | |
if metric in metric_monitor_map.keys(): | |
metric_monitors = metric_monitor_map[metric] | |
# Delete R3 if training VAE | |
if cfg.TRAIN.STAGE == 'vae' and metric == 'TM2TMetrics': | |
del metric_monitors['Metrics/R_precision_top_3'] | |
for metric_monitor in metric_monitors: | |
checkpointParams.update({ | |
'filename': | |
metric_monitor_map[metric][metric_monitor]['mode'] | |
+ "-" + | |
metric_monitor_map[metric][metric_monitor]['abbr'] | |
+ "{ep}", | |
'monitor': | |
metric_monitor, | |
'mode': | |
metric_monitor_map[metric][metric_monitor]['mode'], | |
}) | |
callbacks.append( | |
ModelCheckpoint(**checkpointParams)) | |
return callbacks | |
class progressBar(RichProgressBar): | |
def __init__(self, ): | |
super().__init__() | |
def get_metrics(self, trainer, model): | |
# Don't show the version number | |
items = super().get_metrics(trainer, model) | |
items.pop("v_num", None) | |
return items | |
class progressLogger(Callback): | |
def __init__(self, | |
logger, | |
metric_monitor: dict, | |
precision: int = 3, | |
log_every_n_steps: int = 1): | |
# Metric to monitor | |
self.logger = logger | |
self.metric_monitor = metric_monitor | |
self.precision = precision | |
self.log_every_n_steps = log_every_n_steps | |
def on_train_start(self, trainer: Trainer, pl_module: LightningModule, | |
**kwargs) -> None: | |
self.logger.info("Training started") | |
def on_train_end(self, trainer: Trainer, pl_module: LightningModule, | |
**kwargs) -> None: | |
self.logger.info("Training done") | |
def on_validation_epoch_end(self, trainer: Trainer, | |
pl_module: LightningModule, **kwargs) -> None: | |
if trainer.sanity_checking: | |
self.logger.info("Sanity checking ok.") | |
def on_train_epoch_end(self, | |
trainer: Trainer, | |
pl_module: LightningModule, | |
padding=False, | |
**kwargs) -> None: | |
metric_format = f"{{:.{self.precision}e}}" | |
line = f"Epoch {trainer.current_epoch}" | |
if padding: | |
line = f"{line:>{len('Epoch xxxx')}}" # Right padding | |
if trainer.current_epoch % self.log_every_n_steps == 0: | |
metrics_str = [] | |
losses_dict = trainer.callback_metrics | |
for metric_name, dico_name in self.metric_monitor.items(): | |
if dico_name in losses_dict: | |
metric = losses_dict[dico_name].item() | |
metric = metric_format.format(metric) | |
metric = f"{metric_name} {metric}" | |
metrics_str.append(metric) | |
line = line + ": " + " ".join(metrics_str) | |
self.logger.info(line) | |