jacopoteneggi commited on
Commit
b30bcef
·
verified ·
1 Parent(s): 7e207f0
Files changed (8) hide show
  1. README.md +1 -1
  2. app_lib/main.py +2 -2
  3. app_lib/test.py +3 -3
  4. app_lib/user_input.py +24 -5
  5. app_lib/utils.py +1 -1
  6. app_lib/viz.py +1 -1
  7. header.md +2 -1
  8. style.css +19 -6
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: I Bet You Did Not Mean That
3
- emoji: 🔎
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: streamlit
 
1
  ---
2
  title: I Bet You Did Not Mean That
3
+ emoji: 🤔
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: streamlit
app_lib/main.py CHANGED
@@ -56,7 +56,7 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
56
  st.error(error_message)
57
 
58
  with st.container():
59
- significance_level, tau_max, r, cardinality = get_advanced_settings(
60
  concepts, concepts_ready
61
  )
62
 
@@ -79,7 +79,7 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
79
  class_name,
80
  concepts,
81
  cardinality,
82
- "imagenette",
83
  model_name,
84
  device,
85
  )
 
56
  st.error(error_message)
57
 
58
  with st.container():
59
+ significance_level, tau_max, r, cardinality, dataset_name = get_advanced_settings(
60
  concepts, concepts_ready
61
  )
62
 
 
79
  class_name,
80
  concepts,
81
  cardinality,
82
+ dataset_name,
83
  model_name,
84
  device,
85
  )
app_lib/test.py CHANGED
@@ -168,7 +168,7 @@ def test(
168
 
169
  progress_bar = st.progress(
170
  0,
171
- text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
172
  )
173
 
174
  embedding = _load_dataset(dataset_name, model_name)
@@ -179,7 +179,7 @@ def test(
179
 
180
  progress_bar.progress(
181
  1 / (len(concepts) + 1),
182
- text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
183
  )
184
 
185
  with ThreadPoolExecutor() as executor:
@@ -202,7 +202,7 @@ def test(
202
  results.append(future.result())
203
  progress_bar.progress(
204
  (idx + 2) / (len(concepts) + 1),
205
- text=f"Testing concepts (can take a few minutes) [{idx + 1} / {len(concepts)} completed]",
206
  )
207
 
208
  rejected = np.empty((testing_config.r, len(concepts)))
 
168
 
169
  progress_bar = st.progress(
170
  0,
171
+ text=f"Testing concepts (can take up to a minute) [0 / {len(concepts)} completed]",
172
  )
173
 
174
  embedding = _load_dataset(dataset_name, model_name)
 
179
 
180
  progress_bar.progress(
181
  1 / (len(concepts) + 1),
182
+ text=f"Testing concepts (can take up to a minute) [0 / {len(concepts)} completed]",
183
  )
184
 
185
  with ThreadPoolExecutor() as executor:
 
202
  results.append(future.result())
203
  progress_bar.progress(
204
  (idx + 2) / (len(concepts) + 1),
205
+ text=f"Testing concepts (can take up to a minute) [{idx + 1} / {len(concepts)} completed]",
206
  )
207
 
208
  rejected = np.empty((testing_config.r, len(concepts)))
app_lib/user_input.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from PIL import Image
3
  from streamlit_image_select import image_select
4
 
5
- from app_lib.utils import SUPPORTED_MODELS
6
 
7
 
8
  def _validate_class_name(class_name):
@@ -43,7 +43,7 @@ def _get_tau_max():
43
  DEFAULT = 200
44
  return int(
45
  st.slider(
46
- "Duration of test",
47
  help=" ".join(
48
  [
49
  "The maximum number of steps for each test.",
@@ -80,6 +80,7 @@ def _get_number_of_tests():
80
 
81
 
82
  def _get_cardinality(concepts, concepts_ready):
 
83
  return st.slider(
84
  "Size of conditioning set",
85
  help=" ".join(
@@ -90,12 +91,28 @@ def _get_cardinality(concepts, concepts_ready):
90
  ),
91
  min_value=1,
92
  max_value=max(2, len(concepts) - 1),
93
- value=int(len(concepts) / 2),
94
  step=1,
95
  disabled=st.session_state.disabled or not concepts_ready,
96
  )
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def get_model_name():
100
  return st.selectbox(
101
  "Model to test",
@@ -150,10 +167,12 @@ def get_concepts():
150
 
151
 
152
  def get_advanced_settings(concepts, concepts_ready):
153
- with st.popover("Advanced settings", disabled=st.session_state.disabled):
 
154
  significance_level = _get_significance_level()
155
  tau_max = _get_tau_max()
156
  r = _get_number_of_tests()
157
  cardinality = _get_cardinality(concepts, concepts_ready)
 
158
 
159
- return significance_level, tau_max, r, cardinality
 
2
  from PIL import Image
3
  from streamlit_image_select import image_select
4
 
5
+ from app_lib.utils import SUPPORTED_MODELS, SUPPORTED_DATASETS
6
 
7
 
8
  def _validate_class_name(class_name):
 
43
  DEFAULT = 200
44
  return int(
45
  st.slider(
46
+ "Length of test",
47
  help=" ".join(
48
  [
49
  "The maximum number of steps for each test.",
 
80
 
81
 
82
  def _get_cardinality(concepts, concepts_ready):
83
+ DEFAULT = lambda concepts: int(len(concepts) / 2)
84
  return st.slider(
85
  "Size of conditioning set",
86
  help=" ".join(
 
91
  ),
92
  min_value=1,
93
  max_value=max(2, len(concepts) - 1),
94
+ value=DEFAULT(concepts),
95
  step=1,
96
  disabled=st.session_state.disabled or not concepts_ready,
97
  )
98
 
99
 
100
+ def _get_dataset_name():
101
+ DEFAULT = SUPPORTED_DATASETS.index("imagenette")
102
+ return st.selectbox(
103
+ "Dataset",
104
+ options=SUPPORTED_DATASETS,
105
+ index=DEFAULT,
106
+ help=" ".join(
107
+ [
108
+ "Name of the dataset to use to train sampler.",
109
+ "Defaults to Imagenette.",
110
+ ]
111
+ ),
112
+ disabled=st.session_state.disabled,
113
+ )
114
+
115
+
116
  def get_model_name():
117
  return st.selectbox(
118
  "Model to test",
 
167
 
168
 
169
  def get_advanced_settings(concepts, concepts_ready):
170
+ with st.expander("Advanced settings"):
171
+ dataset_name = _get_dataset_name()
172
  significance_level = _get_significance_level()
173
  tau_max = _get_tau_max()
174
  r = _get_number_of_tests()
175
  cardinality = _get_cardinality(concepts, concepts_ready)
176
+ st.divider()
177
 
178
+ return significance_level, tau_max, r, cardinality, dataset_name
app_lib/utils.py CHANGED
@@ -20,7 +20,7 @@ with open(supported_models_path, "r") as f:
20
 
21
 
22
  SUPPORTED_DATASETS = []
23
- with open(supported_models_path, "r") as f:
24
  for line in f:
25
  dataset_name = line.strip()
26
  SUPPORTED_DATASETS.append(dataset_name)
 
20
 
21
 
22
  SUPPORTED_DATASETS = []
23
+ with open(supported_datasets_path, "r") as f:
24
  for line in f:
25
  dataset_name = line.strip()
26
  SUPPORTED_DATASETS.append(dataset_name)
app_lib/viz.py CHANGED
@@ -27,7 +27,7 @@ def viz_results():
27
  results = st.session_state.results
28
 
29
  if results is None:
30
- st.info("Run tests to show results", icon="ℹ️")
31
  else:
32
  rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
33
 
 
27
  results = st.session_state.results
28
 
29
  if results is None:
30
+ st.info("Test concepts to show results", icon="ℹ️")
31
  else:
32
  rank_tab, wealth_tab = st.tabs(["Rank of importance", "Wealth process"])
33
 
header.md CHANGED
@@ -1,3 +1,4 @@
1
  # 🤔 I Bet You Did Not Mean That
2
 
3
- Official 🤗 Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
 
 
1
  # 🤔 I Bet You Did Not Mean That
2
 
3
+ Test the importance of semantic concepts for the predictions of a classifier. [[paper]](https://arxiv.org/pdf/2405.19146) [[code]](https://github.com/Sulam-Group/IBYDMT)
4
+
style.css CHANGED
@@ -19,8 +19,23 @@ h1 {
19
  justify-content: center;
20
  }
21
 
22
- [data-testid="stVerticalBlock"]:has(> [data-testid="stPopover"]) {
23
  display: block;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
  [data-testid="stPopover"] {
@@ -46,9 +61,7 @@ h1 {
46
  }
47
  }
48
 
49
- [data-testid="stSpinner"] {
50
- >div {
51
- display: flex;
52
- justify-content: center;
53
- }
54
  }
 
19
  justify-content: center;
20
  }
21
 
22
+ [data-testid="stVerticalBlock"]:has(> [data-testid="stExpander"]) {
23
  display: block;
24
+
25
+ details {
26
+ border: 0;
27
+ }
28
+
29
+ summary {
30
+ padding: 0;
31
+ display: inline-flex;
32
+ align-items: center;
33
+ width: fit-content;
34
+ }
35
+
36
+ hr {
37
+ margin-top: 0;
38
+ }
39
  }
40
 
41
  [data-testid="stPopover"] {
 
61
  }
62
  }
63
 
64
+ [data-testid="stSpinner"]>div {
65
+ display: flex;
66
+ justify-content: center;
 
 
67
  }