Spaces:
Runtime error
Runtime error
File size: 4,843 Bytes
0fb8d2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import requests
from PIL import Image
from io import BytesIO
from transformers import (
AutoModelForImageClassification,
AutoFeatureExtractor,
AutoConfig,
)
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchcam import methods
# TODO I have an error with those
# CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "LayerCAM"]
SUPPORTED_MODELS = ["convnext"]
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("TorchCAM 📸 and Transformers 🤗")
st.header("Class activation explorer")
# For newline
st.write("\n")
st.write("`torch-cam`: https://github.com/frgfm/torch-cam")
st.write("`transformers`: https://github.com/huggingface/transformers")
st.write("Upload an image, select your CAM method and hit the Compute Cam button!")
# For newline
st.write("\n")
# Set the columns
cols = st.columns((1, 1))
cols[0].header("Input image")
cols[1].header("Overlayed CAM")
# Sidebar
# File selection
st.sidebar.title("Input selection")
# Disabling warning
st.set_option("deprecation.showfileUploaderEncoding", False)
# Choose your own image
uploaded_file = st.sidebar.file_uploader(
"Upload files", type=["png", "jpeg", "jpg"]
)
if uploaded_file is not None:
img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
else:
r = requests.get(
"https://i.insider.com/5df126b679d7570ad2044f3e?width=700&format=jpeg&auto=webp"
)
img = Image.open(BytesIO(r.content))
cols[0].image(img, use_column_width=True)
model_name = st.sidebar.text_input("Model name", "facebook/convnext-tiny-224")
if model_name is not None:
with st.spinner("Loading model..."):
config = AutoConfig.from_pretrained(model_name)
model_type = config.model_type
if model_type not in SUPPORTED_MODELS:
st.warning(
f"{model_type} not in supported models: {','.join(SUPPORTED_MODELS)}"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
if cam_method is not None:
cam_extractor = methods.__dict__[cam_method](
model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
)
# label choices
class_choices = [
f"{idx + 1} - {class_name}" for idx, class_name in model.config.id2label.items()
]
class_selection = st.sidebar.selectbox(
"Class selection", ["Predicted class (argmax)"] + class_choices
)
# for newline
st.sidebar.write("\n")
if st.sidebar.button("Compute CAM"):
# compute cam
if img is None:
st.sidebar.error("Please upload an image first")
else:
with st.spinner("Analyzing..."):
# Set your CAM extractor
cam_extractor = GradCAM(
model, target_layer=model.convnext.encoder.stages[-1].layers[-1]
)
inputs = feature_extractor(img, return_tensors="pt")
logits = model(**inputs).logits
# select the target class
if class_selection == "Predicted class (argmax)":
class_idx = logits.squeeze(0).argmax().item()
else:
class_idx = model.config.label2id[
class_selection.rpartition(" - ")[-1]
]
print(class_idx)
# run the cam extractor
cams = cam_extractor(class_idx, logits)
cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
# resize + overlay
result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
# display it
fig, ax = plt.subplots()
result = overlay_mask(img, to_pil_image(cam, mode="F"), alpha=0.5)
ax.imshow(result)
ax.axis("off")
cols[1].pyplot(fig)
if class_selection == "Predicted class (argmax)":
# show the predicted class
st.markdown(
f"<p style='text align: center'> Predicted class is {config.id2label[class_idx]}</p>",
unsafe_allow_html=True,
)
main()
|