import os from typing import Tuple, List import cv2 import matplotlib.pyplot as plt import pandas as pd import numpy as np import seaborn as sns from sklearn import metrics from tqdm import tqdm from blurriness_estimation.blurriness_estimator import BlurrinessEstimator from config import BlurrinessConfig def get_paths_and_labels(cfg: BlurrinessConfig) -> Tuple[List[str], np.ndarray]: """ Returns the images paths and corresponding labels for the blur dataset Args: cfg: blurriness config Returns: Tuple[List[str], np.ndarray]: images paths and corresponding labels for the blur dataset """ cfg.data_dir = 'data/blur/' sharp_dir = os.path.join(cfg.data_dir, "sharp") blurry_dir = os.path.join(cfg.data_dir, "blurry") sharp_paths = [os.path.join(sharp_dir, filename) for filename in os.listdir(sharp_dir)] blurry_paths = [os.path.join(blurry_dir, filename) for filename in os.listdir(blurry_dir)] files = sharp_paths + blurry_paths labels = np.hstack([np.zeros(len(sharp_paths)), np.ones(len(blurry_paths))]) return files, labels def predict(files: List[str], cfg: BlurrinessConfig) -> np.ndarray: """ Calculates the variance of the laplacian for the given files Args: files: paths of images to process cfg: blurriness config Returns: np.ndarray: predictions as variance of laplacian values """ estimator = BlurrinessEstimator(cfg) pred = [] for file in tqdm(files): img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB) h, w, c = img.shape img = cv2.resize(img, (h // 4, w // 4), cv2.INTER_AREA) is_blurry, vol = estimator(img) # vol: variance of laplacian pred.append(vol) pred = np.array(pred) np.save('blurriness_estimation/preds.npy', pred) return pred def display_statistics(true: np.ndarray, pred: np.ndarray, cfg: BlurrinessConfig): """ Displays the distribution of the variance of the laplacian in out test dataset Args: true: true labels (0: sharp, 1: blurry) pred: variance of laplacian cfg: blurriness config (to get the threshold selected) Returns: """ pred = np.copy(pred).clip(0, 1000) data = pd.DataFrame.from_dict({'pred': pred, 'true': true}) priors = (1 - 350/1050, 1 - 700/1050) weights = np.hstack([priors[0] * np.ones(350), priors[1] * np.ones(700)]) sns.kdeplot(data, x='pred', hue='true', fill=True, weights=weights) plt.vlines([cfg.threshold], 0, plt.ylim()[1], colors=['r'], linestyles=['--']) plt.legend(labels=['sharp', 'blurry']) plt.xlabel('Variance of laplacian') plt.title("Distribution of Variance of Laplacian\nfor blurry vs sharp images") def display_metrics(true: np.ndarray, pred: np.ndarray): """ Displays Roc curve, precision-recall and confusion matrix Args: true: true labels (0: sharp, 1: blurry) pred: variance of laplacian """ # We need to somehow convert the laplacian variance value to a range between 0 and 1 to easily use sklearn tools. # Let's clip this value to 1000 then divide by 1000. # Then do 1 - pred, because we want 1 to be "blur" since it's a blurriness detector, not a sharpness detector :) pred = 1 - pred.clip(0, 1000) / 1000 fig, axes = plt.subplots(1, 3, figsize=(12, 4)) fig.suptitle("Blurriness metrics\n0: Sharp, 1: Blurry") # Roc Curve metrics.RocCurveDisplay.from_predictions(true, pred, ax=axes[0]) # Precision Recall precision, recall, thresholds = metrics.precision_recall_curve(true, pred) average_precision = metrics.average_precision_score(true, pred) metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=axes[1]) threshold_idx = int(len(precision) * 0.3) threshold = thresholds[threshold_idx] variance_threshold = (1 - threshold) * 1000 # Confusion matrix metrics.ConfusionMatrixDisplay.from_predictions(true, pred > threshold, cmap='Blues', ax=axes[2]) # Plots details axes[1].scatter(recall[threshold_idx], precision[threshold_idx], 50, 'r', 'X') title = f'Threshold: {threshold:.2f}\n(variance: {variance_threshold:.0f})' axes[2].set_title(title) plt.tight_layout() plt.show() # Classification report print("\n", metrics.classification_report(true, pred > threshold)) def calculate_metrics(): """ Calculates and plots various metrics about the task (ROC curve, precision/recall, distribution, classification matrix) """ cfg = BlurrinessConfig() files, true = get_paths_and_labels(cfg) if os.path.exists('blurriness_estimation/preds.npy'): # It takes a while to compute all the predictions so I saved them to iterate pred = np.load('blurriness_estimation/preds.npy') else: pred = predict(files, cfg) display_statistics(true, pred, cfg) display_metrics(true, pred) if __name__ == "__main__": calculate_metrics()