|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from pytorch3dunet.datasets.utils import get_train_loaders |
|
from pytorch3dunet.unet3d.losses import get_loss_criterion |
|
from pytorch3dunet.unet3d.metrics import get_evaluation_metric |
|
from pytorch3dunet.unet3d.model import get_model, UNet2D |
|
from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \ |
|
create_lr_scheduler, get_number_of_learnable_parameters |
|
from . import utils |
|
|
|
logger = get_logger('UNetTrainer') |
|
|
|
|
|
def create_trainer(config): |
|
|
|
model = get_model(config['model']) |
|
|
|
if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': |
|
model = nn.DataParallel(model) |
|
logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') |
|
model = model.cuda() |
|
if torch.cuda.is_available() and not config['device'] == 'cpu': |
|
model = model.cuda() |
|
|
|
|
|
logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') |
|
|
|
|
|
loss_criterion = get_loss_criterion(config) |
|
|
|
eval_criterion = get_evaluation_metric(config) |
|
|
|
|
|
loaders = get_train_loaders(config) |
|
|
|
|
|
optimizer = create_optimizer(config['optimizer'], model) |
|
|
|
|
|
lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer) |
|
|
|
trainer_config = config['trainer'] |
|
|
|
tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None)) |
|
|
|
resume = trainer_config.pop('resume', None) |
|
pre_trained = trainer_config.pop('pre_trained', None) |
|
|
|
return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, |
|
eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter, |
|
resume=resume, pre_trained=pre_trained, **trainer_config) |
|
|
|
|
|
class UNetTrainer: |
|
"""UNet trainer. |
|
|
|
Args: |
|
model (Unet3D): UNet 3D model to be trained |
|
optimizer (nn.optim.Optimizer): optimizer used for training |
|
lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler |
|
WARN: bear in mind that lr_scheduler.step() is invoked after every validation step |
|
(i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 |
|
the learning rate will be adjusted after every 30 * validate_after_iters iterations. |
|
loss_criterion (callable): loss function |
|
eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) |
|
saving the best checkpoint is based on the result of this function on the validation set |
|
loaders (dict): 'train' and 'val' loaders |
|
checkpoint_dir (string): dir for saving checkpoints and tensorboard logs |
|
max_num_epochs (int): maximum number of epochs |
|
max_num_iterations (int): maximum number of iterations |
|
validate_after_iters (int): validate after that many iterations |
|
log_after_iters (int): number of iterations before logging to tensorboard |
|
validate_iters (int): number of validation iterations, if None validate |
|
on the whole validation set |
|
eval_score_higher_is_better (bool): if True higher eval scores are considered better |
|
best_eval_score (float): best validation score so far (higher better) |
|
num_iterations (int): useful when loading the model from the checkpoint |
|
num_epoch (int): useful when loading the model from the checkpoint |
|
tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images |
|
that can be displayed in tensorboard |
|
skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when |
|
evaluation is expensive) |
|
""" |
|
|
|
def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir, |
|
max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None, |
|
num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None, |
|
skip_train_validation=False, resume=None, pre_trained=None, **kwargs): |
|
|
|
self.model = model |
|
self.optimizer = optimizer |
|
self.scheduler = lr_scheduler |
|
self.loss_criterion = loss_criterion |
|
self.eval_criterion = eval_criterion |
|
self.loaders = loaders |
|
self.checkpoint_dir = checkpoint_dir |
|
self.max_num_epochs = max_num_epochs |
|
self.max_num_iterations = max_num_iterations |
|
self.validate_after_iters = validate_after_iters |
|
self.log_after_iters = log_after_iters |
|
self.validate_iters = validate_iters |
|
self.eval_score_higher_is_better = eval_score_higher_is_better |
|
|
|
logger.info(model) |
|
logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}') |
|
|
|
|
|
if eval_score_higher_is_better: |
|
self.best_eval_score = float('-inf') |
|
else: |
|
self.best_eval_score = float('+inf') |
|
|
|
self.writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, 'logs')) |
|
|
|
assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided' |
|
self.tensorboard_formatter = tensorboard_formatter |
|
|
|
self.num_iterations = num_iterations |
|
self.num_epochs = num_epoch |
|
self.skip_train_validation = skip_train_validation |
|
|
|
if resume is not None: |
|
logger.info(f"Loading checkpoint '{resume}'...") |
|
state = utils.load_checkpoint(resume, self.model, self.optimizer) |
|
logger.info( |
|
f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}. Iteration: {state['num_iterations']}. " |
|
f"Best val score: {state['best_eval_score']}." |
|
) |
|
self.best_eval_score = state['best_eval_score'] |
|
self.num_iterations = state['num_iterations'] |
|
self.num_epochs = state['num_epochs'] |
|
self.checkpoint_dir = os.path.split(resume)[0] |
|
elif pre_trained is not None: |
|
logger.info(f"Logging pre-trained model from '{pre_trained}'...") |
|
utils.load_checkpoint(pre_trained, self.model, None) |
|
if 'checkpoint_dir' not in kwargs: |
|
self.checkpoint_dir = os.path.split(pre_trained)[0] |
|
|
|
def fit(self): |
|
for _ in range(self.num_epochs, self.max_num_epochs): |
|
|
|
should_terminate = self.train() |
|
|
|
if should_terminate: |
|
logger.info('Stopping criterion is satisfied. Finishing training') |
|
return |
|
|
|
self.num_epochs += 1 |
|
logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...") |
|
|
|
def train(self): |
|
"""Trains the model for 1 epoch. |
|
|
|
Returns: |
|
True if the training should be terminated immediately, False otherwise |
|
""" |
|
train_losses = utils.RunningAverage() |
|
train_eval_scores = utils.RunningAverage() |
|
|
|
|
|
self.model.train() |
|
|
|
for t in self.loaders['train']: |
|
logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. ' |
|
f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]') |
|
|
|
input, target, weight = self._split_training_batch(t) |
|
|
|
output, loss = self._forward_pass(input, target, weight) |
|
|
|
train_losses.update(loss.item(), self._batch_size(input)) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
if self.num_iterations % self.validate_after_iters == 0: |
|
|
|
self.model.eval() |
|
|
|
eval_score = self.validate() |
|
|
|
self.model.train() |
|
|
|
|
|
if isinstance(self.scheduler, ReduceLROnPlateau): |
|
self.scheduler.step(eval_score) |
|
else: |
|
self.scheduler.step() |
|
|
|
self._log_lr() |
|
|
|
is_best = self._is_best_eval_score(eval_score) |
|
|
|
|
|
self._save_checkpoint(is_best) |
|
|
|
if self.num_iterations % self.log_after_iters == 0: |
|
|
|
if not self.skip_train_validation: |
|
eval_score = self.eval_criterion(output, target) |
|
train_eval_scores.update(eval_score.item(), self._batch_size(input)) |
|
|
|
|
|
logger.info( |
|
f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}') |
|
self._log_stats('train', train_losses.avg, train_eval_scores.avg) |
|
|
|
self._log_images(input, target, output, 'train_') |
|
|
|
if self.should_stop(): |
|
return True |
|
|
|
self.num_iterations += 1 |
|
|
|
return False |
|
|
|
def should_stop(self): |
|
""" |
|
Training will terminate if maximum number of iterations is exceeded or the learning rate drops below |
|
some predefined threshold (1e-6 in our case) |
|
""" |
|
if self.max_num_iterations < self.num_iterations: |
|
logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.') |
|
return True |
|
|
|
min_lr = 1e-6 |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
if lr < min_lr: |
|
logger.info(f'Learning rate below the minimum {min_lr}.') |
|
return True |
|
|
|
return False |
|
|
|
def validate(self): |
|
logger.info('Validating...') |
|
|
|
val_losses = utils.RunningAverage() |
|
val_scores = utils.RunningAverage() |
|
|
|
with torch.no_grad(): |
|
for i, t in enumerate(self.loaders['val']): |
|
logger.info(f'Validation iteration {i}') |
|
|
|
input, target, weight = self._split_training_batch(t) |
|
|
|
output, loss = self._forward_pass(input, target, weight) |
|
val_losses.update(loss.item(), self._batch_size(input)) |
|
|
|
if i % 100 == 0: |
|
self._log_images(input, target, output, 'val_') |
|
|
|
eval_score = self.eval_criterion(output, target) |
|
val_scores.update(eval_score.item(), self._batch_size(input)) |
|
|
|
if self.validate_iters is not None and self.validate_iters <= i: |
|
|
|
break |
|
|
|
self._log_stats('val', val_losses.avg, val_scores.avg) |
|
logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}') |
|
return val_scores.avg |
|
|
|
def _split_training_batch(self, t): |
|
def _move_to_gpu(input): |
|
if isinstance(input, tuple) or isinstance(input, list): |
|
return tuple([_move_to_gpu(x) for x in input]) |
|
else: |
|
if torch.cuda.is_available(): |
|
input = input.cuda(non_blocking=True) |
|
return input |
|
|
|
t = _move_to_gpu(t) |
|
weight = None |
|
if len(t) == 2: |
|
input, target = t |
|
else: |
|
input, target, weight = t |
|
return input, target, weight |
|
|
|
def _forward_pass(self, input, target, weight=None): |
|
if isinstance(self.model, UNet2D): |
|
|
|
input = torch.squeeze(input, dim=-3) |
|
|
|
output = self.model(input) |
|
|
|
output = torch.unsqueeze(output, dim=-3) |
|
else: |
|
|
|
output = self.model(input) |
|
|
|
|
|
if weight is None: |
|
loss = self.loss_criterion(output, target) |
|
else: |
|
loss = self.loss_criterion(output, target, weight) |
|
|
|
return output, loss |
|
|
|
def _is_best_eval_score(self, eval_score): |
|
if self.eval_score_higher_is_better: |
|
is_best = eval_score > self.best_eval_score |
|
else: |
|
is_best = eval_score < self.best_eval_score |
|
|
|
if is_best: |
|
logger.info(f'Saving new best evaluation metric: {eval_score}') |
|
self.best_eval_score = eval_score |
|
|
|
return is_best |
|
|
|
def _save_checkpoint(self, is_best): |
|
|
|
|
|
if isinstance(self.model, nn.DataParallel): |
|
state_dict = self.model.module.state_dict() |
|
else: |
|
state_dict = self.model.state_dict() |
|
|
|
last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch') |
|
logger.info(f"Saving checkpoint to '{last_file_path}'") |
|
|
|
utils.save_checkpoint({ |
|
'num_epochs': self.num_epochs + 1, |
|
'num_iterations': self.num_iterations, |
|
'model_state_dict': state_dict, |
|
'best_eval_score': self.best_eval_score, |
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
}, is_best, checkpoint_dir=self.checkpoint_dir) |
|
|
|
def _log_lr(self): |
|
lr = self.optimizer.param_groups[0]['lr'] |
|
self.writer.add_scalar('learning_rate', lr, self.num_iterations) |
|
|
|
def _log_stats(self, phase, loss_avg, eval_score_avg): |
|
tag_value = { |
|
f'{phase}_loss_avg': loss_avg, |
|
f'{phase}_eval_score_avg': eval_score_avg |
|
} |
|
|
|
for tag, value in tag_value.items(): |
|
self.writer.add_scalar(tag, value, self.num_iterations) |
|
|
|
def _log_params(self): |
|
logger.info('Logging model parameters and gradients') |
|
for name, value in self.model.named_parameters(): |
|
self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations) |
|
self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations) |
|
|
|
def _log_images(self, input, target, prediction, prefix=''): |
|
|
|
if isinstance(self.model, nn.DataParallel): |
|
net = self.model.module |
|
else: |
|
net = self.model |
|
|
|
if net.final_activation is not None: |
|
prediction = net.final_activation(prediction) |
|
|
|
inputs_map = { |
|
'inputs': input, |
|
'targets': target, |
|
'predictions': prediction |
|
} |
|
img_sources = {} |
|
for name, batch in inputs_map.items(): |
|
if isinstance(batch, list) or isinstance(batch, tuple): |
|
for i, b in enumerate(batch): |
|
img_sources[f'{name}{i}'] = b.data.cpu().numpy() |
|
else: |
|
img_sources[name] = batch.data.cpu().numpy() |
|
|
|
for name, batch in img_sources.items(): |
|
for tag, image in self.tensorboard_formatter(name, batch): |
|
self.writer.add_image(prefix + tag, image, self.num_iterations) |
|
|
|
@staticmethod |
|
def _batch_size(input): |
|
if isinstance(input, list) or isinstance(input, tuple): |
|
return input[0].size(0) |
|
else: |
|
return input.size(0) |
|
|