Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from app_lib.test import get_testing_config, load_precomputed_results, test | |
from app_lib.user_input import ( | |
get_advanced_settings, | |
get_class_name, | |
get_concepts, | |
get_image, | |
get_model_name, | |
) | |
from app_lib.viz import viz_results | |
def _disable(): | |
st.session_state.disabled = True | |
def _toggle_sidebar(button): | |
if button: | |
st.session_state.sidebar_state = "expanded" | |
st.experimental_rerun() | |
def _preload_results(image_name): | |
if image_name != st.session_state.image_name: | |
st.session_state.image_name = image_name | |
st.session_state.tested = False | |
if st.session_state.image_name is not None and not st.session_state.tested: | |
st.session_state.results = load_precomputed_results(image_name) | |
def demo(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): | |
columns = st.columns([0.40, 0.60]) | |
with columns[0]: | |
st.header("Choose Image and Concepts") | |
image_col, concepts_col = st.columns(2) | |
with image_col: | |
image_name, image = get_image() | |
st.image(image, use_column_width=True) | |
change_image_button = st.button( | |
"Change Image", | |
use_container_width=False, | |
disabled=st.session_state.disabled, | |
) | |
_toggle_sidebar(change_image_button) | |
with concepts_col: | |
model_name = get_model_name() | |
class_name, class_ready, class_error = get_class_name(image_name) | |
concepts, concepts_ready, concepts_error = get_concepts(image_name) | |
ready = class_ready and concepts_ready | |
error_message = "" | |
if class_error is not None: | |
error_message += f"- {class_error}\n" | |
if concepts_error is not None: | |
error_message += f"- {concepts_error}\n" | |
if error_message: | |
st.error(error_message) | |
with st.container(): | |
( | |
significance_level, | |
tau_max, | |
r, | |
cardinality, | |
dataset_name, | |
) = get_advanced_settings(concepts, concepts_ready) | |
test_button = st.button( | |
"Test Concepts", | |
use_container_width=True, | |
on_click=_disable, | |
disabled=st.session_state.disabled or not ready, | |
) | |
if test_button: | |
st.session_state.results = None | |
with columns[1]: | |
viz_results() | |
testing_config = get_testing_config( | |
significance_level=significance_level, tau_max=tau_max, r=r | |
) | |
with columns[0]: | |
results = test( | |
testing_config, | |
image, | |
class_name, | |
concepts, | |
cardinality, | |
dataset_name, | |
model_name, | |
device=device, | |
) | |
st.session_state.tested = True | |
st.session_state.results = results | |
st.session_state.disabled = False | |
st.experimental_rerun() | |
else: | |
_preload_results(image_name) | |
with columns[1]: | |
viz_results() | |