File size: 3,201 Bytes
80dc74c
5e91161
80dc74c
21d3461
80dc74c
5e91161
80dc74c
 
 
 
 
5ead791
4f55ca2
 
 
 
80dc74c
 
8e05eba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f55ca2
80dc74c
 
 
5ead791
80dc74c
5ead791
80dc74c
5ead791
0aef92c
80dc74c
 
4f55ca2
 
5ead791
4f55ca2
 
8e05eba
 
5ead791
 
0aef92c
 
80dc74c
5ead791
80dc74c
5ead791
 
 
 
 
 
 
 
 
5e91161
 
 
 
 
 
 
7e207f0
80dc74c
5ead791
80dc74c
4f55ca2
 
80dc74c
 
8e05eba
 
5ead791
8e05eba
 
 
 
 
 
 
 
 
7e207f0
 
 
 
 
b30bcef
7e207f0
21d3461
7e207f0
21d3461
8e05eba
 
 
 
 
 
21d3461
8e05eba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import streamlit as st
import torch

from app_lib.test import get_testing_config, load_precomputed_results, test
from app_lib.user_input import (
    get_advanced_settings,
    get_class_name,
    get_concepts,
    get_image,
    get_model_name,
)
from app_lib.viz import viz_results


def _disable():
    st.session_state.disabled = True


def _toggle_sidebar(button):
    if button:
        st.session_state.sidebar_state = "expanded"
        st.experimental_rerun()


def _preload_results(image_name):
    if image_name != st.session_state.image_name:
        st.session_state.image_name = image_name
        st.session_state.tested = False

    if st.session_state.image_name is not None and not st.session_state.tested:
        st.session_state.results = load_precomputed_results(image_name)


def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    columns = st.columns([0.40, 0.60])

    with columns[0]:
        st.header("Choose Image and Concepts")

        image_col, concepts_col = st.columns(2)

        with image_col:
            image_name, image = get_image()
            st.image(image, use_column_width=True)

            change_image_button = st.button(
                "Change Image",
                use_container_width=False,
                disabled=st.session_state.disabled,
            )
            _toggle_sidebar(change_image_button)

        with concepts_col:
            model_name = get_model_name()
            class_name, class_ready, class_error = get_class_name(image_name)
            concepts, concepts_ready, concepts_error = get_concepts(image_name)

        ready = class_ready and concepts_ready

        error_message = ""
        if class_error is not None:
            error_message += f"- {class_error}\n"
        if concepts_error is not None:
            error_message += f"- {concepts_error}\n"
        if error_message:
            st.error(error_message)

        with st.container():
            (
                significance_level,
                tau_max,
                r,
                cardinality,
                dataset_name,
            ) = get_advanced_settings(concepts, concepts_ready)

            test_button = st.button(
                "Test Concepts",
                use_container_width=True,
                on_click=_disable,
                disabled=st.session_state.disabled or not ready,
            )

    if test_button:
        st.session_state.results = None

        with columns[1]:
            viz_results()

        testing_config = get_testing_config(
            significance_level=significance_level, tau_max=tau_max, r=r
        )

        with columns[0]:
            results = test(
                testing_config,
                image,
                class_name,
                concepts,
                cardinality,
                dataset_name,
                model_name,
                device=device,
            )

        st.session_state.tested = True
        st.session_state.results = results
        st.session_state.disabled = False
        st.experimental_rerun()
    else:
        _preload_results(image_name)

        with columns[1]:
            viz_results()