File size: 3,201 Bytes
468f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import torch

from app_lib.test import get_testing_config, load_precomputed_results, test
from app_lib.user_input import (
    get_advanced_settings,
    get_class_name,
    get_concepts,
    get_image,
    get_model_name,
)
from app_lib.viz import viz_results


def _disable():
    st.session_state.disabled = True


def _toggle_sidebar(button):
    if button:
        st.session_state.sidebar_state = "expanded"
        st.experimental_rerun()


def _preload_results(image_name):
    if image_name != st.session_state.image_name:
        st.session_state.image_name = image_name
        st.session_state.tested = False

    if st.session_state.image_name is not None and not st.session_state.tested:
        st.session_state.results = load_precomputed_results(image_name)


def demo(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    columns = st.columns([0.40, 0.60])

    with columns[0]:
        st.header("Choose Image and Concepts")

        image_col, concepts_col = st.columns(2)

        with image_col:
            image_name, image = get_image()
            st.image(image, use_column_width=True)

            change_image_button = st.button(
                "Change Image",
                use_container_width=False,
                disabled=st.session_state.disabled,
            )
            _toggle_sidebar(change_image_button)

        with concepts_col:
            model_name = get_model_name()
            class_name, class_ready, class_error = get_class_name(image_name)
            concepts, concepts_ready, concepts_error = get_concepts(image_name)

        ready = class_ready and concepts_ready

        error_message = ""
        if class_error is not None:
            error_message += f"- {class_error}\n"
        if concepts_error is not None:
            error_message += f"- {concepts_error}\n"
        if error_message:
            st.error(error_message)

        with st.container():
            (
                significance_level,
                tau_max,
                r,
                cardinality,
                dataset_name,
            ) = get_advanced_settings(concepts, concepts_ready)

            test_button = st.button(
                "Test Concepts",
                use_container_width=True,
                on_click=_disable,
                disabled=st.session_state.disabled or not ready,
            )

    if test_button:
        st.session_state.results = None

        with columns[1]:
            viz_results()

        testing_config = get_testing_config(
            significance_level=significance_level, tau_max=tau_max, r=r
        )

        with columns[0]:
            results = test(
                testing_config,
                image,
                class_name,
                concepts,
                cardinality,
                dataset_name,
                model_name,
                device=device,
            )

        st.session_state.tested = True
        st.session_state.results = results
        st.session_state.disabled = False
        st.experimental_rerun()
    else:
        _preload_results(image_name)

        with columns[1]:
            viz_results()