import streamlit as st import torch from app_lib.test import get_testing_config, test from app_lib.user_input import ( get_advanced_settings, get_class_name, get_concepts, get_image, get_model_name, ) from app_lib.viz import viz_results def _disable(): st.session_state.disabled = True def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): columns = st.columns([0.40, 0.60]) with columns[1]: st.header("Results") viz_results() with columns[0]: st.header("Choose Image and Concepts") image_col, concepts_col = st.columns(2) with image_col: image = get_image() st.image(image, use_column_width=True) change_image_button = st.button( "Change Image", use_container_width=False, disabled=st.session_state.disabled, ) if change_image_button: st.session_state.sidebar_state = "expanded" st.experimental_rerun() with concepts_col: model_name = get_model_name() class_name, class_ready, class_error = get_class_name() concepts, concepts_ready, concepts_error = get_concepts() ready = class_ready and concepts_ready error_message = "" if class_error is not None: error_message += f"- {class_error}\n" if concepts_error is not None: error_message += f"- {concepts_error}\n" if error_message: st.error(error_message) with st.container(): ( significance_level, tau_max, r, cardinality, dataset_name, ) = get_advanced_settings(concepts, concepts_ready) test_button = st.button( "Test Concepts", use_container_width=True, on_click=_disable, disabled=st.session_state.disabled or not ready, ) if test_button: st.session_state.results = None testing_config = get_testing_config( significance_level=significance_level, tau_max=tau_max, r=r ) test( testing_config, image, class_name, concepts, cardinality, dataset_name, model_name, device, )