PedroMartelleto commited on
Commit
d4a3403
1 Parent(s): e90a667

Deploying to HF

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -52,10 +52,10 @@ class Explainer:
52
  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
53
  return PIL.Image.fromarray(data)
54
 
55
- def shap(self):
56
  gradient_shap = GradientShap(self.model)
57
  rand_img_dist = torch.cat([self.input * 0, self.input * 1])
58
- attributions_gs = gradient_shap.attribute(self.input, n_samples=50, stdevs=0.0001, baselines=rand_img_dist, target=self.pred_label_idx)
59
  fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
60
  np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
61
  ["original_image", "heat_map"],
@@ -65,13 +65,13 @@ class Explainer:
65
  fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
66
  return self.convert_fig_to_pil(fig)
67
 
68
- def occlusion(self):
69
  occlusion = Occlusion(model)
70
 
71
  attributions_occ = occlusion.attribute(self.input,
72
  target=self.pred_label_idx,
73
- strides=(3, 8, 8),
74
- sliding_window_shapes=(3,15, 15),
75
  baselines=0)
76
 
77
  fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
@@ -103,20 +103,6 @@ class Explainer:
103
  fig_size=(18, 6))
104
  fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
105
  return self.convert_fig_to_pil(fig)
106
-
107
- def integrated_gradients(self):
108
- integrated_gradients = IntegratedGradients(self.model)
109
- attributions_ig = integrated_gradients.attribute(self.input, target=self.pred_label_idx, n_steps=50)
110
-
111
- fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1,2,0)),
112
- np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
113
- ["original_image", "heat_map", "masked_image"],
114
- ["all", "positive", "positive"],
115
- show_colorbar=True,
116
- titles=["Original", "Attribution", "Masked"],
117
- fig_size=(18, 6))
118
- fig.suptitle("Integrated gradients | " + self.fig_title, fontsize=12)
119
- return self.convert_fig_to_pil(fig)
120
 
121
  def create_model_from_checkpoint():
122
  # Loads a model from a checkpoint
@@ -129,12 +115,21 @@ def create_model_from_checkpoint():
129
  model = create_model_from_checkpoint()
130
  labels = [ "benign", "malignant", "normal" ]
131
 
132
- def predict(img):
133
  explainer = Explainer(model, img, labels)
134
- return [explainer.confidences, explainer.shap(), explainer.occlusion(), explainer.gradcam()]
 
 
 
135
 
136
  ui = gr.Interface(fn=predict,
137
- inputs=gr.Image(type="pil"),
 
 
 
 
 
 
138
  outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
139
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
140
  ui.launch(share=True)
 
52
  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
53
  return PIL.Image.fromarray(data)
54
 
55
+ def shap(self, n_samples, stdevs):
56
  gradient_shap = GradientShap(self.model)
57
  rand_img_dist = torch.cat([self.input * 0, self.input * 1])
58
+ attributions_gs = gradient_shap.attribute(self.input, n_samples=int(n_samples), stdevs=stdevs, baselines=rand_img_dist, target=self.pred_label_idx)
59
  fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
60
  np.transpose(self.transformed_img.squeeze().cpu().detach().numpy(), (1,2,0)),
61
  ["original_image", "heat_map"],
 
65
  fig.suptitle("SHAP | " + self.fig_title, fontsize=12)
66
  return self.convert_fig_to_pil(fig)
67
 
68
+ def occlusion(self, stride, sliding_window):
69
  occlusion = Occlusion(model)
70
 
71
  attributions_occ = occlusion.attribute(self.input,
72
  target=self.pred_label_idx,
73
+ strides=(3, int(stride), int(stride)),
74
+ sliding_window_shapes=(3, int(sliding_window), int(sliding_window)),
75
  baselines=0)
76
 
77
  fig, _ = viz.visualize_image_attr_multiple(np.transpose(attributions_occ.squeeze().cpu().detach().numpy(), (1,2,0)),
 
103
  fig_size=(18, 6))
104
  fig.suptitle("GradCAM layer3[1].conv2 | " + self.fig_title, fontsize=12)
105
  return self.convert_fig_to_pil(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def create_model_from_checkpoint():
108
  # Loads a model from a checkpoint
 
115
  model = create_model_from_checkpoint()
116
  labels = [ "benign", "malignant", "normal" ]
117
 
118
+ def predict(img, shap_samples, shap_stdevs, occlusion_stride, occlusion_window):
119
  explainer = Explainer(model, img, labels)
120
+ return [explainer.confidences,
121
+ explainer.shap(shap_samples, shap_stdevs),
122
+ explainer.occlusion(occlusion_stride, occlusion_window),
123
+ explainer.gradcam()]
124
 
125
  ui = gr.Interface(fn=predict,
126
+ inputs=[
127
+ gr.Image(type="pil"),
128
+ gr.Slider(minimum=10, maximum=100, default=50, label="SHAP Samples"),
129
+ gr.Slider(minimum=0.0001, maximum=0.01, default=0.0001, label="SHAP Stdevs"),
130
+ gr.Slider(minimum=4, maximum=80, default=8, label="Occlusion Stride"),
131
+ gr.Slider(minimum=4, maximum=80, default=15, label="Occlusion Window")
132
+ ],
133
  outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],
134
  examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch()
135
  ui.launch(share=True)