IBYDMT / app_lib /main.py
jacopoteneggi's picture
Update
8e05eba verified
raw
history blame
3.2 kB
import streamlit as st
import torch
from app_lib.test import get_testing_config, load_precomputed_results, test
from app_lib.user_input import (
get_advanced_settings,
get_class_name,
get_concepts,
get_image,
get_model_name,
)
from app_lib.viz import viz_results
def _disable():
st.session_state.disabled = True
def _toggle_sidebar(button):
if button:
st.session_state.sidebar_state = "expanded"
st.experimental_rerun()
def _preload_results(image_name):
if image_name != st.session_state.image_name:
st.session_state.image_name = image_name
st.session_state.tested = False
if st.session_state.image_name is not None and not st.session_state.tested:
st.session_state.results = load_precomputed_results(image_name)
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
columns = st.columns([0.40, 0.60])
with columns[0]:
st.header("Choose Image and Concepts")
image_col, concepts_col = st.columns(2)
with image_col:
image_name, image = get_image()
st.image(image, use_column_width=True)
change_image_button = st.button(
"Change Image",
use_container_width=False,
disabled=st.session_state.disabled,
)
_toggle_sidebar(change_image_button)
with concepts_col:
model_name = get_model_name()
class_name, class_ready, class_error = get_class_name(image_name)
concepts, concepts_ready, concepts_error = get_concepts(image_name)
ready = class_ready and concepts_ready
error_message = ""
if class_error is not None:
error_message += f"- {class_error}\n"
if concepts_error is not None:
error_message += f"- {concepts_error}\n"
if error_message:
st.error(error_message)
with st.container():
(
significance_level,
tau_max,
r,
cardinality,
dataset_name,
) = get_advanced_settings(concepts, concepts_ready)
test_button = st.button(
"Test Concepts",
use_container_width=True,
on_click=_disable,
disabled=st.session_state.disabled or not ready,
)
if test_button:
st.session_state.results = None
with columns[1]:
viz_results()
testing_config = get_testing_config(
significance_level=significance_level, tau_max=tau_max, r=r
)
with columns[0]:
results = test(
testing_config,
image,
class_name,
concepts,
cardinality,
dataset_name,
model_name,
device=device,
)
st.session_state.tested = True
st.session_state.results = results
st.session_state.disabled = False
st.experimental_rerun()
else:
_preload_results(image_name)
with columns[1]:
viz_results()