File size: 5,053 Bytes
80dc74c
 
 
 
5e91161
80dc74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e207f0
 
f6eb5e3
7e207f0
 
 
 
 
 
 
 
f6eb5e3
7e207f0
 
f6eb5e3
7e207f0
 
 
 
 
f6eb5e3
7e207f0
 
b30bcef
7e207f0
 
 
 
 
 
f6eb5e3
7e207f0
f6eb5e3
7e207f0
 
 
 
 
 
 
f6eb5e3
7e207f0
 
 
 
 
 
 
 
 
f6eb5e3
7e207f0
f6eb5e3
7e207f0
 
 
 
 
 
 
b30bcef
7e207f0
 
 
 
 
 
 
 
 
 
b30bcef
7e207f0
 
 
 
 
b30bcef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80dc74c
 
5ead791
4f55ca2
80dc74c
4f55ca2
80dc74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f55ca2
80dc74c
 
 
 
 
 
 
 
7e207f0
 
80dc74c
 
4f55ca2
80dc74c
 
 
 
 
 
 
 
 
 
7e207f0
b30bcef
 
7e207f0
 
 
 
b30bcef
7e207f0
b30bcef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select

from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS


def _validate_class_name(class_name):
    if class_name is None:
        return (False, "Class name cannot be empty.")
    if class_name.strip() == "":
        return (False, "Class name cannot be empty.")
    return (True, None)


def _validate_concepts(concepts):
    if len(concepts) < 3:
        return (False, "You must provide at least 3 concepts")
    if len(concepts) > 10:
        return (False, "Maximum 10 concepts allowed")
    return (True, None)


def _get_significance_level():
    STEP, DEFAULT = 0.01, 0.05
    return st.slider(
        "Significance level",
        help=" ".join(
            [
                "The level of significance of the tests.",
                f"Defaults to {DEFAULT:.2F}.",
            ]
        ),
        min_value=STEP,
        max_value=1.0,
        value=DEFAULT,
        step=STEP,
        disabled=st.session_state.disabled,
    )


def _get_tau_max():
    STEP, DEFAULT = 50, 200
    return int(
        st.slider(
            "Length of test",
            help=" ".join(
                [
                    "The maximum number of steps for each test.",
                    f"Defaults to {DEFAULT}.",
                ]
            ),
            min_value=STEP,
            max_value=1000,
            step=STEP,
            value=DEFAULT,
            disabled=st.session_state.disabled,
        )
    )


def _get_number_of_tests():
    STEP, DEFAULT = 5, 10
    return int(
        st.slider(
            "Number of tests per concept",
            help=" ".join(
                [
                    "The number of tests to average for each concept.",
                    f"Defaults to {DEFAULT}.",
                ]
            ),
            min_value=STEP,
            max_value=100,
            step=STEP,
            value=DEFAULT,
            disabled=st.session_state.disabled,
        )
    )


def _get_cardinality(concepts, concepts_ready):
    DEFAULT = lambda concepts: int(len(concepts) / 2)
    return st.slider(
        "Size of conditioning set",
        help=" ".join(
            [
                "The number of concepts to condition model predictions on.",
                "Defaults to half of the number of concepts.",
            ]
        ),
        min_value=1,
        max_value=max(2, len(concepts) - 1),
        value=DEFAULT(concepts),
        step=1,
        disabled=st.session_state.disabled or not concepts_ready,
    )


def _get_dataset_name():
    DEFAULT = SUPPORTED_DATASETS.index("imagenette")
    return st.selectbox(
        "Dataset",
        options=SUPPORTED_DATASETS,
        index=DEFAULT,
        help=" ".join(
            [
                "Name of the dataset to use to train sampler.",
                "Defaults to Imagenette.",
            ]
        ),
        disabled=st.session_state.disabled,
    )


def get_model_name():
    return st.selectbox(
        "Model to test",
        options=list(SUPPORTED_MODELS.keys()),
        help="Name of the vision-language model to test the predictions of.",
        disabled=st.session_state.disabled,
    )


def get_image():
    with st.sidebar:
        uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
        image = uploaded_file or image_select(
            label="or select one",
            images=[
                "assets/ace.jpg",
                "assets/ace.jpg",
                "assets/ace.jpg",
                "assets/ace.jpg",
            ],
        )
    return Image.open(image)


def get_class_name():
    class_name = st.text_input(
        "Class to test",
        help="Name of the class to build the zero-shot CLIP classifier with.",
        value="cat",
        disabled=st.session_state.disabled,
    )

    class_ready, class_error = _validate_class_name(class_name)
    return class_name, class_ready, class_error


def get_concepts():
    concepts = st.text_area(
        "Concepts to test",
        help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
        height=160,
        value="piano\ncute\nwhiskers\nmusic\nwild",
        disabled=st.session_state.disabled,
    )
    concepts = concepts.split("\n")
    concepts = [concept.strip() for concept in concepts]
    concepts = [concept for concept in concepts if concept != ""]
    concepts = list(set(concepts))

    concepts_ready, concepts_error = _validate_concepts(concepts)
    return concepts, concepts_ready, concepts_error


def get_advanced_settings(concepts, concepts_ready):
    with st.expander("Advanced settings"):
        dataset_name = _get_dataset_name()
        significance_level = _get_significance_level()
        tau_max = _get_tau_max()
        r = _get_number_of_tests()
        cardinality = _get_cardinality(concepts, concepts_ready)
        st.divider()

    return significance_level, tau_max, r, cardinality, dataset_name