Don't run activations viz with ResNet model, to save memory
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ from CLIP_Explainability.vit_cam import (
|
|
27 |
|
28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
29 |
|
30 |
-
RUN_LITE =
|
31 |
|
32 |
MAX_IMG_WIDTH = 500
|
33 |
MAX_IMG_HEIGHT = 800
|
@@ -110,6 +110,10 @@ def clip_search(search_query):
|
|
110 |
|
111 |
|
112 |
def string_search():
|
|
|
|
|
|
|
|
|
113 |
if "search_field_value" in st.session_state:
|
114 |
clip_search(st.session_state.search_field_value)
|
115 |
|
@@ -179,10 +183,9 @@ def init():
|
|
179 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
180 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
)
|
186 |
|
187 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
188 |
ja_model_name, trust_remote_code=True
|
@@ -216,6 +219,9 @@ def init():
|
|
216 |
st.session_state.search_image_ids = []
|
217 |
st.session_state.search_image_scores = {}
|
218 |
st.session_state.text_table_df = None
|
|
|
|
|
|
|
219 |
|
220 |
with st.spinner("Loading models and data, please wait..."):
|
221 |
load_image_features()
|
@@ -430,7 +436,7 @@ def visualize_gradcam(image):
|
|
430 |
|
431 |
header_cols = st.columns([80, 20], vertical_alignment="bottom")
|
432 |
with header_cols[0]:
|
433 |
-
st.title("Image + query
|
434 |
with header_cols[1]:
|
435 |
if st.button("Close"):
|
436 |
st.rerun()
|
@@ -457,6 +463,8 @@ def visualize_gradcam(image):
|
|
457 |
st.session_state.search_field_value
|
458 |
)
|
459 |
|
|
|
|
|
460 |
with st.spinner("Calculating..."):
|
461 |
# info_text = st.text("Calculating activation regions...")
|
462 |
|
@@ -743,6 +751,7 @@ with controls[3]:
|
|
743 |
key="uploaded_image",
|
744 |
label_visibility="collapsed",
|
745 |
on_change=vis_uploaded_image,
|
|
|
746 |
)
|
747 |
|
748 |
|
@@ -777,7 +786,9 @@ for image_id in batch:
|
|
777 |
<div>""",
|
778 |
unsafe_allow_html=True,
|
779 |
)
|
780 |
-
if not
|
|
|
|
|
781 |
st.button(
|
782 |
"Explain this",
|
783 |
on_click=vis_known_image,
|
@@ -785,4 +796,6 @@ for image_id in batch:
|
|
785 |
use_container_width=True,
|
786 |
key=image_id,
|
787 |
)
|
|
|
|
|
788 |
col = (col + 1) % row_size
|
|
|
27 |
|
28 |
from pytorch_grad_cam.grad_cam import GradCAM
|
29 |
|
30 |
+
RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only
|
31 |
|
32 |
MAX_IMG_WIDTH = 500
|
33 |
MAX_IMG_HEIGHT = 800
|
|
|
110 |
|
111 |
|
112 |
def string_search():
|
113 |
+
st.session_state.disable_uploader = (
|
114 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
115 |
+
)
|
116 |
+
|
117 |
if "search_field_value" in st.session_state:
|
118 |
clip_search(st.session_state.search_field_value)
|
119 |
|
|
|
183 |
ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider"
|
184 |
ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin"
|
185 |
|
186 |
+
st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load(
|
187 |
+
ja_model_path, device=device, jit=False
|
188 |
+
)
|
|
|
189 |
|
190 |
st.session_state.ja_model = AutoModel.from_pretrained(
|
191 |
ja_model_name, trust_remote_code=True
|
|
|
219 |
st.session_state.search_image_ids = []
|
220 |
st.session_state.search_image_scores = {}
|
221 |
st.session_state.text_table_df = None
|
222 |
+
st.session_state.disable_uploader = (
|
223 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
224 |
+
)
|
225 |
|
226 |
with st.spinner("Loading models and data, please wait..."):
|
227 |
load_image_features()
|
|
|
436 |
|
437 |
header_cols = st.columns([80, 20], vertical_alignment="bottom")
|
438 |
with header_cols[0]:
|
439 |
+
st.title("Image + query activation gradients")
|
440 |
with header_cols[1]:
|
441 |
if st.button("Close"):
|
442 |
st.rerun()
|
|
|
463 |
st.session_state.search_field_value
|
464 |
)
|
465 |
|
466 |
+
st.image(image)
|
467 |
+
|
468 |
with st.spinner("Calculating..."):
|
469 |
# info_text = st.text("Calculating activation regions...")
|
470 |
|
|
|
751 |
key="uploaded_image",
|
752 |
label_visibility="collapsed",
|
753 |
on_change=vis_uploaded_image,
|
754 |
+
disabled=st.session_state.disable_uploader,
|
755 |
)
|
756 |
|
757 |
|
|
|
786 |
<div>""",
|
787 |
unsafe_allow_html=True,
|
788 |
)
|
789 |
+
if not (
|
790 |
+
RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)"
|
791 |
+
):
|
792 |
st.button(
|
793 |
"Explain this",
|
794 |
on_click=vis_known_image,
|
|
|
796 |
use_container_width=True,
|
797 |
key=image_id,
|
798 |
)
|
799 |
+
else:
|
800 |
+
st.empty()
|
801 |
col = (col + 1) % row_size
|