File size: 5,031 Bytes
e99cf64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from typing import Tuple, List

import cv2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn import metrics
from tqdm import tqdm

from blurriness_estimation.blurriness_estimator import BlurrinessEstimator
from config import BlurrinessConfig


def get_paths_and_labels(cfg: BlurrinessConfig) -> Tuple[List[str], np.ndarray]:
    """
    Returns the images paths and corresponding labels for the blur dataset
    Args:
        cfg: blurriness config

    Returns:
        Tuple[List[str], np.ndarray]: images paths and corresponding labels for the blur dataset
    """
    cfg.data_dir = 'data/blur/'
    sharp_dir = os.path.join(cfg.data_dir, "sharp")
    blurry_dir = os.path.join(cfg.data_dir, "blurry")
    sharp_paths = [os.path.join(sharp_dir, filename) for filename in os.listdir(sharp_dir)]
    blurry_paths = [os.path.join(blurry_dir, filename) for filename in os.listdir(blurry_dir)]
    files = sharp_paths + blurry_paths
    labels = np.hstack([np.zeros(len(sharp_paths)), np.ones(len(blurry_paths))])
    return files, labels


def predict(files: List[str], cfg: BlurrinessConfig) -> np.ndarray:
    """
    Calculates the variance of the laplacian for the given files
    Args:
        files: paths of images to process
        cfg: blurriness config

    Returns:
        np.ndarray: predictions as variance of laplacian values
    """
    estimator = BlurrinessEstimator(cfg)
    pred = []
    for file in tqdm(files):
        img = cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB)
        h, w, c = img.shape
        img = cv2.resize(img, (h // 4, w // 4), cv2.INTER_AREA)
        is_blurry, vol = estimator(img)  # vol: variance of laplacian
        pred.append(vol)
    pred = np.array(pred)
    np.save('blurriness_estimation/preds.npy', pred)
    return pred


def display_statistics(true: np.ndarray, pred: np.ndarray, cfg: BlurrinessConfig):
    """
    Displays the distribution of the variance of the laplacian in out test dataset
    Args:
        true: true labels (0: sharp, 1: blurry)
        pred: variance of laplacian
        cfg: blurriness config (to get the threshold selected)

    Returns:

    """
    pred = np.copy(pred).clip(0, 1000)
    data = pd.DataFrame.from_dict({'pred': pred, 'true': true})
    priors = (1 - 350/1050, 1 - 700/1050)
    weights = np.hstack([priors[0] * np.ones(350), priors[1] * np.ones(700)])
    sns.kdeplot(data, x='pred', hue='true', fill=True, weights=weights)
    plt.vlines([cfg.threshold], 0, plt.ylim()[1], colors=['r'], linestyles=['--'])
    plt.legend(labels=['sharp', 'blurry'])
    plt.xlabel('Variance of laplacian')
    plt.title("Distribution of Variance of Laplacian\nfor blurry vs sharp images")


def display_metrics(true: np.ndarray, pred: np.ndarray):
    """
    Displays Roc curve, precision-recall and confusion matrix
    Args:
        true: true labels (0: sharp, 1: blurry)
        pred: variance of laplacian
    """
    # We need to somehow convert the laplacian variance value to a range between 0 and 1 to easily use sklearn tools.
    # Let's clip this value to 1000 then divide by 1000.
    # Then do 1 - pred, because we want 1 to be "blur" since it's a blurriness detector, not a sharpness detector :)
    pred = 1 - pred.clip(0, 1000) / 1000
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    fig.suptitle("Blurriness metrics\n0: Sharp, 1: Blurry")

    # Roc Curve
    metrics.RocCurveDisplay.from_predictions(true, pred, ax=axes[0])

    # Precision Recall
    precision, recall, thresholds = metrics.precision_recall_curve(true, pred)
    average_precision = metrics.average_precision_score(true, pred)
    metrics.PrecisionRecallDisplay(precision=precision, recall=recall, average_precision=average_precision).plot(ax=axes[1])
    threshold_idx = int(len(precision) * 0.3)
    threshold = thresholds[threshold_idx]
    variance_threshold = (1 - threshold) * 1000

    # Confusion matrix
    metrics.ConfusionMatrixDisplay.from_predictions(true, pred > threshold, cmap='Blues', ax=axes[2])

    # Plots details
    axes[1].scatter(recall[threshold_idx], precision[threshold_idx], 50, 'r', 'X')
    title = f'Threshold: {threshold:.2f}\n(variance: {variance_threshold:.0f})'
    axes[2].set_title(title)
    plt.tight_layout()
    plt.show()

    # Classification report
    print("\n", metrics.classification_report(true, pred > threshold))


def calculate_metrics():
    """
    Calculates and plots various metrics about the task
    (ROC curve, precision/recall, distribution, classification matrix)
    """
    cfg = BlurrinessConfig()
    files, true = get_paths_and_labels(cfg)
    if os.path.exists('blurriness_estimation/preds.npy'):
        # It takes a while to compute all the predictions so I saved them to iterate
        pred = np.load('blurriness_estimation/preds.npy')
    else:
        pred = predict(files, cfg)
    display_statistics(true, pred, cfg)
    display_metrics(true, pred)


if __name__ == "__main__":
    calculate_metrics()