fg-mindee commited on
Commit
56fb801
1 Parent(s): 2254e7f

feat: Added Streamlit app

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2020-2021, François-Guillaume Fernandez.
2
+
3
+ # This program is licensed under the Apache License version 2.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
+
6
+ import requests
7
+ import streamlit as st
8
+ import matplotlib.pyplot as plt
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from torchvision import models
12
+ from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
13
+
14
+ from torchcam import cams
15
+ from torchcam.utils import overlay_mask
16
+
17
+
18
+ CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
19
+ TV_MODELS = ["resnet18", "resnet50", "mobilenet_v2", "mobilenet_v3_small", "mobilenet_v3_large"]
20
+ LABEL_MAP = requests.get(
21
+ "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
22
+ ).json()
23
+
24
+
25
+ def main():
26
+
27
+ # Wide mode
28
+ st.set_page_config(layout="wide")
29
+
30
+ # Designing the interface
31
+ st.title("TorchCAM: class activation explorer")
32
+ # For newline
33
+ st.write('\n')
34
+ # Set the columns
35
+ cols = st.columns((1, 1, 1))
36
+ cols[0].header("Input image")
37
+ cols[1].header("Raw CAM")
38
+ cols[-1].header("Overlayed CAM")
39
+
40
+ # Sidebar
41
+ # File selection
42
+ st.sidebar.title("Input selection")
43
+ # Disabling warning
44
+ st.set_option('deprecation.showfileUploaderEncoding', False)
45
+ # Choose your own image
46
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg'])
47
+ if uploaded_file is not None:
48
+ img = Image.open(BytesIO(uploaded_file.read()), mode='r').convert('RGB')
49
+
50
+ cols[0].image(img, use_column_width=True)
51
+
52
+ # Model selection
53
+ st.sidebar.title("Setup")
54
+ tv_model = st.sidebar.selectbox("Classification model", TV_MODELS)
55
+ default_layer = ""
56
+ if tv_model is not None:
57
+ with st.spinner('Loading model...'):
58
+ model = models.__dict__[tv_model](pretrained=True).eval()
59
+ default_layer = cams.utils.locate_candidate_layer(model, (3, 224, 224))
60
+
61
+ target_layer = st.sidebar.text_input("Target layer", default_layer)
62
+ cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
63
+ if cam_method is not None:
64
+ cam_extractor = cams.__dict__[cam_method](
65
+ model,
66
+ target_layer=target_layer if len(target_layer) > 0 else None
67
+ )
68
+
69
+ class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
70
+ class_selection = st.sidebar.selectbox("Class selection", ["Predicted class (argmax)"] + class_choices)
71
+
72
+ # For newline
73
+ st.sidebar.write('\n')
74
+
75
+ if st.sidebar.button("Compute CAM"):
76
+
77
+ if uploaded_file is None:
78
+ st.sidebar.error("Please upload an image first")
79
+
80
+ else:
81
+ with st.spinner('Analyzing...'):
82
+
83
+ # Preprocess image
84
+ img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
85
+
86
+ # Forward the image to the model
87
+ out = model(img_tensor.unsqueeze(0))
88
+ # Select the target class
89
+ if class_selection == "Predicted class (argmax)":
90
+ class_idx = out.squeeze(0).argmax().item()
91
+ else:
92
+ class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
93
+ # Retrieve the CAM
94
+ activation_map = cam_extractor(class_idx, out)[0]
95
+ # Plot the raw heatmap
96
+ fig, ax = plt.subplots()
97
+ ax.imshow(activation_map.numpy())
98
+ ax.axis('off')
99
+ cols[1].pyplot(fig)
100
+
101
+ # Overlayed CAM
102
+ fig, ax = plt.subplots()
103
+ result = overlay_mask(img, to_pil_image(activation_map, mode='F'), alpha=0.5)
104
+ ax.imshow(result)
105
+ ax.axis('off')
106
+ cols[-1].pyplot(fig)
107
+
108
+
109
+ if __name__ == '__main__':
110
+ main()