fg-mindee
commited on
Commit
·
82a65d2
1
Parent(s):
ea7721e
feat: Added option to retrieve CAMs from multiple layers
Browse files
app.py
CHANGED
@@ -11,7 +11,8 @@ from io import BytesIO
|
|
11 |
from torchvision import models
|
12 |
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
|
13 |
|
14 |
-
from torchcam import
|
|
|
15 |
from torchcam.utils import overlay_mask
|
16 |
|
17 |
|
@@ -59,14 +60,14 @@ def main():
|
|
59 |
if tv_model is not None:
|
60 |
with st.spinner('Loading model...'):
|
61 |
model = models.__dict__[tv_model](pretrained=True).eval()
|
62 |
-
default_layer =
|
63 |
|
64 |
target_layer = st.sidebar.text_input("Target layer", default_layer)
|
65 |
cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
|
66 |
if cam_method is not None:
|
67 |
-
cam_extractor =
|
68 |
model,
|
69 |
-
target_layer=target_layer if len(target_layer) > 0 else None
|
70 |
)
|
71 |
|
72 |
class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
|
@@ -94,16 +95,18 @@ def main():
|
|
94 |
else:
|
95 |
class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
|
96 |
# Retrieve the CAM
|
97 |
-
|
|
|
|
|
98 |
# Plot the raw heatmap
|
99 |
fig, ax = plt.subplots()
|
100 |
-
ax.imshow(
|
101 |
ax.axis('off')
|
102 |
cols[1].pyplot(fig)
|
103 |
|
104 |
# Overlayed CAM
|
105 |
fig, ax = plt.subplots()
|
106 |
-
result = overlay_mask(img, to_pil_image(
|
107 |
ax.imshow(result)
|
108 |
ax.axis('off')
|
109 |
cols[-1].pyplot(fig)
|
|
|
11 |
from torchvision import models
|
12 |
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
|
13 |
|
14 |
+
from torchcam import methods
|
15 |
+
from torchcam.methods._utils import locate_candidate_layer
|
16 |
from torchcam.utils import overlay_mask
|
17 |
|
18 |
|
|
|
60 |
if tv_model is not None:
|
61 |
with st.spinner('Loading model...'):
|
62 |
model = models.__dict__[tv_model](pretrained=True).eval()
|
63 |
+
default_layer = locate_candidate_layer(model, (3, 224, 224))
|
64 |
|
65 |
target_layer = st.sidebar.text_input("Target layer", default_layer)
|
66 |
cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
|
67 |
if cam_method is not None:
|
68 |
+
cam_extractor = methods.__dict__[cam_method](
|
69 |
model,
|
70 |
+
target_layer=target_layer.split("+") if len(target_layer) > 0 else None
|
71 |
)
|
72 |
|
73 |
class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
|
|
|
95 |
else:
|
96 |
class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
|
97 |
# Retrieve the CAM
|
98 |
+
cams = cam_extractor(class_idx, out)
|
99 |
+
# Fuse the CAMs if there are several
|
100 |
+
cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
|
101 |
# Plot the raw heatmap
|
102 |
fig, ax = plt.subplots()
|
103 |
+
ax.imshow(cam.numpy())
|
104 |
ax.axis('off')
|
105 |
cols[1].pyplot(fig)
|
106 |
|
107 |
# Overlayed CAM
|
108 |
fig, ax = plt.subplots()
|
109 |
+
result = overlay_mask(img, to_pil_image(cam, mode='F'), alpha=0.5)
|
110 |
ax.imshow(result)
|
111 |
ax.axis('off')
|
112 |
cols[-1].pyplot(fig)
|