File size: 780 Bytes
c4b2b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import matplotlib.pyplot as plt

from generic_utils import generate_visualization


def do_lrp(transform, image, class_index=None):
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(image)
    axs[0].axis("off")

    transformed_image = transform(image)
    viz = generate_visualization(
        transformed_image, class_index=class_index, method="full"
    )

    axs[1].imshow(viz)
    axs[1].axis("off")
    return fig


def do_partial_lrp(transform, image, class_index=None):
    fig, axs = plt.subplots(1, 2)
    axs[0].imshow(image)
    axs[0].axis("off")

    transformed_image = transform(image)
    viz = generate_visualization(
        transformed_image, class_index=class_index, method="last_layer"
    )

    axs[1].imshow(viz)
    axs[1].axis("off")
    return fig