import streamlit as st import torch from app_lib.test import get_testing_config, load_precomputed_results, test from app_lib.user_input import ( get_advanced_settings, get_class_name, get_concepts, get_image, get_model_name, ) from app_lib.viz import viz_results def _disable(): st.session_state.disabled = True def _toggle_sidebar(button): if button: st.session_state.sidebar_state = "expanded" st.experimental_rerun() def _preload_results(image_name): if image_name != st.session_state.image_name: st.session_state.image_name = image_name st.session_state.tested = False if st.session_state.image_name is not None and not st.session_state.tested: st.session_state.results = load_precomputed_results(image_name) def demo(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): columns = st.columns([0.40, 0.60]) with columns[0]: st.header("Choose Image and Concepts") image_col, concepts_col = st.columns(2) with image_col: image_name, image = get_image() st.image(image, use_column_width=True) change_image_button = st.button( "Change Image", use_container_width=False, disabled=st.session_state.disabled, ) _toggle_sidebar(change_image_button) with concepts_col: model_name = get_model_name() class_name, class_ready, class_error = get_class_name(image_name) concepts, concepts_ready, concepts_error = get_concepts(image_name) ready = class_ready and concepts_ready error_message = "" if class_error is not None: error_message += f"- {class_error}\n" if concepts_error is not None: error_message += f"- {concepts_error}\n" if error_message: st.error(error_message) with st.container(): ( significance_level, tau_max, r, cardinality, dataset_name, ) = get_advanced_settings(concepts, concepts_ready) test_button = st.button( "Test Concepts", use_container_width=True, on_click=_disable, disabled=st.session_state.disabled or not ready, ) if test_button: st.session_state.results = None with columns[1]: viz_results() testing_config = get_testing_config( significance_level=significance_level, tau_max=tau_max, r=r ) with columns[0]: results = test( testing_config, image, class_name, concepts, cardinality, dataset_name, model_name, device=device, ) st.session_state.tested = True st.session_state.results = results st.session_state.disabled = False st.experimental_rerun() else: _preload_results(image_name) with columns[1]: viz_results()