BlinkBlur / blurriness_estimation /calculate_metrics.py
ArnoBen's picture
Upload 31 files
e99cf64
raw
history blame
5.03 kB
import os
from typing import Tuple, List
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn import metrics
from tqdm import tqdm
from blurriness_estimation.blurriness_estimator import BlurrinessEstimator
from config import BlurrinessConfig
def get_paths_and_labels(cfg: BlurrinessConfig) -> Tuple[List[str], np.ndarray]:
"""
Returns the images paths and corresponding labels for the blur dataset
Args:
cfg: blurriness config
Returns:
Tuple[List[str], np.ndarray]: images paths and corresponding labels for the blur dataset
"""
cfg.data_dir = 'data/blur/'
sharp_dir = os.path.join(cfg.data_dir, "sharp")
blurry_dir = os.path.join(cfg.data_dir, "blurry")
sharp_paths = [os.path.join(sharp_dir, filename) for filename in os.listdir(sharp_dir)]
blurry_paths = [os.path.join(blurry_dir, filename) for filename in os.listdir(blurry_dir)]
files = sharp_paths + blurry_paths
labels = np.hstack([np.zeros(len(sharp_paths)), np.ones(len(blurry_paths))])
return files, labels
def predict(files: List[str], cfg: BlurrinessConfig) -> np.ndarray:
"""
Calculates the variance of the laplacian for the given files
Args:
files: paths of images to process
cfg: blurriness config
Returns:
np.ndarray: predictions as variance of laplacian values
"""
estimator = BlurrinessEstimator(cfg)
pred = []
for file in tqdm(files):
img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
h, w, c = img.shape
img = cv2.resize(img, (h // 4, w // 4), cv2.INTER_AREA)
is_blurry, vol = estimator(img) # vol: variance of laplacian
pred.append(vol)
pred = np.array(pred)
np.save('blurriness_estimation/preds.npy', pred)
return pred
def display_statistics(true: np.ndarray, pred: np.ndarray, cfg: BlurrinessConfig):
"""
Displays the distribution of the variance of the laplacian in out test dataset
Args:
true: true labels (0: sharp, 1: blurry)
pred: variance of laplacian
cfg: blurriness config (to get the threshold selected)
Returns:
"""
pred = np.copy(pred).clip(0, 1000)
data = pd.DataFrame.from_dict({'pred': pred, 'true': true})
priors = (1 - 350/1050, 1 - 700/1050)
weights = np.hstack([priors[0] * np.ones(350), priors[1] * np.ones(700)])
sns.kdeplot(data, x='pred', hue='true', fill=True, weights=weights)
plt.vlines([cfg.threshold], 0, plt.ylim()[1], colors=['r'], linestyles=['--'])
plt.legend(labels=['sharp', 'blurry'])
plt.xlabel('Variance of laplacian')
plt.title("Distribution of Variance of Laplacian\nfor blurry vs sharp images")
def display_metrics(true: np.ndarray, pred: np.ndarray):
"""
Displays Roc curve, precision-recall and confusion matrix
Args:
true: true labels (0: sharp, 1: blurry)
pred: variance of laplacian
"""
# We need to somehow convert the laplacian variance value to a range between 0 and 1 to easily use sklearn tools.
# Let's clip this value to 1000 then divide by 1000.
# Then do 1 - pred, because we want 1 to be "blur" since it's a blurriness detector, not a sharpness detector :)
pred = 1 - pred.clip(0, 1000) / 1000
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
fig.suptitle("Blurriness metrics\n0: Sharp, 1: Blurry")
# Roc Curve
metrics.RocCurveDisplay.from_predictions(true, pred, ax=axes[0])
# Precision Recall
precision, recall, thresholds = metrics.precision_recall_curve(true, pred)
average_precision = metrics.average_precision_score(true, pred)
metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=axes[1])
threshold_idx = int(len(precision) * 0.3)
threshold = thresholds[threshold_idx]
variance_threshold = (1 - threshold) * 1000
# Confusion matrix
metrics.ConfusionMatrixDisplay.from_predictions(true, pred > threshold, cmap='Blues', ax=axes[2])
# Plots details
axes[1].scatter(recall[threshold_idx], precision[threshold_idx], 50, 'r', 'X')
title = f'Threshold: {threshold:.2f}\n(variance: {variance_threshold:.0f})'
axes[2].set_title(title)
plt.tight_layout()
plt.show()
# Classification report
print("\n", metrics.classification_report(true, pred > threshold))
def calculate_metrics():
"""
Calculates and plots various metrics about the task
(ROC curve, precision/recall, distribution, classification matrix)
"""
cfg = BlurrinessConfig()
files, true = get_paths_and_labels(cfg)
if os.path.exists('blurriness_estimation/preds.npy'):
# It takes a while to compute all the predictions so I saved them to iterate
pred = np.load('blurriness_estimation/preds.npy')
else:
pred = predict(files, cfg)
display_statistics(true, pred, cfg)
display_metrics(true, pred)
if __name__ == "__main__":
calculate_metrics()