NamedCurves / utils /evaluator.py
davidserra9's picture
First commit from github repo
117183e verified
import torch
import logging
from torchmetrics import PeakSignalNoiseRatio as PSNR
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from lpips import LPIPS
from utils.deltaE import deltaEab, deltaE00
class Evaluator():
def __init__(self, dataloader, metrics, split_name, log_dirpath, best_metric):
self.dataloader = dataloader
self._create_metrics(metrics)
self.split_name = split_name
self.log_dirpath = log_dirpath
self.best_metric = best_metric
self.best_value = 0
def _create_metrics(self, metrics):
self.metrics = {}
self.cumulative_values = {}
for metric in metrics:
if metric.type == 'PSNR':
self.metrics['PSNR'] = PSNR(**metric.params).cuda()
self.cumulative_values['PSNR'] = 0
elif metric.type == 'SSIM':
self.metrics['SSIM'] = SSIM(**metric.params).cuda()
self.cumulative_values['SSIM'] = 0
elif metric.type == 'LPIPS':
self.metrics['LPIPS'] = LPIPS(**metric.params).cuda()
self.cumulative_values['LPIPS'] = 0
elif metric.type == 'deltaEab':
self.metrics['deltaEab'] = deltaEab()
self.cumulative_values['deltaEab'] = 0
elif metric.type == 'deltaE00':
self.metrics['deltaE00'] = deltaE00()
self.cumulative_values['deltaE00'] = 0
else:
raise NotImplementedError(f"Metric {metric.type} not implemented")
def _compute_metrics(self, input_image, target_image):
for name, metric in self.metrics.items():
self.cumulative_values[name] += metric(input_image, target_image)
def _compute_average_metrics(self):
avg_metrics = {}
for name, value in self.cumulative_values.items():
avg_metrics[name] = float(value / len(self.dataloader))
return avg_metrics
def _reset_metrics(self):
for metric in self.metrics:
self.cumulative_values[metric] = 0
def __call__(self, model, save_results=True):
model.eval()
self._reset_metrics()
with torch.no_grad():
for data in self.dataloader:
input_image, target_image, name = data['input_image'], data['target_image'], data['name']
self._compute_metrics(input_image.cuda(), target_image.cuda())
avg_metrics = self._compute_average_metrics()
logging.info(f"{self.split_name} metrics: " + ", ".join([f'{key}: {value:.4f}' for key, value in avg_metrics.items()]))
if (avg_metrics[self.best_metric] > self.best_value) and save_results:
self.best_value = avg_metrics[self.best_metric]
torch.save({**{'model_state_dict': model.state_dict()}, **avg_metrics},
f"{self.log_dirpath}/{self.split_name}_best_model.pth")
logging.info(f"New best model saved at {self.log_dirpath}/{self.split_name}_best_model.pth")