File size: 4,843 Bytes
0fb8d2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st

import requests
from PIL import Image
from io import BytesIO
from transformers import (
    AutoModelForImageClassification,
    AutoFeatureExtractor,
    AutoConfig,
)
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchcam import methods

# TODO I have an error with those
# CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "LayerCAM"]

SUPPORTED_MODELS = ["convnext"]


def main():
    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("TorchCAM 📸 and Transformers 🤗")
    st.header("Class activation explorer")
    # For newline
    st.write("\n")
    st.write("`torch-cam`: https://github.com/frgfm/torch-cam")
    st.write("`transformers`: https://github.com/huggingface/transformers")
    st.write("Upload an image, select your CAM method and hit the Compute Cam button!")

    # For newline
    st.write("\n")
    # Set the columns
    cols = st.columns((1, 1))
    cols[0].header("Input image")
    cols[1].header("Overlayed CAM")
    # Sidebar
    # File selection
    st.sidebar.title("Input selection")
    # Disabling warning
    st.set_option("deprecation.showfileUploaderEncoding", False)
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader(
        "Upload files", type=["png", "jpeg", "jpg"]
    )
    if uploaded_file is not None:
        img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
    else:
        r = requests.get(
            "https://i.insider.com/5df126b679d7570ad2044f3e?width=700&format=jpeg&auto=webp"
        )
        img = Image.open(BytesIO(r.content))
    cols[0].image(img, use_column_width=True)

    model_name = st.sidebar.text_input("Model name", "facebook/convnext-tiny-224")

    if model_name is not None:
        with st.spinner("Loading model..."):
            config = AutoConfig.from_pretrained(model_name)
            model_type = config.model_type
            if model_type not in SUPPORTED_MODELS:
                st.warning(
                    f"{model_type} not in supported models: {','.join(SUPPORTED_MODELS)}"
                )
            else:
                feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
                model = AutoModelForImageClassification.from_pretrained(model_name)

    cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
    if cam_method is not None:
        cam_extractor = methods.__dict__[cam_method](
            model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
        )

    # label choices
    class_choices = [
        f"{idx + 1} - {class_name}" for idx, class_name in model.config.id2label.items()
    ]
    class_selection = st.sidebar.selectbox(
        "Class selection", ["Predicted class (argmax)"] + class_choices
    )
    # for newline
    st.sidebar.write("\n")

    if st.sidebar.button("Compute CAM"):
        # compute cam
        if img is None:
            st.sidebar.error("Please upload an image first")
        else:
            with st.spinner("Analyzing..."):
                # Set your CAM extractor
                cam_extractor = GradCAM(
                    model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
                )
                inputs = feature_extractor(img, return_tensors="pt")
                logits = model(**inputs).logits
                # select the target class
                if class_selection == "Predicted class (argmax)":
                    class_idx = logits.squeeze(0).argmax().item()
                else:
                    class_idx = model.config.label2id[
                        class_selection.rpartition(" - ")[-1]
                    ]
                print(class_idx)
                # run the cam extractor
                cams = cam_extractor(class_idx, logits)
                cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
                # resize + overlay
                result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
                # display it
                fig, ax = plt.subplots()
                result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
                ax.imshow(result)
                ax.axis("off")
                cols[1].pyplot(fig)
                if class_selection == "Predicted class (argmax)":
                    # show the predicted class
                    st.markdown(
                        f"<p style='text align: center'> Predicted class is {config.id2label[class_idx]}</p>",
                        unsafe_allow_html=True,
                    )


main()