Spaces:
Runtime error
Runtime error
PedroMartelleto
commited on
Commit
•
d4a3403
1
Parent(s):
e90a667
Deploying to HF
Browse files
app.py
CHANGED
@@ -52,10 +52,10 @@ class Explainer:
|
|
52 |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
53 |
return PIL.Image.fromarray(data)
|
54 |
|
55 |
-
def shap(self):
|
56 |
gradient_shap = GradientShap(self.model)
|
57 |
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
|
58 |
-
attributions_gs = gradient_shap.attribute(self.input, n_samples=
|
59 |
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
|
60 |
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
61 |
["original_image", "heat_map"],
|
@@ -65,13 +65,13 @@ class Explainer:
|
|
65 |
fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
|
66 |
return self.convert_fig_to_pil(fig)
|
67 |
|
68 |
-
def occlusion(self):
|
69 |
occlusion = Occlusion(model)
|
70 |
|
71 |
attributions_occ = occlusion.attribute(self.input,
|
72 |
target=self.pred_label_idx,
|
73 |
-
strides=(3,
|
74 |
-
sliding_window_shapes=(3,
|
75 |
baselines=0)
|
76 |
|
77 |
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
|
@@ -103,20 +103,6 @@ class Explainer:
|
|
103 |
fig_size=(18, 6))
|
104 |
fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
|
105 |
return self.convert_fig_to_pil(fig)
|
106 |
-
|
107 |
-
def integrated_gradients(self):
|
108 |
-
integrated_gradients = IntegratedGradients(self.model)
|
109 |
-
attributions_ig = integrated_gradients.attribute(self.input, target=self.pred_label_idx, n_steps=50)
|
110 |
-
|
111 |
-
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
|
112 |
-
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
113 |
-
["original_image", "heat_map", "masked_image"],
|
114 |
-
["all", "positive", "positive"],
|
115 |
-
show_colorbar=True,
|
116 |
-
titles=["Original", "Attribution", "Masked"],
|
117 |
-
fig_size=(18, 6))
|
118 |
-
fig.suptitle("Integrated gradients | " + self.fig_title, fontsize=12)
|
119 |
-
return self.convert_fig_to_pil(fig)
|
120 |
|
121 |
def create_model_from_checkpoint():
|
122 |
# Loads a model from a checkpoint
|
@@ -129,12 +115,21 @@ def create_model_from_checkpoint():
|
|
129 |
model = create_model_from_checkpoint()
|
130 |
labels = [ "benign", "malignant", "normal" ]
|
131 |
|
132 |
-
def predict(img):
|
133 |
explainer = Explainer(model, img, labels)
|
134 |
-
return [explainer.confidences,
|
|
|
|
|
|
|
135 |
|
136 |
ui = gr.Interface(fn=predict,
|
137 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
|
139 |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
|
140 |
ui.launch(share=True)
|
|
|
52 |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
53 |
return PIL.Image.fromarray(data)
|
54 |
|
55 |
+
def shap(self, n_samples, stdevs):
|
56 |
gradient_shap = GradientShap(self.model)
|
57 |
rand_img_dist = torch.cat([self.input * 0, self.input * 1])
|
58 |
+
attributions_gs = gradient_shap.attribute(self.input, n_samples=int(n_samples), stdevs=stdevs, baselines=rand_img_dist, target=self.pred_label_idx)
|
59 |
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
|
60 |
np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
|
61 |
["original_image", "heat_map"],
|
|
|
65 |
fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
|
66 |
return self.convert_fig_to_pil(fig)
|
67 |
|
68 |
+
def occlusion(self, stride, sliding_window):
|
69 |
occlusion = Occlusion(model)
|
70 |
|
71 |
attributions_occ = occlusion.attribute(self.input,
|
72 |
target=self.pred_label_idx,
|
73 |
+
strides=(3, int(stride), int(stride)),
|
74 |
+
sliding_window_shapes=(3, int(sliding_window), int(sliding_window)),
|
75 |
baselines=0)
|
76 |
|
77 |
fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
|
|
|
103 |
fig_size=(18, 6))
|
104 |
fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
|
105 |
return self.convert_fig_to_pil(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def create_model_from_checkpoint():
|
108 |
# Loads a model from a checkpoint
|
|
|
115 |
model = create_model_from_checkpoint()
|
116 |
labels = [ "benign", "malignant", "normal" ]
|
117 |
|
118 |
+
def predict(img, shap_samples, shap_stdevs, occlusion_stride, occlusion_window):
|
119 |
explainer = Explainer(model, img, labels)
|
120 |
+
return [explainer.confidences,
|
121 |
+
explainer.shap(shap_samples, shap_stdevs),
|
122 |
+
explainer.occlusion(occlusion_stride, occlusion_window),
|
123 |
+
explainer.gradcam()]
|
124 |
|
125 |
ui = gr.Interface(fn=predict,
|
126 |
+
inputs=[
|
127 |
+
gr.Image(type="pil"),
|
128 |
+
gr.Slider(minimum=10, maximum=100, default=50, label="SHAP Samples"),
|
129 |
+
gr.Slider(minimum=0.0001, maximum=0.01, default=0.0001, label="SHAP Stdevs"),
|
130 |
+
gr.Slider(minimum=4, maximum=80, default=8, label="Occlusion Stride"),
|
131 |
+
gr.Slider(minimum=4, maximum=80, default=15, label="Occlusion Window")
|
132 |
+
],
|
133 |
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
|
134 |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
|
135 |
ui.launch(share=True)
|