import sys import torch from tqdm import tqdm as tqdm from .meter import AverageValueMeter class Epoch: def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): self.model = model self.loss = loss self.metrics = metrics self.stage_name = stage_name self.verbose = verbose self.device = device self._to_device() def _to_device(self): self.model.to(self.device) self.loss.to(self.device) for metric in self.metrics: metric.to(self.device) def _format_logs(self, logs): str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] s = ", ".join(str_logs) return s def batch_update(self, x, y): raise NotImplementedError def on_epoch_start(self): pass def run(self, dataloader): self.on_epoch_start() logs = {} loss_meter = AverageValueMeter() metrics_meters = { metric.__name__: AverageValueMeter() for metric in self.metrics } with tqdm( dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose), ) as iterator: for x, y in iterator: x, y = x.to(self.device), y.to(self.device) loss, y_pred = self.batch_update(x, y) # update loss logs loss_value = loss.cpu().detach().numpy() loss_meter.add(loss_value) loss_logs = {self.loss.__name__: loss_meter.mean} logs.update(loss_logs) # update metrics logs for metric_fn in self.metrics: metric_value = metric_fn(y_pred, y).cpu().detach().numpy() metrics_meters[metric_fn.__name__].add(metric_value) metrics_logs = {k: v.mean for k, v in metrics_meters.items()} logs.update(metrics_logs) if self.verbose: s = self._format_logs(logs) iterator.set_postfix_str(s) return logs class TrainEpoch(Epoch): def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): super().__init__( model=model, loss=loss, metrics=metrics, stage_name="train", device=device, verbose=verbose, ) self.optimizer = optimizer def on_epoch_start(self): self.model.train() def batch_update(self, x, y): self.optimizer.zero_grad() prediction = self.model.forward(x) loss = self.loss(prediction, y) loss.backward() self.optimizer.step() return loss, prediction class ValidEpoch(Epoch): def __init__(self, model, loss, metrics, device="cpu", verbose=True): super().__init__( model=model, loss=loss, metrics=metrics, stage_name="valid", device=device, verbose=verbose, ) def on_epoch_start(self): self.model.eval() def batch_update(self, x, y): with torch.no_grad(): prediction = self.model.forward(x) loss = self.loss(prediction, y) return loss, prediction