Francesco commited on
Commit
2e9772b
·
1 Parent(s): 13497c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -119
app.py CHANGED
@@ -1,24 +1,6 @@
1
  import streamlit as st
2
 
3
- import requests
4
- from PIL import Image
5
- from io import BytesIO
6
- from transformers import (
7
- AutoModelForImageClassification,
8
- AutoFeatureExtractor,
9
- AutoConfig,
10
- )
11
- from torchcam.methods import GradCAM
12
- from torchcam.utils import overlay_mask
13
- import matplotlib.pyplot as plt
14
- from torchvision.transforms.functional import to_pil_image
15
- from torchcam import methods
16
-
17
- # TODO I have an error with those
18
- # CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
19
- CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "LayerCAM"]
20
-
21
- SUPPORTED_MODELS = ["convnext"]
22
 
23
 
24
  def main():
@@ -26,105 +8,7 @@ def main():
26
  st.set_page_config(layout="wide")
27
 
28
  # Designing the interface
29
- st.title("TorchCAM 📸 and Transformers 🤗")
30
- st.header("Class activation explorer")
31
- # For newline
32
- st.write("\n")
33
- st.write("`torch-cam`: https://github.com/frgfm/torch-cam")
34
- st.write("`transformers`: https://github.com/huggingface/transformers")
35
- st.write("Upload an image, select your CAM method and hit the Compute Cam button!")
36
-
37
- # For newline
38
- st.write("\n")
39
- # Set the columns
40
- cols = st.columns((1, 1))
41
- cols[0].header("Input image")
42
- cols[1].header("Overlayed CAM")
43
- # Sidebar
44
- # File selection
45
- st.sidebar.title("Input selection")
46
- # Disabling warning
47
- st.set_option("deprecation.showfileUploaderEncoding", False)
48
- # Choose your own image
49
- uploaded_file = st.sidebar.file_uploader(
50
- "Upload files", type=["png", "jpeg", "jpg"]
51
- )
52
- if uploaded_file is not None:
53
- img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
54
- else:
55
- r = requests.get(
56
- "https://i.insider.com/5df126b679d7570ad2044f3e?width=700&format=jpeg&auto=webp"
57
- )
58
- img = Image.open(BytesIO(r.content))
59
- cols[0].image(img, use_column_width=True)
60
-
61
- model_name = st.sidebar.text_input("Model name", "facebook/convnext-tiny-224")
62
-
63
- if model_name is not None:
64
- with st.spinner("Loading model..."):
65
- config = AutoConfig.from_pretrained(model_name)
66
- model_type = config.model_type
67
- if model_type not in SUPPORTED_MODELS:
68
- st.warning(
69
- f"{model_type} not in supported models: {','.join(SUPPORTED_MODELS)}"
70
- )
71
- else:
72
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
73
- model = AutoModelForImageClassification.from_pretrained(model_name)
74
-
75
- cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
76
- if cam_method is not None:
77
- cam_extractor = methods.__dict__[cam_method](
78
- model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
79
- )
80
-
81
- # label choices
82
- class_choices = [
83
- f"{idx + 1} - {class_name}" for idx, class_name in model.config.id2label.items()
84
- ]
85
- class_selection = st.sidebar.selectbox(
86
- "Class selection", ["Predicted class (argmax)"] + class_choices
87
- )
88
- # for newline
89
- st.sidebar.write("\n")
90
-
91
- if st.sidebar.button("Compute CAM"):
92
- # compute cam
93
- if img is None:
94
- st.sidebar.error("Please upload an image first")
95
- else:
96
- with st.spinner("Analyzing..."):
97
- # Set your CAM extractor
98
- cam_extractor = GradCAM(
99
- model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
100
- )
101
- inputs = feature_extractor(img, return_tensors="pt")
102
- logits = model(**inputs).logits
103
- # select the target class
104
- if class_selection == "Predicted class (argmax)":
105
- class_idx = logits.squeeze(0).argmax().item()
106
- else:
107
- class_idx = model.config.label2id[
108
- class_selection.rpartition(" - ")[-1]
109
- ]
110
- print(class_idx)
111
- # run the cam extractor
112
- cams = cam_extractor(class_idx, logits)
113
- cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
114
- # resize + overlay
115
- result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
116
- # display it
117
- fig, ax = plt.subplots()
118
- result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
119
- ax.imshow(result)
120
- ax.axis("off")
121
- cols[1].pyplot(fig)
122
- if class_selection == "Predicted class (argmax)":
123
- # show the predicted class
124
- st.markdown(
125
- f"<p style='text align: center'> Predicted class is {config.id2label[class_idx]}</p>",
126
- unsafe_allow_html=True,
127
- )
128
-
129
 
130
  main()
 
1
  import streamlit as st
2
 
3
+ import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def main():
 
8
  st.set_page_config(layout="wide")
9
 
10
  # Designing the interface
11
+ st.title("sys.version)")
12
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  main()