File size: 1,832 Bytes
06a7cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

from . import config
from .transforms import test_transforms


def generate_confidences(
    model,
    input_img,
    num_top_preds,
):
    input_img = test_transforms(image=input_img)
    input_img = input_img["image"]

    input_img = input_img.unsqueeze(0)
    model.eval()
    log_probs = model(input_img)[0].detach()
    model.train()
    probs = torch.exp(log_probs)

    confidences = {
        config.CLASSES[i]: float(probs[i]) for i in range(len(config.CLASSES))
    }
    # Select top 5 confidences based on value
    confidences = {
        k: v
        for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[
            :num_top_preds
        ]
    }
    return input_img, confidences


def generate_gradcam(
    model,
    org_img,
    input_img,
    show_gradcam,
    gradcam_layer,
    gradcam_opacity,
):
    if show_gradcam:
        if gradcam_layer == -1:
            target_layers = [model.l3[-1]]
        elif gradcam_layer == -2:
            target_layers = [model.l2[-1]]

        cam = GradCAM(
            model=model,
            target_layers=target_layers,
        )
        grayscale_cam = cam(input_tensor=input_img, targets=None)
        grayscale_cam = grayscale_cam[0, :]

        visualization = show_cam_on_image(
            org_img / 255,
            grayscale_cam,
            use_rgb=True,
            image_weight=(1 - gradcam_opacity),
        )
    else:
        visualization = None
    return visualization


def generate_missclassified_imgs(
    model,
    show_misclassified,
    num_misclassified,
):
    if show_misclassified:
        plot = model.plot_incorrect_predictions_helper(num_misclassified)
    else:
        plot = None
    return plot