Spaces:
Sleeping
Sleeping
File size: 2,513 Bytes
80dc74c 5e91161 80dc74c 5e91161 80dc74c 5e91161 80dc74c 5ead791 4f55ca2 80dc74c 4f55ca2 80dc74c 7e207f0 80dc74c 5ead791 80dc74c 5ead791 80dc74c 5ead791 80dc74c 4f55ca2 5ead791 4f55ca2 80dc74c 5ead791 80dc74c 5ead791 80dc74c 5ead791 5e91161 7e207f0 80dc74c 5ead791 80dc74c 4f55ca2 80dc74c 5ead791 7e207f0 b30bcef 7e207f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import streamlit as st
import torch
from app_lib.test import get_testing_config, test
from app_lib.user_input import (
get_advanced_settings,
get_class_name,
get_concepts,
get_image,
get_model_name,
)
from app_lib.viz import viz_results
def _disable():
st.session_state.disabled = True
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
columns = st.columns([0.40, 0.60])
with columns[1]:
st.header("Results")
viz_results()
with columns[0]:
st.header("Choose Image and Concepts")
image_col, concepts_col = st.columns(2)
with image_col:
image = get_image()
st.image(image, use_column_width=True)
change_image_button = st.button(
"Change Image",
use_container_width=False,
disabled=st.session_state.disabled,
)
if change_image_button:
st.session_state.sidebar_state = "expanded"
st.experimental_rerun()
with concepts_col:
model_name = get_model_name()
class_name, class_ready, class_error = get_class_name()
concepts, concepts_ready, concepts_error = get_concepts()
ready = class_ready and concepts_ready
error_message = ""
if class_error is not None:
error_message += f"- {class_error}\n"
if concepts_error is not None:
error_message += f"- {concepts_error}\n"
if error_message:
st.error(error_message)
with st.container():
(
significance_level,
tau_max,
r,
cardinality,
dataset_name,
) = get_advanced_settings(concepts, concepts_ready)
test_button = st.button(
"Test Concepts",
use_container_width=True,
on_click=_disable,
disabled=st.session_state.disabled or not ready,
)
if test_button:
st.session_state.results = None
testing_config = get_testing_config(
significance_level=significance_level, tau_max=tau_max, r=r
)
test(
testing_config,
image,
class_name,
concepts,
cardinality,
dataset_name,
model_name,
device,
)
|