from __future__ import division import os import torch import datetime import logging logger = logging.getLogger(__name__) class CheckpointSaver(): """Class that handles saving and loading checkpoints during training.""" def __init__(self, save_dir, save_steps=1000, overwrite=False): self.save_dir = os.path.abspath(save_dir) self.save_steps = save_steps self.overwrite = overwrite if not os.path.exists(self.save_dir): os.makedirs(self.save_dir) self.get_latest_checkpoint() return def exists_checkpoint(self, checkpoint_file=None): """Check if a checkpoint exists in the current directory.""" if checkpoint_file is None: return False if self.latest_checkpoint is None else True else: return os.path.isfile(checkpoint_file) def save_checkpoint(self, models, optimizers, epoch, batch_idx, batch_size, total_step_count, is_best=False, save_by_step=False, interval=5, with_optimizer=True): """Save checkpoint.""" timestamp = datetime.datetime.now() if self.overwrite: checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) elif save_by_step: checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count))) else: if epoch % interval == 0: checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt')) else: checkpoint_filename = None checkpoint = {} for model in models: model_dict = models[model].state_dict() for k in list(model_dict.keys()): if '.smpl.' in k: del model_dict[k] checkpoint[model] = model_dict if with_optimizer: for optimizer in optimizers: checkpoint[optimizer] = optimizers[optimizer].state_dict() checkpoint['epoch'] = epoch checkpoint['batch_idx'] = batch_idx checkpoint['batch_size'] = batch_size checkpoint['total_step_count'] = total_step_count print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) if checkpoint_filename is not None: torch.save(checkpoint, checkpoint_filename) print('Saving checkpoint file [' + checkpoint_filename + ']') if is_best: # save the best checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) torch.save(checkpoint, checkpoint_filename) print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) print('Saving checkpoint file [' + checkpoint_filename + ']') torch.save(checkpoint, checkpoint_filename) print('Saved checkpoint file [' + checkpoint_filename + ']') def load_checkpoint(self, models, optimizers, checkpoint_file=None): """Load a checkpoint.""" if checkpoint_file is None: logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']') checkpoint_file = self.latest_checkpoint checkpoint = torch.load(checkpoint_file) for model in models: if model in checkpoint: model_dict = models[model].state_dict() pretrained_dict = {k: v for k, v in checkpoint[model].items() if k in model_dict.keys()} model_dict.update(pretrained_dict) models[model].load_state_dict(model_dict) # models[model].load_state_dict(checkpoint[model]) for optimizer in optimizers: if optimizer in checkpoint: optimizers[optimizer].load_state_dict(checkpoint[optimizer]) return {'epoch': checkpoint['epoch'], 'batch_idx': checkpoint['batch_idx'], 'batch_size': checkpoint['batch_size'], 'total_step_count': checkpoint['total_step_count']} def get_latest_checkpoint(self): """Get filename of latest checkpoint if it exists.""" checkpoint_list = [] for dirpath, dirnames, filenames in os.walk(self.save_dir): for filename in filenames: if filename.endswith('.pt'): checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) # sort import re def atof(text): try: retval = float(text) except ValueError: retval = text return retval def natural_keys(text): ''' alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) float regex comes from https://stackoverflow.com/a/12643073/190597 ''' return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ] checkpoint_list.sort(key=natural_keys) self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] return