Spaces:
Running
on
L40S
Running
on
L40S
from __future__ import division | |
import os | |
import torch | |
import datetime | |
import logging | |
logger = logging.getLogger(__name__) | |
class CheckpointSaver(): | |
"""Class that handles saving and loading checkpoints during training.""" | |
def __init__(self, save_dir, save_steps=1000, overwrite=False): | |
self.save_dir = os.path.abspath(save_dir) | |
self.save_steps = save_steps | |
self.overwrite = overwrite | |
if not os.path.exists(self.save_dir): | |
os.makedirs(self.save_dir) | |
self.get_latest_checkpoint() | |
return | |
def exists_checkpoint(self, checkpoint_file=None): | |
"""Check if a checkpoint exists in the current directory.""" | |
if checkpoint_file is None: | |
return False if self.latest_checkpoint is None else True | |
else: | |
return os.path.isfile(checkpoint_file) | |
def save_checkpoint( | |
self, | |
models, | |
optimizers, | |
epoch, | |
batch_idx, | |
batch_size, | |
total_step_count, | |
is_best=False, | |
save_by_step=False, | |
interval=5, | |
with_optimizer=True | |
): | |
"""Save checkpoint.""" | |
timestamp = datetime.datetime.now() | |
if self.overwrite: | |
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_latest.pt')) | |
elif save_by_step: | |
checkpoint_filename = os.path.abspath( | |
os.path.join(self.save_dir, '{:08d}.pt'.format(total_step_count)) | |
) | |
else: | |
if epoch % interval == 0: | |
checkpoint_filename = os.path.abspath( | |
os.path.join(self.save_dir, f'model_epoch_{epoch:02d}.pt') | |
) | |
else: | |
checkpoint_filename = None | |
checkpoint = {} | |
for model in models: | |
model_dict = models[model].state_dict() | |
for k in list(model_dict.keys()): | |
if '.smpl.' in k: | |
del model_dict[k] | |
checkpoint[model] = model_dict | |
if with_optimizer: | |
for optimizer in optimizers: | |
checkpoint[optimizer] = optimizers[optimizer].state_dict() | |
checkpoint['epoch'] = epoch | |
checkpoint['batch_idx'] = batch_idx | |
checkpoint['batch_size'] = batch_size | |
checkpoint['total_step_count'] = total_step_count | |
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) | |
if checkpoint_filename is not None: | |
torch.save(checkpoint, checkpoint_filename) | |
print('Saving checkpoint file [' + checkpoint_filename + ']') | |
if is_best: # save the best | |
checkpoint_filename = os.path.abspath(os.path.join(self.save_dir, 'model_best.pt')) | |
torch.save(checkpoint, checkpoint_filename) | |
print(timestamp, 'Epoch:', epoch, 'Iteration:', batch_idx) | |
print('Saving checkpoint file [' + checkpoint_filename + ']') | |
torch.save(checkpoint, checkpoint_filename) | |
print('Saved checkpoint file [' + checkpoint_filename + ']') | |
def load_checkpoint(self, models, optimizers, checkpoint_file=None): | |
"""Load a checkpoint.""" | |
if checkpoint_file is None: | |
logger.info('Loading latest checkpoint [' + self.latest_checkpoint + ']') | |
checkpoint_file = self.latest_checkpoint | |
checkpoint = torch.load(checkpoint_file) | |
for model in models: | |
if model in checkpoint: | |
model_dict = models[model].state_dict() | |
pretrained_dict = { | |
k: v | |
for k, v in checkpoint[model].items() if k in model_dict.keys() | |
} | |
model_dict.update(pretrained_dict) | |
models[model].load_state_dict(model_dict) | |
# models[model].load_state_dict(checkpoint[model]) | |
for optimizer in optimizers: | |
if optimizer in checkpoint: | |
optimizers[optimizer].load_state_dict(checkpoint[optimizer]) | |
return { | |
'epoch': checkpoint['epoch'], | |
'batch_idx': checkpoint['batch_idx'], | |
'batch_size': checkpoint['batch_size'], | |
'total_step_count': checkpoint['total_step_count'] | |
} | |
def get_latest_checkpoint(self): | |
"""Get filename of latest checkpoint if it exists.""" | |
checkpoint_list = [] | |
for dirpath, dirnames, filenames in os.walk(self.save_dir): | |
for filename in filenames: | |
if filename.endswith('.pt'): | |
checkpoint_list.append(os.path.abspath(os.path.join(dirpath, filename))) | |
# sort | |
import re | |
def atof(text): | |
try: | |
retval = float(text) | |
except ValueError: | |
retval = text | |
return retval | |
def natural_keys(text): | |
''' | |
alist.sort(key=natural_keys) sorts in human order | |
http://nedbatchelder.com/blog/200712/human_sorting.html | |
(See Toothy's implementation in the comments) | |
float regex comes from https://stackoverflow.com/a/12643073/190597 | |
''' | |
return [atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text)] | |
checkpoint_list.sort(key=natural_keys) | |
self.latest_checkpoint = None if (len(checkpoint_list) == 0) else checkpoint_list[-1] | |
return | |