|
import os |
|
from typing import Tuple, List |
|
|
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
import numpy as np |
|
import seaborn as sns |
|
from sklearn import metrics |
|
from tqdm import tqdm |
|
|
|
from blurriness_estimation.blurriness_estimator import BlurrinessEstimator |
|
from config import BlurrinessConfig |
|
|
|
|
|
def get_paths_and_labels(cfg: BlurrinessConfig) -> Tuple[List[str], np.ndarray]: |
|
""" |
|
Returns the images paths and corresponding labels for the blur dataset |
|
Args: |
|
cfg: blurriness config |
|
|
|
Returns: |
|
Tuple[List[str], np.ndarray]: images paths and corresponding labels for the blur dataset |
|
""" |
|
cfg.data_dir = 'data/blur/' |
|
sharp_dir = os.path.join(cfg.data_dir, "sharp") |
|
blurry_dir = os.path.join(cfg.data_dir, "blurry") |
|
sharp_paths = [os.path.join(sharp_dir, filename) for filename in os.listdir(sharp_dir)] |
|
blurry_paths = [os.path.join(blurry_dir, filename) for filename in os.listdir(blurry_dir)] |
|
files = sharp_paths + blurry_paths |
|
labels = np.hstack([np.zeros(len(sharp_paths)), np.ones(len(blurry_paths))]) |
|
return files, labels |
|
|
|
|
|
def predict(files: List[str], cfg: BlurrinessConfig) -> np.ndarray: |
|
""" |
|
Calculates the variance of the laplacian for the given files |
|
Args: |
|
files: paths of images to process |
|
cfg: blurriness config |
|
|
|
Returns: |
|
np.ndarray: predictions as variance of laplacian values |
|
""" |
|
estimator = BlurrinessEstimator(cfg) |
|
pred = [] |
|
for file in tqdm(files): |
|
img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB) |
|
h, w, c = img.shape |
|
img = cv2.resize(img, (h // 4, w // 4), cv2.INTER_AREA) |
|
is_blurry, vol = estimator(img) |
|
pred.append(vol) |
|
pred = np.array(pred) |
|
np.save('blurriness_estimation/preds.npy', pred) |
|
return pred |
|
|
|
|
|
def display_statistics(true: np.ndarray, pred: np.ndarray, cfg: BlurrinessConfig): |
|
""" |
|
Displays the distribution of the variance of the laplacian in out test dataset |
|
Args: |
|
true: true labels (0: sharp, 1: blurry) |
|
pred: variance of laplacian |
|
cfg: blurriness config (to get the threshold selected) |
|
|
|
Returns: |
|
|
|
""" |
|
pred = np.copy(pred).clip(0, 1000) |
|
data = pd.DataFrame.from_dict({'pred': pred, 'true': true}) |
|
priors = (1 - 350/1050, 1 - 700/1050) |
|
weights = np.hstack([priors[0] * np.ones(350), priors[1] * np.ones(700)]) |
|
sns.kdeplot(data, x='pred', hue='true', fill=True, weights=weights) |
|
plt.vlines([cfg.threshold], 0, plt.ylim()[1], colors=['r'], linestyles=['--']) |
|
plt.legend(labels=['sharp', 'blurry']) |
|
plt.xlabel('Variance of laplacian') |
|
plt.title("Distribution of Variance of Laplacian\nfor blurry vs sharp images") |
|
|
|
|
|
def display_metrics(true: np.ndarray, pred: np.ndarray): |
|
""" |
|
Displays Roc curve, precision-recall and confusion matrix |
|
Args: |
|
true: true labels (0: sharp, 1: blurry) |
|
pred: variance of laplacian |
|
""" |
|
|
|
|
|
|
|
pred = 1 - pred.clip(0, 1000) / 1000 |
|
fig, axes = plt.subplots(1, 3, figsize=(12, 4)) |
|
fig.suptitle("Blurriness metrics\n0: Sharp, 1: Blurry") |
|
|
|
|
|
metrics.RocCurveDisplay.from_predictions(true, pred, ax=axes[0]) |
|
|
|
|
|
precision, recall, thresholds = metrics.precision_recall_curve(true, pred) |
|
average_precision = metrics.average_precision_score(true, pred) |
|
metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=axes[1]) |
|
threshold_idx = int(len(precision) * 0.3) |
|
threshold = thresholds[threshold_idx] |
|
variance_threshold = (1 - threshold) * 1000 |
|
|
|
|
|
metrics.ConfusionMatrixDisplay.from_predictions(true, pred > threshold, cmap='Blues', ax=axes[2]) |
|
|
|
|
|
axes[1].scatter(recall[threshold_idx], precision[threshold_idx], 50, 'r', 'X') |
|
title = f'Threshold: {threshold:.2f}\n(variance: {variance_threshold:.0f})' |
|
axes[2].set_title(title) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
print("\n", metrics.classification_report(true, pred > threshold)) |
|
|
|
|
|
def calculate_metrics(): |
|
""" |
|
Calculates and plots various metrics about the task |
|
(ROC curve, precision/recall, distribution, classification matrix) |
|
""" |
|
cfg = BlurrinessConfig() |
|
files, true = get_paths_and_labels(cfg) |
|
if os.path.exists('blurriness_estimation/preds.npy'): |
|
|
|
pred = np.load('blurriness_estimation/preds.npy') |
|
else: |
|
pred = predict(files, cfg) |
|
display_statistics(true, pred, cfg) |
|
display_metrics(true, pred) |
|
|
|
|
|
if __name__ == "__main__": |
|
calculate_metrics() |
|
|