Spaces:
Sleeping
Sleeping
File size: 3,201 Bytes
80dc74c 5e91161 80dc74c 21d3461 80dc74c 5e91161 80dc74c 5ead791 4f55ca2 80dc74c 8e05eba 4f55ca2 80dc74c 5ead791 80dc74c 5ead791 80dc74c 5ead791 0aef92c 80dc74c 4f55ca2 5ead791 4f55ca2 8e05eba 5ead791 0aef92c 80dc74c 5ead791 80dc74c 5ead791 5e91161 7e207f0 80dc74c 5ead791 80dc74c 4f55ca2 80dc74c 8e05eba 5ead791 8e05eba 7e207f0 b30bcef 7e207f0 21d3461 7e207f0 21d3461 8e05eba 21d3461 8e05eba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import streamlit as st
import torch
from app_lib.test import get_testing_config, load_precomputed_results, test
from app_lib.user_input import (
get_advanced_settings,
get_class_name,
get_concepts,
get_image,
get_model_name,
)
from app_lib.viz import viz_results
def _disable():
st.session_state.disabled = True
def _toggle_sidebar(button):
if button:
st.session_state.sidebar_state = "expanded"
st.experimental_rerun()
def _preload_results(image_name):
if image_name != st.session_state.image_name:
st.session_state.image_name = image_name
st.session_state.tested = False
if st.session_state.image_name is not None and not st.session_state.tested:
st.session_state.results = load_precomputed_results(image_name)
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
columns = st.columns([0.40, 0.60])
with columns[0]:
st.header("Choose Image and Concepts")
image_col, concepts_col = st.columns(2)
with image_col:
image_name, image = get_image()
st.image(image, use_column_width=True)
change_image_button = st.button(
"Change Image",
use_container_width=False,
disabled=st.session_state.disabled,
)
_toggle_sidebar(change_image_button)
with concepts_col:
model_name = get_model_name()
class_name, class_ready, class_error = get_class_name(image_name)
concepts, concepts_ready, concepts_error = get_concepts(image_name)
ready = class_ready and concepts_ready
error_message = ""
if class_error is not None:
error_message += f"- {class_error}\n"
if concepts_error is not None:
error_message += f"- {concepts_error}\n"
if error_message:
st.error(error_message)
with st.container():
(
significance_level,
tau_max,
r,
cardinality,
dataset_name,
) = get_advanced_settings(concepts, concepts_ready)
test_button = st.button(
"Test Concepts",
use_container_width=True,
on_click=_disable,
disabled=st.session_state.disabled or not ready,
)
if test_button:
st.session_state.results = None
with columns[1]:
viz_results()
testing_config = get_testing_config(
significance_level=significance_level, tau_max=tau_max, r=r
)
with columns[0]:
results = test(
testing_config,
image,
class_name,
concepts,
cardinality,
dataset_name,
model_name,
device=device,
)
st.session_state.tested = True
st.session_state.results = results
st.session_state.disabled = False
st.experimental_rerun()
else:
_preload_results(image_name)
with columns[1]:
viz_results()
|