File size: 4,722 Bytes
032c7aa
c396e65
032c7aa
 
 
 
 
ad937a3
032c7aa
ad937a3
 
1a23377
032c7aa
 
 
 
7736fa8
 
 
 
032c7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
1b87171
032c7aa
41433b6
032c7aa
 
41433b6
032c7aa
ffd57e9
032c7aa
 
 
 
 
 
 
1b87171
2134128
1b87171
 
032c7aa
d4a3403
032c7aa
 
d4a3403
032c7aa
 
 
9745f9b
c396e65
 
35677f0
c396e65
 
ad937a3
 
 
 
87fbe80
ad937a3
 
 
 
 
 
19010f0
1a23377
d4a3403
 
 
 
ad937a3
 
d4a3403
 
7d550ac
 
 
 
d4a3403
35677f0
19010f0
 
 
 
 
 
ad937a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import PIL
from captum.attr import GradientShap, Occlusion, LayerGradCam, LayerAttribution, IntegratedGradients
from captum.attr import visualization as viz
import torch
from torchvision import transforms
from matplotlib.colors import LinearSegmentedColormap
import torch.nn.functional as F
import gradio as gr
from torchvision.models import resnet50
import torch.nn as nn
import torch
import numpy as np

class Explainer:
    def __init__(self, model, img, class_names):
        self.model = model
        self.default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                [(0, '#ffffff'),
                                                (0.25, '#000000'),
                                                (1, '#000000')], N=256)
        self.class_names = class_names

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])

        transform_normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

        self.transformed_img = transform(img)

        self.input = transform_normalize(self.transformed_img)
        self.input = self.input.unsqueeze(0)

        with torch.no_grad():
            self.output = self.model(self.input)
            self.output = F.softmax(self.output, dim=1)

        self.confidences = {class_names[i]: float(self.output[0, i]) for i in range(3)}

        self.pred_score, self.pred_label_idx = torch.topk(self.output, 1)
        self.pred_label = self.class_names[self.pred_label_idx]
        self.fig_title = 'Predicted: ' + self.pred_label + ' (' + str(round(self.pred_score.squeeze().item(), 2)) + ')'

    def convert_fig_to_pil(self, fig):
        fig.canvas.draw()
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        return PIL.Image.fromarray(data)

    def shap(self, n_samples, stdevs):
        gradient_shap = GradientShap(self.model)
        rand_img_dist = torch.cat([self.input * 0, self.input * 1])
        attributions_gs = gradient_shap.attribute(self.input, n_samples=int(n_samples), stdevs=stdevs, baselines=rand_img_dist, target=self.pred_label_idx)
        fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
                                            ["original_image", "heat_map"],
show_colorbar=True,
                                            titles=["Original", "Positive Attribution", "Masked"],
                                            fig_size=(18, 6))
        fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
        return self.convert_fig_to_pil(fig)

def create_model_from_checkpoint():
    # Loads a model from a checkpoint
    model = resnet50()
    model.fc = nn.Linear(model.fc.in_features, 3)
    model.load_state_dict(torch.load("best_model", map_location=torch.device('cpu')))
    model.eval()
    return model

model = create_model_from_checkpoint()
labels = [ "benign", "malignant", "normal" ]

def predict(img, shap_samples, shap_stdevs, occlusion_stride, occlusion_window):
    explainer = Explainer(model, img, labels)
    return [explainer.confidences,
            explainer.shap(shap_samples, shap_stdevs),
            explainer.occlusion(occlusion_stride, occlusion_window),
            explainer.gradcam()] 

ui = gr.Interface(fn=predict, 
                inputs=[
                    gr.Image(type="pil"),
                    gr.Slider(minimum=10, maximum=100, default=50, label="SHAP Samples", step=1),
                    gr.Slider(minimum=0.0001, maximum=0.01, default=0.0001, label="SHAP Stdevs", step=0.0001),
                    gr.Slider(minimum=4, maximum=80, default=8, label="Occlusion Stride", step=1),
                    gr.Slider(minimum=4, maximum=80, default=15, label="Occlusion Window", step=1)
                ],
                outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
                examples=[["benign (52).png", 50, 0.0001, 8, 15],
                          ["benign (243).png", 50, 0.0001, 8, 15],
                          ["malignant (127).png", 50, 0.0001, 8, 15],
                          ["malignant (201).png", 50, 0.0001, 8, 15],
                          ["normal (81).png", 50, 0.0001, 8, 15], 
                          ["normal (101).png", 50, 0.0001, 8, 15]]).launch()
ui.launch(share=True)