|
import sys |
|
import torch |
|
from tqdm import tqdm as tqdm |
|
from .meter import AverageValueMeter |
|
|
|
|
|
class Epoch: |
|
def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): |
|
self.model = model |
|
self.loss = loss |
|
self.metrics = metrics |
|
self.stage_name = stage_name |
|
self.verbose = verbose |
|
self.device = device |
|
|
|
self._to_device() |
|
|
|
def _to_device(self): |
|
self.model.to(self.device) |
|
self.loss.to(self.device) |
|
for metric in self.metrics: |
|
metric.to(self.device) |
|
|
|
def _format_logs(self, logs): |
|
str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] |
|
s = ", ".join(str_logs) |
|
return s |
|
|
|
def batch_update(self, x, y): |
|
raise NotImplementedError |
|
|
|
def on_epoch_start(self): |
|
pass |
|
|
|
def run(self, dataloader): |
|
|
|
self.on_epoch_start() |
|
|
|
logs = {} |
|
loss_meter = AverageValueMeter() |
|
metrics_meters = { |
|
metric.__name__: AverageValueMeter() for metric in self.metrics |
|
} |
|
|
|
with tqdm( |
|
dataloader, |
|
desc=self.stage_name, |
|
file=sys.stdout, |
|
disable=not (self.verbose), |
|
) as iterator: |
|
for x, y in iterator: |
|
x, y = x.to(self.device), y.to(self.device) |
|
loss, y_pred = self.batch_update(x, y) |
|
|
|
|
|
loss_value = loss.cpu().detach().numpy() |
|
loss_meter.add(loss_value) |
|
loss_logs = {self.loss.__name__: loss_meter.mean} |
|
logs.update(loss_logs) |
|
|
|
|
|
for metric_fn in self.metrics: |
|
metric_value = metric_fn(y_pred, y).cpu().detach().numpy() |
|
metrics_meters[metric_fn.__name__].add(metric_value) |
|
metrics_logs = {k: v.mean for k, v in metrics_meters.items()} |
|
logs.update(metrics_logs) |
|
|
|
if self.verbose: |
|
s = self._format_logs(logs) |
|
iterator.set_postfix_str(s) |
|
|
|
return logs |
|
|
|
|
|
class TrainEpoch(Epoch): |
|
def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): |
|
super().__init__( |
|
model=model, |
|
loss=loss, |
|
metrics=metrics, |
|
stage_name="train", |
|
device=device, |
|
verbose=verbose, |
|
) |
|
self.optimizer = optimizer |
|
|
|
def on_epoch_start(self): |
|
self.model.train() |
|
|
|
def batch_update(self, x, y): |
|
self.optimizer.zero_grad() |
|
prediction = self.model.forward(x) |
|
loss = self.loss(prediction, y) |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss, prediction |
|
|
|
|
|
class ValidEpoch(Epoch): |
|
def __init__(self, model, loss, metrics, device="cpu", verbose=True): |
|
super().__init__( |
|
model=model, |
|
loss=loss, |
|
metrics=metrics, |
|
stage_name="valid", |
|
device=device, |
|
verbose=verbose, |
|
) |
|
|
|
def on_epoch_start(self): |
|
self.model.eval() |
|
|
|
def batch_update(self, x, y): |
|
with torch.no_grad(): |
|
prediction = self.model.forward(x) |
|
loss = self.loss(prediction, y) |
|
return loss, prediction |
|
|