hyzhou's picture
upload everything
cca9b7e
raw
history blame
16.1 kB
import os
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from pytorch3dunet.datasets.utils import get_train_loaders
from pytorch3dunet.unet3d.losses import get_loss_criterion
from pytorch3dunet.unet3d.metrics import get_evaluation_metric
from pytorch3dunet.unet3d.model import get_model, UNet2D
from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \
create_lr_scheduler, get_number_of_learnable_parameters
from . import utils
logger = get_logger('UNetTrainer')
def create_trainer(config):
# Create the model
model = get_model(config['model'])
if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
model = nn.DataParallel(model)
logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
model = model.cuda()
if torch.cuda.is_available() and not config['device'] == 'cpu':
model = model.cuda()
# Log the number of learnable parameters
logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
# Create loss criterion
loss_criterion = get_loss_criterion(config)
# Create evaluation metric
eval_criterion = get_evaluation_metric(config)
# Create data loaders
loaders = get_train_loaders(config)
# Create the optimizer
optimizer = create_optimizer(config['optimizer'], model)
# Create learning rate adjustment strategy
lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer)
trainer_config = config['trainer']
# Create tensorboard formatter
tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None))
# Create trainer
resume = trainer_config.pop('resume', None)
pre_trained = trainer_config.pop('pre_trained', None)
return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion,
eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter,
resume=resume, pre_trained=pre_trained, **trainer_config)
class UNetTrainer:
"""UNet trainer.
Args:
model (Unet3D): UNet 3D model to be trained
optimizer (nn.optim.Optimizer): optimizer used for training
lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
WARN: bear in mind that lr_scheduler.step() is invoked after every validation step
(i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30
the learning rate will be adjusted after every 30 * validate_after_iters iterations.
loss_criterion (callable): loss function
eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score)
saving the best checkpoint is based on the result of this function on the validation set
loaders (dict): 'train' and 'val' loaders
checkpoint_dir (string): dir for saving checkpoints and tensorboard logs
max_num_epochs (int): maximum number of epochs
max_num_iterations (int): maximum number of iterations
validate_after_iters (int): validate after that many iterations
log_after_iters (int): number of iterations before logging to tensorboard
validate_iters (int): number of validation iterations, if None validate
on the whole validation set
eval_score_higher_is_better (bool): if True higher eval scores are considered better
best_eval_score (float): best validation score so far (higher better)
num_iterations (int): useful when loading the model from the checkpoint
num_epoch (int): useful when loading the model from the checkpoint
tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images
that can be displayed in tensorboard
skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when
evaluation is expensive)
"""
def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir,
max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None,
num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None,
skip_train_validation=False, resume=None, pre_trained=None, **kwargs):
self.model = model
self.optimizer = optimizer
self.scheduler = lr_scheduler
self.loss_criterion = loss_criterion
self.eval_criterion = eval_criterion
self.loaders = loaders
self.checkpoint_dir = checkpoint_dir
self.max_num_epochs = max_num_epochs
self.max_num_iterations = max_num_iterations
self.validate_after_iters = validate_after_iters
self.log_after_iters = log_after_iters
self.validate_iters = validate_iters
self.eval_score_higher_is_better = eval_score_higher_is_better
logger.info(model)
logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}')
# initialize the best_eval_score
if eval_score_higher_is_better:
self.best_eval_score = float('-inf')
else:
self.best_eval_score = float('+inf')
self.writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, 'logs'))
assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided'
self.tensorboard_formatter = tensorboard_formatter
self.num_iterations = num_iterations
self.num_epochs = num_epoch
self.skip_train_validation = skip_train_validation
if resume is not None:
logger.info(f"Loading checkpoint '{resume}'...")
state = utils.load_checkpoint(resume, self.model, self.optimizer)
logger.info(
f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}. Iteration: {state['num_iterations']}. "
f"Best val score: {state['best_eval_score']}."
)
self.best_eval_score = state['best_eval_score']
self.num_iterations = state['num_iterations']
self.num_epochs = state['num_epochs']
self.checkpoint_dir = os.path.split(resume)[0]
elif pre_trained is not None:
logger.info(f"Logging pre-trained model from '{pre_trained}'...")
utils.load_checkpoint(pre_trained, self.model, None)
if 'checkpoint_dir' not in kwargs:
self.checkpoint_dir = os.path.split(pre_trained)[0]
def fit(self):
for _ in range(self.num_epochs, self.max_num_epochs):
# train for one epoch
should_terminate = self.train()
if should_terminate:
logger.info('Stopping criterion is satisfied. Finishing training')
return
self.num_epochs += 1
logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")
def train(self):
"""Trains the model for 1 epoch.
Returns:
True if the training should be terminated immediately, False otherwise
"""
train_losses = utils.RunningAverage()
train_eval_scores = utils.RunningAverage()
# sets the model in training mode
self.model.train()
for t in self.loaders['train']:
logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. '
f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]')
input, target, weight = self._split_training_batch(t)
output, loss = self._forward_pass(input, target, weight)
train_losses.update(loss.item(), self._batch_size(input))
# compute gradients and update parameters
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.num_iterations % self.validate_after_iters == 0:
# set the model in eval mode
self.model.eval()
# evaluate on validation set
eval_score = self.validate()
# set the model back to training mode
self.model.train()
# adjust learning rate if necessary
if isinstance(self.scheduler, ReduceLROnPlateau):
self.scheduler.step(eval_score)
else:
self.scheduler.step()
# log current learning rate in tensorboard
self._log_lr()
# remember best validation metric
is_best = self._is_best_eval_score(eval_score)
# save checkpoint
self._save_checkpoint(is_best)
if self.num_iterations % self.log_after_iters == 0:
# compute eval criterion
if not self.skip_train_validation:
eval_score = self.eval_criterion(output, target)
train_eval_scores.update(eval_score.item(), self._batch_size(input))
# log stats, params and images
logger.info(
f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}')
self._log_stats('train', train_losses.avg, train_eval_scores.avg)
# self._log_params()
self._log_images(input, target, output, 'train_')
if self.should_stop():
return True
self.num_iterations += 1
return False
def should_stop(self):
"""
Training will terminate if maximum number of iterations is exceeded or the learning rate drops below
some predefined threshold (1e-6 in our case)
"""
if self.max_num_iterations < self.num_iterations:
logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.')
return True
min_lr = 1e-6
lr = self.optimizer.param_groups[0]['lr']
if lr < min_lr:
logger.info(f'Learning rate below the minimum {min_lr}.')
return True
return False
def validate(self):
logger.info('Validating...')
val_losses = utils.RunningAverage()
val_scores = utils.RunningAverage()
with torch.no_grad():
for i, t in enumerate(self.loaders['val']):
logger.info(f'Validation iteration {i}')
input, target, weight = self._split_training_batch(t)
output, loss = self._forward_pass(input, target, weight)
val_losses.update(loss.item(), self._batch_size(input))
if i % 100 == 0:
self._log_images(input, target, output, 'val_')
eval_score = self.eval_criterion(output, target)
val_scores.update(eval_score.item(), self._batch_size(input))
if self.validate_iters is not None and self.validate_iters <= i:
# stop validation
break
self._log_stats('val', val_losses.avg, val_scores.avg)
logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}')
return val_scores.avg
def _split_training_batch(self, t):
def _move_to_gpu(input):
if isinstance(input, tuple) or isinstance(input, list):
return tuple([_move_to_gpu(x) for x in input])
else:
if torch.cuda.is_available():
input = input.cuda(non_blocking=True)
return input
t = _move_to_gpu(t)
weight = None
if len(t) == 2:
input, target = t
else:
input, target, weight = t
return input, target, weight
def _forward_pass(self, input, target, weight=None):
if isinstance(self.model, UNet2D):
# remove the singleton z-dimension from the input
input = torch.squeeze(input, dim=-3)
# forward pass
output = self.model(input)
# add the singleton z-dimension to the output
output = torch.unsqueeze(output, dim=-3)
else:
# forward pass
output = self.model(input)
# compute the loss
if weight is None:
loss = self.loss_criterion(output, target)
else:
loss = self.loss_criterion(output, target, weight)
return output, loss
def _is_best_eval_score(self, eval_score):
if self.eval_score_higher_is_better:
is_best = eval_score > self.best_eval_score
else:
is_best = eval_score < self.best_eval_score
if is_best:
logger.info(f'Saving new best evaluation metric: {eval_score}')
self.best_eval_score = eval_score
return is_best
def _save_checkpoint(self, is_best):
# remove `module` prefix from layer names when using `nn.DataParallel`
# see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20
if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch')
logger.info(f"Saving checkpoint to '{last_file_path}'")
utils.save_checkpoint({
'num_epochs': self.num_epochs + 1,
'num_iterations': self.num_iterations,
'model_state_dict': state_dict,
'best_eval_score': self.best_eval_score,
'optimizer_state_dict': self.optimizer.state_dict(),
}, is_best, checkpoint_dir=self.checkpoint_dir)
def _log_lr(self):
lr = self.optimizer.param_groups[0]['lr']
self.writer.add_scalar('learning_rate', lr, self.num_iterations)
def _log_stats(self, phase, loss_avg, eval_score_avg):
tag_value = {
f'{phase}_loss_avg': loss_avg,
f'{phase}_eval_score_avg': eval_score_avg
}
for tag, value in tag_value.items():
self.writer.add_scalar(tag, value, self.num_iterations)
def _log_params(self):
logger.info('Logging model parameters and gradients')
for name, value in self.model.named_parameters():
self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations)
self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations)
def _log_images(self, input, target, prediction, prefix=''):
if isinstance(self.model, nn.DataParallel):
net = self.model.module
else:
net = self.model
if net.final_activation is not None:
prediction = net.final_activation(prediction)
inputs_map = {
'inputs': input,
'targets': target,
'predictions': prediction
}
img_sources = {}
for name, batch in inputs_map.items():
if isinstance(batch, list) or isinstance(batch, tuple):
for i, b in enumerate(batch):
img_sources[f'{name}{i}'] = b.data.cpu().numpy()
else:
img_sources[name] = batch.data.cpu().numpy()
for name, batch in img_sources.items():
for tag, image in self.tensorboard_formatter(name, batch):
self.writer.add_image(prefix + tag, image, self.num_iterations)
@staticmethod
def _batch_size(input):
if isinstance(input, list) or isinstance(input, tuple):
return input[0].size(0)
else:
return input.size(0)