File size: 5,989 Bytes
0aef92c
 
 
80dc74c
 
 
 
21d3461
5e91161
80dc74c
0aef92c
 
 
 
 
80dc74c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e207f0
 
21d3461
 
7e207f0
 
21d3461
 
7e207f0
21d3461
 
7e207f0
 
 
 
 
21d3461
 
7e207f0
 
b30bcef
21d3461
 
7e207f0
21d3461
 
7e207f0
 
 
 
 
 
21d3461
 
7e207f0
 
 
21d3461
 
 
7e207f0
21d3461
7e207f0
21d3461
 
7e207f0
 
 
 
 
 
21d3461
 
7e207f0
 
21d3461
 
8e05eba
7e207f0
 
 
8e05eba
21d3461
7e207f0
 
 
 
b30bcef
21d3461
 
b30bcef
 
21d3461
 
 
 
 
b30bcef
 
 
 
 
80dc74c
21d3461
 
80dc74c
5ead791
21d3461
 
 
 
 
 
4f55ca2
80dc74c
 
 
 
 
 
0aef92c
 
 
8e05eba
0aef92c
 
 
 
 
 
 
 
 
 
 
21d3461
0aef92c
 
80dc74c
21d3461
80dc74c
21d3461
4f55ca2
0aef92c
80dc74c
 
 
 
 
 
0aef92c
21d3461
0aef92c
 
 
 
80dc74c
7e207f0
21d3461
 
 
 
8e05eba
21d3461
4f55ca2
0aef92c
80dc74c
 
 
 
 
 
 
 
 
 
7e207f0
b30bcef
 
7e207f0
 
 
 
b30bcef
7e207f0
b30bcef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import json
import os

import streamlit as st
from PIL import Image
from streamlit_image_select import image_select

import app_lib.defaults as defaults
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS

IMAGE_DIR = os.path.join("assets", "images")
IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR))))
IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES))
IMAGE_PRESETS = json.load(open("assets/image_presets.json"))


def _validate_class_name(class_name):
    if class_name is None:
        return (False, "Class name cannot be empty.")
    if class_name.strip() == "":
        return (False, "Class name cannot be empty.")
    return (True, None)


def _validate_concepts(concepts):
    if len(concepts) < 3:
        return (False, "You must provide at least 3 concepts")
    if len(concepts) > 10:
        return (False, "Maximum 10 concepts allowed")
    return (True, None)


def _get_significance_level():
    default = defaults.SIGNIFICANCE_LEVEL_VALUE
    step = defaults.SIGNIFICANCE_LEVEL_STEP
    return st.slider(
        "Significance level",
        help=f"The level of significance of the tests. Defaults to {default:.2F}.",
        min_value=step,
        max_value=1.0,
        value=default,
        step=step,
        disabled=st.session_state.disabled,
    )


def _get_tau_max():
    default = defaults.TAU_MAX_VALUE
    step = defaults.TAU_MAX_STEP
    return int(
        st.slider(
            "Length of test",
            help=f"The maximum number of steps for each test. Defaults to {default}.",
            min_value=step,
            max_value=1000,
            step=step,
            value=default,
            disabled=st.session_state.disabled,
        )
    )


def _get_number_of_tests():
    default = defaults.R_VALUE
    step = defaults.R_STEP
    return int(
        st.slider(
            "Number of tests per concept",
            help=(
                "The number of tests to average for each concept. "
                f"Defaults to {default}."
            ),
            min_value=step,
            max_value=100,
            step=step,
            value=default,
            disabled=st.session_state.disabled,
        )
    )


def _get_cardinality(concepts, concepts_ready):
    default = defaults.CARDINALITY_VALUE
    step = defaults.CARDINALITY_STEP
    return st.slider(
        "Size of conditioning set",
        help=(
            "The number of concepts to condition model predictions on. "
            "Defaults to {default}."
        ),
        min_value=1,
        max_value=max(2, len(concepts) - 1),
        value=default,
        step=step,
        disabled=st.session_state.disabled or not concepts_ready,
    )


def _get_dataset_name():
    options = SUPPORTED_DATASETS
    default_idx = options.index(defaults.DATASET_NAME)
    return st.selectbox(
        "Dataset",
        options=options,
        index=default_idx,
        help=(
            "Name of the dataset to use to train sampler."
            f"Defaults to {SUPPORTED_DATASETS[default_idx]}."
        ),
        disabled=st.session_state.disabled,
    )


def get_model_name():
    options = list(SUPPORTED_MODELS.keys())
    default_idx = options.index(defaults.MODEL_NAME)
    return st.selectbox(
        "Model to test",
        options=options,
        index=default_idx,
        help=(
            "Name of the vision-language model to test the predictions of."
            f"Defaults to {options[default_idx]}"
        ),
        disabled=st.session_state.disabled,
    )


def get_image():
    with st.sidebar:
        uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
        if uploaded_file is not None:
            return (None, Image.open(uploaded_file))
        else:
            DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg")
            image_idx = image_select(
                label="or select one",
                images=IMAGE_PATHS,
                index=DEFAULT,
                return_value="index",
            )
            image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx]
            return (image_name, Image.open(image_path))


def get_class_name(image_name=None):
    default = (
        IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
    )
    class_name = st.text_input(
        "Class to predict",
        help="Name of the class to build the zero-shot CLIP classifier with.",
        value=default,
        disabled=st.session_state.disabled,
        placeholder="Type class name here",
    )

    class_ready, class_error = _validate_class_name(class_name)
    return class_name, class_ready, class_error


def get_concepts(image_name=None):
    default = (
        "\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
        if image_name
        else ""
    )
    concepts = st.text_area(
        "Concepts to test",
        help=(
            "List of concepts to test the predictions of the model with. "
            "Write one concept per line. Maximum 10 concepts allowed."
        ),
        height=180,
        value=default,
        disabled=st.session_state.disabled,
        placeholder="Type one concept\nper line",
    )
    concepts = concepts.split("\n")
    concepts = [concept.strip() for concept in concepts]
    concepts = [concept for concept in concepts if concept != ""]
    concepts = list(set(concepts))

    concepts_ready, concepts_error = _validate_concepts(concepts)
    return concepts, concepts_ready, concepts_error


def get_advanced_settings(concepts, concepts_ready):
    with st.expander("Advanced settings"):
        dataset_name = _get_dataset_name()
        significance_level = _get_significance_level()
        tau_max = _get_tau_max()
        r = _get_number_of_tests()
        cardinality = _get_cardinality(concepts, concepts_ready)
        st.divider()

    return significance_level, tau_max, r, cardinality, dataset_name