import logging class Trainer(): def __init__(self, model, optimizer, criterion, scheduler, train_loader, valid_evaluator, test_evaluator, config_train, config_eval): self.model = model self.optimizer = optimizer self.criterion = criterion self.scheduler = scheduler self.train_loader = train_loader self.valid_evaluator = valid_evaluator self.test_evaluator = test_evaluator self.config_train = config_train self.config_eval = config_eval def _train_step(self, input_image, target_image): self.optimizer.zero_grad() prediction, x_backbone = self.model(input_image.cuda(), return_backbone=True) loss = self.criterion(x_backbone, prediction, target_image.cuda()) loss.backward() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() return loss.item() def _train_epoch(self): epoch_loss = 0 self.model.train() for data in self.train_loader: input_image, target_image, name = data['input_image'], data['target_image'], data['name'] loss = self._train_step(input_image, target_image) epoch_loss += loss return epoch_loss / len(self.train_loader) def train(self): for epoch in range(self.config_train.epochs): epoch_loss = self._train_epoch() logging.info(f"Epoch {epoch+1}/{self.config_train.epochs} | Loss: {epoch_loss}") if self.valid_evaluator is not None and (epoch+1) % self.config_train.valid_every == 0: self.valid_evaluator(self.model) self.test_evaluator(self.model, save_results=True if self.valid_evaluator is None else False) logging.info("Training finished.")