|
from __future__ import division |
|
import os |
|
import torch |
|
import datetime |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CheckpointSaver(): |
|
"""Class that handles saving and loading checkpoints during training.""" |
|
def __init__(self, save_dir, save_steps=1000, overwrite=False): |
|
self.save_dir = os.path.abspath(save_dir) |
|
self.save_steps = save_steps |
|
self.overwrite = overwrite |
|
if not os.path.exists(self.save_dir): |
|
os.makedirs(self.save_dir) |
|
self.get_latest_checkpoint() |
|
return |
|
|
|
def exists_checkpoint(self, checkpoint_file=None): |
|
"""Check if a checkpoint exists in the current directory.""" |
|
if checkpoint_file is None: |
|
return False if self.latest_checkpoint is None else True |
|
else: |
|
return os.path.isfile(checkpoint_file) |
|
|
|
def save_checkpoint( |
|
self, |
|
models, |
|
optimizers, |
|
epoch, |
|
batch_idx, |
|
batch_size, |
|
total_step_count, |
|
is_best=False, |
|
save_by_step=False, |
|
interval=5, |
|
with_optimizer=True |
|
): |
|
"""Save checkpoint.""" |
|
timestamp = datetime.datetime.now() |
|
if self.overwrite: |
|
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) |
|
elif save_by_step: |
|
checkpoint_filename = os.path.abspath( |
|
os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)) |
|
) |
|
else: |
|
if epoch % interval == 0: |
|
checkpoint_filename = os.path.abspath( |
|
os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt') |
|
) |
|
else: |
|
checkpoint_filename = None |
|
|
|
checkpoint = {} |
|
for model in models: |
|
model_dict = models[model].state_dict() |
|
for k in list(model_dict.keys()): |
|
if '.smpl.' in k: |
|
del model_dict[k] |
|
checkpoint[model] = model_dict |
|
if with_optimizer: |
|
for optimizer in optimizers: |
|
checkpoint[optimizer] = optimizers[optimizer].state_dict() |
|
checkpoint['epoch'] = epoch |
|
checkpoint['batch_idx'] = batch_idx |
|
checkpoint['batch_size'] = batch_size |
|
checkpoint['total_step_count'] = total_step_count |
|
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) |
|
|
|
if checkpoint_filename is not None: |
|
torch.save(checkpoint, checkpoint_filename) |
|
print('Saving checkpoint file [' + checkpoint_filename + ']') |
|
if is_best: |
|
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) |
|
torch.save(checkpoint, checkpoint_filename) |
|
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) |
|
print('Saving checkpoint file [' + checkpoint_filename + ']') |
|
torch.save(checkpoint, checkpoint_filename) |
|
print('Saved checkpoint file [' + checkpoint_filename + ']') |
|
|
|
def load_checkpoint(self, models, optimizers, checkpoint_file=None): |
|
"""Load a checkpoint.""" |
|
if checkpoint_file is None: |
|
logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']') |
|
checkpoint_file = self.latest_checkpoint |
|
checkpoint = torch.load(checkpoint_file) |
|
for model in models: |
|
if model in checkpoint: |
|
model_dict = models[model].state_dict() |
|
pretrained_dict = { |
|
k: v |
|
for k, v in checkpoint[model].items() if k in model_dict.keys() |
|
} |
|
model_dict.update(pretrained_dict) |
|
models[model].load_state_dict(model_dict) |
|
|
|
|
|
for optimizer in optimizers: |
|
if optimizer in checkpoint: |
|
optimizers[optimizer].load_state_dict(checkpoint[optimizer]) |
|
return { |
|
'epoch': checkpoint['epoch'], |
|
'batch_idx': checkpoint['batch_idx'], |
|
'batch_size': checkpoint['batch_size'], |
|
'total_step_count': checkpoint['total_step_count'] |
|
} |
|
|
|
def get_latest_checkpoint(self): |
|
"""Get filename of latest checkpoint if it exists.""" |
|
checkpoint_list = [] |
|
for dirpath, dirnames, filenames in os.walk(self.save_dir): |
|
for filename in filenames: |
|
if filename.endswith('.pt'): |
|
checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) |
|
|
|
import re |
|
|
|
def atof(text): |
|
try: |
|
retval = float(text) |
|
except ValueError: |
|
retval = text |
|
return retval |
|
|
|
def natural_keys(text): |
|
''' |
|
alist.sort(key=natural_keys) sorts in human order |
|
http://nedbatchelder.com/blog/200712/human_sorting.html |
|
(See Toothy's implementation in the comments) |
|
float regex comes from https://stackoverflow.com/a/12643073/190597 |
|
''' |
|
return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)] |
|
|
|
checkpoint_list.sort(key=natural_keys) |
|
self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] |
|
return |
|
|