File size: 2,513 Bytes
80dc74c
5e91161
80dc74c
5e91161
80dc74c
5e91161
80dc74c
 
 
 
 
5ead791
4f55ca2
 
 
 
80dc74c
 
4f55ca2
80dc74c
 
7e207f0
 
 
 
80dc74c
5ead791
80dc74c
5ead791
80dc74c
5ead791
80dc74c
 
 
4f55ca2
 
5ead791
4f55ca2
 
80dc74c
 
 
5ead791
 
 
 
80dc74c
5ead791
80dc74c
5ead791
 
 
 
 
 
 
 
 
5e91161
 
 
 
 
 
 
7e207f0
80dc74c
5ead791
80dc74c
4f55ca2
 
80dc74c
 
5ead791
 
 
7e207f0
 
 
 
 
 
 
 
 
b30bcef
7e207f0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import streamlit as st
import torch

from app_lib.test import get_testing_config, test
from app_lib.user_input import (
    get_advanced_settings,
    get_class_name,
    get_concepts,
    get_image,
    get_model_name,
)
from app_lib.viz import viz_results


def _disable():
    st.session_state.disabled = True


def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    columns = st.columns([0.40, 0.60])

    with columns[1]:
        st.header("Results")
        viz_results()

    with columns[0]:
        st.header("Choose Image and Concepts")

        image_col, concepts_col = st.columns(2)

        with image_col:
            image = get_image()
            st.image(image, use_column_width=True)

            change_image_button = st.button(
                "Change Image",
                use_container_width=False,
                disabled=st.session_state.disabled,
            )
            if change_image_button:
                st.session_state.sidebar_state = "expanded"
                st.experimental_rerun()
        with concepts_col:
            model_name = get_model_name()
            class_name, class_ready, class_error = get_class_name()
            concepts, concepts_ready, concepts_error = get_concepts()

        ready = class_ready and concepts_ready

        error_message = ""
        if class_error is not None:
            error_message += f"- {class_error}\n"
        if concepts_error is not None:
            error_message += f"- {concepts_error}\n"
        if error_message:
            st.error(error_message)

        with st.container():
            (
                significance_level,
                tau_max,
                r,
                cardinality,
                dataset_name,
            ) = get_advanced_settings(concepts, concepts_ready)

            test_button = st.button(
                "Test Concepts",
                use_container_width=True,
                on_click=_disable,
                disabled=st.session_state.disabled or not ready,
            )

        if test_button:
            st.session_state.results = None

            testing_config = get_testing_config(
                significance_level=significance_level, tau_max=tau_max, r=r
            )
            test(
                testing_config,
                image,
                class_name,
                concepts,
                cardinality,
                dataset_name,
                model_name,
                device,
            )