MotionGPT / mGPT /callback.py
bill-jiang's picture
Init
4409449
raw
history blame
No virus
6.48 kB
import os
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, RichProgressBar, ModelCheckpoint
def build_callbacks(cfg, logger=None, phase='test', **kwargs):
callbacks = []
logger = logger
# Rich Progress Bar
callbacks.append(progressBar())
# Checkpoint Callback
if phase == 'train':
callbacks.extend(getCheckpointCallback(cfg, logger=logger, **kwargs))
return callbacks
def getCheckpointCallback(cfg, logger=None, **kwargs):
callbacks = []
# Logging
metric_monitor = {
"loss_total": "total/train",
"Train_jf": "recons/text2jfeats/train",
"Val_jf": "recons/text2jfeats/val",
"Train_rf": "recons/text2rfeats/train",
"Val_rf": "recons/text2rfeats/val",
"APE root": "Metrics/APE_root",
"APE mean pose": "Metrics/APE_mean_pose",
"AVE root": "Metrics/AVE_root",
"AVE mean pose": "Metrics/AVE_mean_pose",
"R_TOP_1": "Metrics/R_precision_top_1",
"R_TOP_2": "Metrics/R_precision_top_2",
"R_TOP_3": "Metrics/R_precision_top_3",
"gt_R_TOP_3": "Metrics/gt_R_precision_top_3",
"FID": "Metrics/FID",
"gt_FID": "Metrics/gt_FID",
"Diversity": "Metrics/Diversity",
"MM dist": "Metrics/Matching_score",
"Accuracy": "Metrics/accuracy",
}
callbacks.append(
progressLogger(logger,metric_monitor=metric_monitor,log_every_n_steps=1))
# Save 10 latest checkpoints
checkpointParams = {
'dirpath': os.path.join(cfg.FOLDER_EXP, "checkpoints"),
'filename': "{epoch}",
'monitor': "step",
'mode': "max",
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
'save_top_k': 8,
'save_last': True,
'save_on_train_epoch_end': True
}
callbacks.append(ModelCheckpoint(**checkpointParams))
# Save checkpoint every n*10 epochs
checkpointParams.update({
'every_n_epochs':
cfg.LOGGER.VAL_EVERY_STEPS * 10,
'save_top_k':
-1,
'save_last':
False
})
callbacks.append(ModelCheckpoint(**checkpointParams))
metrics = cfg.METRIC.TYPE
metric_monitor_map = {
'TemosMetric': {
'Metrics/APE_root': {
'abbr': 'APEroot',
'mode': 'min'
},
},
'TM2TMetrics': {
'Metrics/FID': {
'abbr': 'FID',
'mode': 'min'
},
'Metrics/R_precision_top_3': {
'abbr': 'R3',
'mode': 'max'
}
},
'MRMetrics': {
'Metrics/MPJPE': {
'abbr': 'MPJPE',
'mode': 'min'
}
},
'HUMANACTMetrics': {
'Metrics/Accuracy': {
'abbr': 'Accuracy',
'mode': 'max'
}
},
'UESTCMetrics': {
'Metrics/Accuracy': {
'abbr': 'Accuracy',
'mode': 'max'
}
},
'UncondMetrics': {
'Metrics/FID': {
'abbr': 'FID',
'mode': 'min'
}
}
}
checkpointParams.update({
'every_n_epochs': cfg.LOGGER.VAL_EVERY_STEPS,
'save_top_k': 1,
})
for metric in metrics:
if metric in metric_monitor_map.keys():
metric_monitors = metric_monitor_map[metric]
# Delete R3 if training VAE
if cfg.TRAIN.STAGE == 'vae' and metric == 'TM2TMetrics':
del metric_monitors['Metrics/R_precision_top_3']
for metric_monitor in metric_monitors:
checkpointParams.update({
'filename':
metric_monitor_map[metric][metric_monitor]['mode']
+ "-" +
metric_monitor_map[metric][metric_monitor]['abbr']
+ "{ep}",
'monitor':
metric_monitor,
'mode':
metric_monitor_map[metric][metric_monitor]['mode'],
})
callbacks.append(
ModelCheckpoint(**checkpointParams))
return callbacks
class progressBar(RichProgressBar):
def __init__(self, ):
super().__init__()
def get_metrics(self, trainer, model):
# Don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
class progressLogger(Callback):
def __init__(self,
logger,
metric_monitor: dict,
precision: int = 3,
log_every_n_steps: int = 1):
# Metric to monitor
self.logger = logger
self.metric_monitor = metric_monitor
self.precision = precision
self.log_every_n_steps = log_every_n_steps
def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
**kwargs) -> None:
self.logger.info("Training started")
def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
**kwargs) -> None:
self.logger.info("Training done")
def on_validation_epoch_end(self, trainer: Trainer,
pl_module: LightningModule, **kwargs) -> None:
if trainer.sanity_checking:
self.logger.info("Sanity checking ok.")
def on_train_epoch_end(self,
trainer: Trainer,
pl_module: LightningModule,
padding=False,
**kwargs) -> None:
metric_format = f"{{:.{self.precision}e}}"
line = f"Epoch {trainer.current_epoch}"
if padding:
line = f"{line:>{len('Epoch xxxx')}}" # Right padding
if trainer.current_epoch % self.log_every_n_steps == 0:
metrics_str = []
losses_dict = trainer.callback_metrics
for metric_name, dico_name in self.metric_monitor.items():
if dico_name in losses_dict:
metric = losses_dict[dico_name].item()
metric = metric_format.format(metric)
metric = f"{metric_name} {metric}"
metrics_str.append(metric)
line = line + ": " + " ".join(metrics_str)
self.logger.info(line)