fg-mindee commited on
Commit
82a65d2
·
1 Parent(s): ea7721e

feat: Added option to retrieve CAMs from multiple layers

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -11,7 +11,8 @@ from io import BytesIO
11
  from torchvision import models
12
  from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
13
 
14
- from torchcam import cams
 
15
  from torchcam.utils import overlay_mask
16
 
17
 
@@ -59,14 +60,14 @@ def main():
59
  if tv_model is not None:
60
  with st.spinner('Loading model...'):
61
  model = models.__dict__[tv_model](pretrained=True).eval()
62
- default_layer = cams.utils.locate_candidate_layer(model, (3, 224, 224))
63
 
64
  target_layer = st.sidebar.text_input("Target layer", default_layer)
65
  cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
66
  if cam_method is not None:
67
- cam_extractor = cams.__dict__[cam_method](
68
  model,
69
- target_layer=target_layer if len(target_layer) > 0 else None
70
  )
71
 
72
  class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
@@ -94,16 +95,18 @@ def main():
94
  else:
95
  class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
96
  # Retrieve the CAM
97
- activation_map = cam_extractor(class_idx, out)[0]
 
 
98
  # Plot the raw heatmap
99
  fig, ax = plt.subplots()
100
- ax.imshow(activation_map.numpy())
101
  ax.axis('off')
102
  cols[1].pyplot(fig)
103
 
104
  # Overlayed CAM
105
  fig, ax = plt.subplots()
106
- result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5)
107
  ax.imshow(result)
108
  ax.axis('off')
109
  cols[-1].pyplot(fig)
 
11
  from torchvision import models
12
  from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
13
 
14
+ from torchcam import methods
15
+ from torchcam.methods._utils import locate_candidate_layer
16
  from torchcam.utils import overlay_mask
17
 
18
 
 
60
  if tv_model is not None:
61
  with st.spinner('Loading model...'):
62
  model = models.__dict__[tv_model](pretrained=True).eval()
63
+ default_layer = locate_candidate_layer(model, (3, 224, 224))
64
 
65
  target_layer = st.sidebar.text_input("Target layer", default_layer)
66
  cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
67
  if cam_method is not None:
68
+ cam_extractor = methods.__dict__[cam_method](
69
  model,
70
+ target_layer=target_layer.split("+") if len(target_layer) > 0 else None
71
  )
72
 
73
  class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
 
95
  else:
96
  class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
97
  # Retrieve the CAM
98
+ cams = cam_extractor(class_idx, out)
99
+ # Fuse the CAMs if there are several
100
+ cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
101
  # Plot the raw heatmap
102
  fig, ax = plt.subplots()
103
+ ax.imshow(cam.numpy())
104
  ax.axis('off')
105
  cols[1].pyplot(fig)
106
 
107
  # Overlayed CAM
108
  fig, ax = plt.subplots()
109
+ result = overlay_mask(img, to_pil_image(cam, mode='F'), alpha=0.5)
110
  ax.imshow(result)
111
  ax.axis('off')
112
  cols[-1].pyplot(fig)