Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Update
Browse files- README.md +1 -1
- app_lib/main.py +2 -2
- app_lib/test.py +3 -3
- app_lib/user_input.py +24 -5
- app_lib/utils.py +1 -1
- app_lib/viz.py +1 -1
- header.md +2 -1
- style.css +19 -6
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: I Bet You Did Not Mean That
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
|
|
1 |
---
|
2 |
title: I Bet You Did Not Mean That
|
3 |
+
emoji: 🤔
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
app_lib/main.py
CHANGED
@@ -56,7 +56,7 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
56 |
st.error(error_message)
|
57 |
|
58 |
with st.container():
|
59 |
-
significance_level, tau_max, r, cardinality = get_advanced_settings(
|
60 |
concepts, concepts_ready
|
61 |
)
|
62 |
|
@@ -79,7 +79,7 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
79 |
class_name,
|
80 |
concepts,
|
81 |
cardinality,
|
82 |
-
|
83 |
model_name,
|
84 |
device,
|
85 |
)
|
|
|
56 |
st.error(error_message)
|
57 |
|
58 |
with st.container():
|
59 |
+
significance_level, tau_max, r, cardinality, dataset_name = get_advanced_settings(
|
60 |
concepts, concepts_ready
|
61 |
)
|
62 |
|
|
|
79 |
class_name,
|
80 |
concepts,
|
81 |
cardinality,
|
82 |
+
dataset_name,
|
83 |
model_name,
|
84 |
device,
|
85 |
)
|
app_lib/test.py
CHANGED
@@ -168,7 +168,7 @@ def test(
|
|
168 |
|
169 |
progress_bar = st.progress(
|
170 |
0,
|
171 |
-
text=f"Testing concepts (can take a
|
172 |
)
|
173 |
|
174 |
embedding = _load_dataset(dataset_name, model_name)
|
@@ -179,7 +179,7 @@ def test(
|
|
179 |
|
180 |
progress_bar.progress(
|
181 |
1 / (len(concepts) + 1),
|
182 |
-
text=f"Testing concepts (can take a
|
183 |
)
|
184 |
|
185 |
with ThreadPoolExecutor() as executor:
|
@@ -202,7 +202,7 @@ def test(
|
|
202 |
results.append(future.result())
|
203 |
progress_bar.progress(
|
204 |
(idx + 2) / (len(concepts) + 1),
|
205 |
-
text=f"Testing concepts (can take a
|
206 |
)
|
207 |
|
208 |
rejected = np.empty((testing_config.r, len(concepts)))
|
|
|
168 |
|
169 |
progress_bar = st.progress(
|
170 |
0,
|
171 |
+
text=f"Testing concepts (can take up to a minute) [0 / {len(concepts)} completed]",
|
172 |
)
|
173 |
|
174 |
embedding = _load_dataset(dataset_name, model_name)
|
|
|
179 |
|
180 |
progress_bar.progress(
|
181 |
1 / (len(concepts) + 1),
|
182 |
+
text=f"Testing concepts (can take up to a minute) [0 / {len(concepts)} completed]",
|
183 |
)
|
184 |
|
185 |
with ThreadPoolExecutor() as executor:
|
|
|
202 |
results.append(future.result())
|
203 |
progress_bar.progress(
|
204 |
(idx + 2) / (len(concepts) + 1),
|
205 |
+
text=f"Testing concepts (can take up to a minute) [{idx + 1} / {len(concepts)} completed]",
|
206 |
)
|
207 |
|
208 |
rejected = np.empty((testing_config.r, len(concepts)))
|
app_lib/user_input.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
from PIL import Image
|
3 |
from streamlit_image_select import image_select
|
4 |
|
5 |
-
from app_lib.utils import SUPPORTED_MODELS
|
6 |
|
7 |
|
8 |
def _validate_class_name(class_name):
|
@@ -43,7 +43,7 @@ def _get_tau_max():
|
|
43 |
DEFAULT = 200
|
44 |
return int(
|
45 |
st.slider(
|
46 |
-
"
|
47 |
help=" ".join(
|
48 |
[
|
49 |
"The maximum number of steps for each test.",
|
@@ -80,6 +80,7 @@ def _get_number_of_tests():
|
|
80 |
|
81 |
|
82 |
def _get_cardinality(concepts, concepts_ready):
|
|
|
83 |
return st.slider(
|
84 |
"Size of conditioning set",
|
85 |
help=" ".join(
|
@@ -90,12 +91,28 @@ def _get_cardinality(concepts, concepts_ready):
|
|
90 |
),
|
91 |
min_value=1,
|
92 |
max_value=max(2, len(concepts) - 1),
|
93 |
-
value=
|
94 |
step=1,
|
95 |
disabled=st.session_state.disabled or not concepts_ready,
|
96 |
)
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def get_model_name():
|
100 |
return st.selectbox(
|
101 |
"Model to test",
|
@@ -150,10 +167,12 @@ def get_concepts():
|
|
150 |
|
151 |
|
152 |
def get_advanced_settings(concepts, concepts_ready):
|
153 |
-
with st.
|
|
|
154 |
significance_level = _get_significance_level()
|
155 |
tau_max = _get_tau_max()
|
156 |
r = _get_number_of_tests()
|
157 |
cardinality = _get_cardinality(concepts, concepts_ready)
|
|
|
158 |
|
159 |
-
return significance_level, tau_max, r, cardinality
|
|
|
2 |
from PIL import Image
|
3 |
from streamlit_image_select import image_select
|
4 |
|
5 |
+
from app_lib.utils import SUPPORTED_MODELS, SUPPORTED_DATASETS
|
6 |
|
7 |
|
8 |
def _validate_class_name(class_name):
|
|
|
43 |
DEFAULT = 200
|
44 |
return int(
|
45 |
st.slider(
|
46 |
+
"Length of test",
|
47 |
help=" ".join(
|
48 |
[
|
49 |
"The maximum number of steps for each test.",
|
|
|
80 |
|
81 |
|
82 |
def _get_cardinality(concepts, concepts_ready):
|
83 |
+
DEFAULT = lambda concepts: int(len(concepts) / 2)
|
84 |
return st.slider(
|
85 |
"Size of conditioning set",
|
86 |
help=" ".join(
|
|
|
91 |
),
|
92 |
min_value=1,
|
93 |
max_value=max(2, len(concepts) - 1),
|
94 |
+
value=DEFAULT(concepts),
|
95 |
step=1,
|
96 |
disabled=st.session_state.disabled or not concepts_ready,
|
97 |
)
|
98 |
|
99 |
|
100 |
+
def _get_dataset_name():
|
101 |
+
DEFAULT = SUPPORTED_DATASETS.index("imagenette")
|
102 |
+
return st.selectbox(
|
103 |
+
"Dataset",
|
104 |
+
options=SUPPORTED_DATASETS,
|
105 |
+
index=DEFAULT,
|
106 |
+
help=" ".join(
|
107 |
+
[
|
108 |
+
"Name of the dataset to use to train sampler.",
|
109 |
+
"Defaults to Imagenette.",
|
110 |
+
]
|
111 |
+
),
|
112 |
+
disabled=st.session_state.disabled,
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
def get_model_name():
|
117 |
return st.selectbox(
|
118 |
"Model to test",
|
|
|
167 |
|
168 |
|
169 |
def get_advanced_settings(concepts, concepts_ready):
|
170 |
+
with st.expander("Advanced settings"):
|
171 |
+
dataset_name = _get_dataset_name()
|
172 |
significance_level = _get_significance_level()
|
173 |
tau_max = _get_tau_max()
|
174 |
r = _get_number_of_tests()
|
175 |
cardinality = _get_cardinality(concepts, concepts_ready)
|
176 |
+
st.divider()
|
177 |
|
178 |
+
return significance_level, tau_max, r, cardinality, dataset_name
|
app_lib/utils.py
CHANGED
@@ -20,7 +20,7 @@ with open(supported_models_path, "r") as f:
|
|
20 |
|
21 |
|
22 |
SUPPORTED_DATASETS = []
|
23 |
-
with open(
|
24 |
for line in f:
|
25 |
dataset_name = line.strip()
|
26 |
SUPPORTED_DATASETS.append(dataset_name)
|
|
|
20 |
|
21 |
|
22 |
SUPPORTED_DATASETS = []
|
23 |
+
with open(supported_datasets_path, "r") as f:
|
24 |
for line in f:
|
25 |
dataset_name = line.strip()
|
26 |
SUPPORTED_DATASETS.append(dataset_name)
|
app_lib/viz.py
CHANGED
@@ -27,7 +27,7 @@ def viz_results():
|
|
27 |
results = st.session_state.results
|
28 |
|
29 |
if results is None:
|
30 |
-
st.info("
|
31 |
else:
|
32 |
rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
|
33 |
|
|
|
27 |
results = st.session_state.results
|
28 |
|
29 |
if results is None:
|
30 |
+
st.info("Test concepts to show results", icon="ℹ️")
|
31 |
else:
|
32 |
rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
|
33 |
|
header.md
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
-
|
|
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
+
Test the importance of semantic concepts for the predictions of a classifier. [[paper]](https://arxiv.org/pdf/2405.19146) [[code]](https://github.com/Sulam-Group/IBYDMT)
|
4 |
+
|
style.css
CHANGED
@@ -19,8 +19,23 @@ h1 {
|
|
19 |
justify-content: center;
|
20 |
}
|
21 |
|
22 |
-
[data-testid="stVerticalBlock"]:has(> [data-testid="
|
23 |
display: block;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
}
|
25 |
|
26 |
[data-testid="stPopover"] {
|
@@ -46,9 +61,7 @@ h1 {
|
|
46 |
}
|
47 |
}
|
48 |
|
49 |
-
[data-testid="stSpinner"] {
|
50 |
-
|
51 |
-
|
52 |
-
justify-content: center;
|
53 |
-
}
|
54 |
}
|
|
|
19 |
justify-content: center;
|
20 |
}
|
21 |
|
22 |
+
[data-testid="stVerticalBlock"]:has(> [data-testid="stExpander"]) {
|
23 |
display: block;
|
24 |
+
|
25 |
+
details {
|
26 |
+
border: 0;
|
27 |
+
}
|
28 |
+
|
29 |
+
summary {
|
30 |
+
padding: 0;
|
31 |
+
display: inline-flex;
|
32 |
+
align-items: center;
|
33 |
+
width: fit-content;
|
34 |
+
}
|
35 |
+
|
36 |
+
hr {
|
37 |
+
margin-top: 0;
|
38 |
+
}
|
39 |
}
|
40 |
|
41 |
[data-testid="stPopover"] {
|
|
|
61 |
}
|
62 |
}
|
63 |
|
64 |
+
[data-testid="stSpinner"]>div {
|
65 |
+
display: flex;
|
66 |
+
justify-content: center;
|
|
|
|
|
67 |
}
|