IBYDMT / app_lib /main.py
jacopoteneggi's picture
Update
5e91161 verified
raw
history blame
2.51 kB
import streamlit as st
import torch
from app_lib.test import get_testing_config, test
from app_lib.user_input import (
get_advanced_settings,
get_class_name,
get_concepts,
get_image,
get_model_name,
)
from app_lib.viz import viz_results
def _disable():
st.session_state.disabled = True
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
columns = st.columns([0.40, 0.60])
with columns[1]:
st.header("Results")
viz_results()
with columns[0]:
st.header("Choose Image and Concepts")
image_col, concepts_col = st.columns(2)
with image_col:
image = get_image()
st.image(image, use_column_width=True)
change_image_button = st.button(
"Change Image",
use_container_width=False,
disabled=st.session_state.disabled,
)
if change_image_button:
st.session_state.sidebar_state = "expanded"
st.experimental_rerun()
with concepts_col:
model_name = get_model_name()
class_name, class_ready, class_error = get_class_name()
concepts, concepts_ready, concepts_error = get_concepts()
ready = class_ready and concepts_ready
error_message = ""
if class_error is not None:
error_message += f"- {class_error}\n"
if concepts_error is not None:
error_message += f"- {concepts_error}\n"
if error_message:
st.error(error_message)
with st.container():
(
significance_level,
tau_max,
r,
cardinality,
dataset_name,
) = get_advanced_settings(concepts, concepts_ready)
test_button = st.button(
"Test Concepts",
use_container_width=True,
on_click=_disable,
disabled=st.session_state.disabled or not ready,
)
if test_button:
st.session_state.results = None
testing_config = get_testing_config(
significance_level=significance_level, tau_max=tau_max, r=r
)
test(
testing_config,
image,
class_name,
concepts,
cardinality,
dataset_name,
model_name,
device,
)