Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Commit
•
7e207f0
1
Parent(s):
5ead791
Update
Browse files- app.py +5 -7
- app_lib/main.py +23 -24
- app_lib/test.py +106 -71
- app_lib/user_input.py +86 -12
- header.md +1 -3
- ibydmt/test.py +5 -2
- requirements.txt +2 -1
- style.css +15 -0
app.py
CHANGED
@@ -2,11 +2,6 @@ import streamlit as st
|
|
2 |
|
3 |
from app_lib.main import main
|
4 |
|
5 |
-
with open("style.css", "r") as f:
|
6 |
-
style = f.read()
|
7 |
-
with open("header.md", "r") as f:
|
8 |
-
header = f.read()
|
9 |
-
|
10 |
if "sidebar_state" not in st.session_state:
|
11 |
st.session_state.sidebar_state = "collapsed"
|
12 |
if "disabled" not in st.session_state:
|
@@ -16,9 +11,12 @@ if "results" not in st.session_state:
|
|
16 |
|
17 |
st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
|
|
|
22 |
st.markdown(header)
|
23 |
|
24 |
if __name__ == "__main__":
|
|
|
2 |
|
3 |
from app_lib.main import main
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
if "sidebar_state" not in st.session_state:
|
6 |
st.session_state.sidebar_state = "collapsed"
|
7 |
if "disabled" not in st.session_state:
|
|
|
11 |
|
12 |
st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
|
13 |
|
14 |
+
with open("style.css", "r") as f:
|
15 |
+
style = f.read()
|
16 |
+
with open("header.md", "r") as f:
|
17 |
+
header = f.read()
|
18 |
|
19 |
+
st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
|
20 |
st.markdown(header)
|
21 |
|
22 |
if __name__ == "__main__":
|
app_lib/main.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
import torch
|
2 |
import streamlit as st
|
3 |
-
import time
|
4 |
|
5 |
from app_lib.user_input import (
|
6 |
-
get_cardinality,
|
7 |
get_class_name,
|
8 |
get_concepts,
|
9 |
get_image,
|
10 |
get_model_name,
|
|
|
11 |
)
|
12 |
-
from app_lib.test import test
|
13 |
from app_lib.viz import viz_results
|
14 |
|
15 |
|
@@ -20,6 +19,10 @@ def _disable():
|
|
20 |
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
21 |
columns = st.columns([0.40, 0.60])
|
22 |
|
|
|
|
|
|
|
|
|
23 |
with columns[0]:
|
24 |
st.header("Choose Image and Concepts")
|
25 |
|
@@ -41,8 +44,6 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
41 |
model_name = get_model_name()
|
42 |
class_name, class_ready, class_error = get_class_name()
|
43 |
concepts, concepts_ready, concepts_error = get_concepts()
|
44 |
-
cardinality = int(len(concepts) / 2)
|
45 |
-
# get_cardinality(concepts, concepts_ready)
|
46 |
|
47 |
ready = class_ready and concepts_ready
|
48 |
|
@@ -55,6 +56,10 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
55 |
st.error(error_message)
|
56 |
|
57 |
with st.container():
|
|
|
|
|
|
|
|
|
58 |
test_button = st.button(
|
59 |
"Test Concepts",
|
60 |
use_container_width=True,
|
@@ -62,25 +67,19 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
62 |
disabled=st.session_state.disabled or not ready,
|
63 |
)
|
64 |
|
65 |
-
with st.popover("Advanced settings", disabled=st.session_state.disabled):
|
66 |
-
st.markdown("Hello World 👋")
|
67 |
-
|
68 |
-
with columns[1]:
|
69 |
-
st.header("Results")
|
70 |
-
|
71 |
if test_button:
|
72 |
st.session_state.results = None
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
1 |
import torch
|
2 |
import streamlit as st
|
|
|
3 |
|
4 |
from app_lib.user_input import (
|
|
|
5 |
get_class_name,
|
6 |
get_concepts,
|
7 |
get_image,
|
8 |
get_model_name,
|
9 |
+
get_advanced_settings,
|
10 |
)
|
11 |
+
from app_lib.test import get_testing_config, test
|
12 |
from app_lib.viz import viz_results
|
13 |
|
14 |
|
|
|
19 |
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
20 |
columns = st.columns([0.40, 0.60])
|
21 |
|
22 |
+
with columns[1]:
|
23 |
+
st.header("Results")
|
24 |
+
viz_results()
|
25 |
+
|
26 |
with columns[0]:
|
27 |
st.header("Choose Image and Concepts")
|
28 |
|
|
|
44 |
model_name = get_model_name()
|
45 |
class_name, class_ready, class_error = get_class_name()
|
46 |
concepts, concepts_ready, concepts_error = get_concepts()
|
|
|
|
|
47 |
|
48 |
ready = class_ready and concepts_ready
|
49 |
|
|
|
56 |
st.error(error_message)
|
57 |
|
58 |
with st.container():
|
59 |
+
significance_level, tau_max, r, cardinality = get_advanced_settings(
|
60 |
+
concepts, concepts_ready
|
61 |
+
)
|
62 |
+
|
63 |
test_button = st.button(
|
64 |
"Test Concepts",
|
65 |
use_container_width=True,
|
|
|
67 |
disabled=st.session_state.disabled or not ready,
|
68 |
)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
if test_button:
|
71 |
st.session_state.results = None
|
72 |
|
73 |
+
testing_config = get_testing_config(
|
74 |
+
significance_level=significance_level, tau_max=tau_max, r=r
|
75 |
+
)
|
76 |
+
test(
|
77 |
+
testing_config,
|
78 |
+
image,
|
79 |
+
class_name,
|
80 |
+
concepts,
|
81 |
+
cardinality,
|
82 |
+
"imagenette",
|
83 |
+
model_name,
|
84 |
+
device,
|
85 |
+
)
|
app_lib/test.py
CHANGED
@@ -4,7 +4,6 @@ import open_clip
|
|
4 |
import h5py
|
5 |
import streamlit as st
|
6 |
import numpy as np
|
7 |
-
import pandas as pd
|
8 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
|
10 |
import ml_collections
|
@@ -16,16 +15,6 @@ from app_lib.ckde import cKDE
|
|
16 |
|
17 |
rng = np.random.default_rng()
|
18 |
|
19 |
-
testing_config = ml_collections.ConfigDict()
|
20 |
-
testing_config.significance_level = 0.05
|
21 |
-
testing_config.wealth = "ons"
|
22 |
-
testing_config.bet = "tanh"
|
23 |
-
testing_config.kernel = "rbf"
|
24 |
-
testing_config.kernel_scale_method = "quantile"
|
25 |
-
testing_config.kernel_scale = 0.5
|
26 |
-
testing_config.tau_max = 200
|
27 |
-
testing_config.r = 10
|
28 |
-
|
29 |
|
30 |
def _get_open_clip_model(model_name, device):
|
31 |
backbone = model_name.split(":")[-1]
|
@@ -45,19 +34,7 @@ def _get_clip_model(model_name, device):
|
|
45 |
return model, preprocess, tokenizer
|
46 |
|
47 |
|
48 |
-
def
|
49 |
-
dataset_path = hf_hub_download(
|
50 |
-
repo_id="jacopoteneggi/IBYDMT",
|
51 |
-
filename=f"{dataset_name}_{model_name}_train.h5",
|
52 |
-
repo_type="dataset",
|
53 |
-
)
|
54 |
-
|
55 |
-
with h5py.File(dataset_path, "r") as dataset:
|
56 |
-
embedding = dataset["embedding"][:]
|
57 |
-
return embedding
|
58 |
-
|
59 |
-
|
60 |
-
def load_model(model_name, device):
|
61 |
if "open_clip" in model_name:
|
62 |
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
63 |
elif "clip" in model_name:
|
@@ -67,7 +44,7 @@ def load_model(model_name, device):
|
|
67 |
|
68 |
@torch.no_grad()
|
69 |
@torch.cuda.amp.autocast()
|
70 |
-
def
|
71 |
concepts_text = tokenizer(concepts).to(device)
|
72 |
|
73 |
concept_features = model.encode_text(concepts_text)
|
@@ -77,7 +54,7 @@ def encode_concepts(tokenizer, model, concepts, device):
|
|
77 |
|
78 |
@torch.no_grad()
|
79 |
@torch.cuda.amp.autocast()
|
80 |
-
def
|
81 |
image = preprocess(image)
|
82 |
image = image.unsqueeze(0)
|
83 |
image = image.to(device)
|
@@ -89,7 +66,7 @@ def encode_image(model, preprocess, image, device):
|
|
89 |
|
90 |
@torch.no_grad()
|
91 |
@torch.cuda.amp.autocast()
|
92 |
-
def
|
93 |
class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
|
94 |
|
95 |
class_features = model.encode_text(class_text)
|
@@ -97,12 +74,24 @@ def encode_class_name(tokenizer, model, class_name, device):
|
|
97 |
return class_features.cpu().numpy()
|
98 |
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
101 |
sample_idx = list(set(range(len(concepts))) - {concept_idx})
|
102 |
return rng.permutation(sample_idx)[:cardinality].tolist()
|
103 |
|
104 |
|
105 |
-
def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
|
106 |
def cond_p(z, cond_idx, m):
|
107 |
_, sample_h = sampler.sample(z, cond_idx, m=m)
|
108 |
return sample_h
|
@@ -118,9 +107,16 @@ def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
|
|
118 |
|
119 |
tester = xSKIT(testing_config)
|
120 |
rejected, tau = tester.test(
|
121 |
-
z,
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
123 |
wealth = tester.wealth._wealth
|
|
|
124 |
|
125 |
rejected_hist.append(rejected)
|
126 |
tau_hist.append(tau)
|
@@ -136,60 +132,99 @@ def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
|
|
136 |
}
|
137 |
|
138 |
|
139 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
with st.spinner("Loading model"):
|
141 |
-
model, preprocess, tokenizer =
|
142 |
|
143 |
with st.spinner("Encoding concepts"):
|
144 |
-
cbm =
|
145 |
|
146 |
with st.spinner("Encoding image"):
|
147 |
-
h =
|
148 |
z = h @ cbm.T
|
149 |
z = z.squeeze()
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
semantics = embedding @ cbm.T
|
156 |
-
sampler = cKDE(embedding, semantics)
|
157 |
-
|
158 |
-
classifier = encode_class_name(tokenizer, model, class_name, device)
|
159 |
-
|
160 |
-
with ThreadPoolExecutor() as executor:
|
161 |
-
futures = [
|
162 |
-
executor.submit(
|
163 |
-
_test, z, concept_idx, concepts, cardinality, sampler, classifier
|
164 |
-
)
|
165 |
-
for concept_idx in range(len(concepts))
|
166 |
-
]
|
167 |
-
|
168 |
-
results = []
|
169 |
-
for idx, future in enumerate(as_completed(futures)):
|
170 |
-
results.append(future.result())
|
171 |
-
progress_bar.progress((idx + 1) / len(concepts))
|
172 |
-
|
173 |
-
rejected = np.empty((testing_config.r, len(concepts)))
|
174 |
-
tau = np.empty((testing_config.r, len(concepts)))
|
175 |
-
wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
|
176 |
|
177 |
-
|
178 |
-
|
|
|
179 |
|
180 |
-
|
181 |
-
tau[:, concept_idx] = np.array(_results["tau"])
|
182 |
-
wealth[:, :, concept_idx] = np.array(_results["wealth"])
|
183 |
|
184 |
-
|
|
|
|
|
|
|
185 |
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
st.session_state.disabled = False
|
195 |
st.experimental_rerun()
|
|
|
4 |
import h5py
|
5 |
import streamlit as st
|
6 |
import numpy as np
|
|
|
7 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
|
9 |
import ml_collections
|
|
|
15 |
|
16 |
rng = np.random.default_rng()
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def _get_open_clip_model(model_name, device):
|
20 |
backbone = model_name.split(":")[-1]
|
|
|
34 |
return model, preprocess, tokenizer
|
35 |
|
36 |
|
37 |
+
def _load_model(model_name, device):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
if "open_clip" in model_name:
|
39 |
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
40 |
elif "clip" in model_name:
|
|
|
44 |
|
45 |
@torch.no_grad()
|
46 |
@torch.cuda.amp.autocast()
|
47 |
+
def _encode_concepts(tokenizer, model, concepts, device):
|
48 |
concepts_text = tokenizer(concepts).to(device)
|
49 |
|
50 |
concept_features = model.encode_text(concepts_text)
|
|
|
54 |
|
55 |
@torch.no_grad()
|
56 |
@torch.cuda.amp.autocast()
|
57 |
+
def _encode_image(model, preprocess, image, device):
|
58 |
image = preprocess(image)
|
59 |
image = image.unsqueeze(0)
|
60 |
image = image.to(device)
|
|
|
66 |
|
67 |
@torch.no_grad()
|
68 |
@torch.cuda.amp.autocast()
|
69 |
+
def _encode_class_name(tokenizer, model, class_name, device):
|
70 |
class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
|
71 |
|
72 |
class_features = model.encode_text(class_text)
|
|
|
74 |
return class_features.cpu().numpy()
|
75 |
|
76 |
|
77 |
+
def _load_dataset(dataset_name, model_name):
|
78 |
+
dataset_path = hf_hub_download(
|
79 |
+
repo_id="jacopoteneggi/IBYDMT",
|
80 |
+
filename=f"{dataset_name}_{model_name}_train.h5",
|
81 |
+
repo_type="dataset",
|
82 |
+
)
|
83 |
+
|
84 |
+
with h5py.File(dataset_path, "r") as dataset:
|
85 |
+
embedding = dataset["embedding"][:]
|
86 |
+
return embedding
|
87 |
+
|
88 |
+
|
89 |
def _sample_random_subset(concept_idx, concepts, cardinality):
|
90 |
sample_idx = list(set(range(len(concepts))) - {concept_idx})
|
91 |
return rng.permutation(sample_idx)[:cardinality].tolist()
|
92 |
|
93 |
|
94 |
+
def _test(testing_config, z, concept_idx, concepts, cardinality, sampler, classifier):
|
95 |
def cond_p(z, cond_idx, m):
|
96 |
_, sample_h = sampler.sample(z, cond_idx, m=m)
|
97 |
return sample_h
|
|
|
107 |
|
108 |
tester = xSKIT(testing_config)
|
109 |
rejected, tau = tester.test(
|
110 |
+
z,
|
111 |
+
concept_idx,
|
112 |
+
subset_idx,
|
113 |
+
cond_p,
|
114 |
+
f,
|
115 |
+
interrupt_on="max_wealth",
|
116 |
+
max_wealth=100,
|
117 |
)
|
118 |
wealth = tester.wealth._wealth
|
119 |
+
wealth = wealth + [wealth[-1]] * (testing_config.tau_max - len(wealth))
|
120 |
|
121 |
rejected_hist.append(rejected)
|
122 |
tau_hist.append(tau)
|
|
|
132 |
}
|
133 |
|
134 |
|
135 |
+
def get_testing_config(**kwargs):
|
136 |
+
testing_config = st.session_state.testing_config = ml_collections.ConfigDict()
|
137 |
+
testing_config.significance_level = kwargs.get("significance_level", 0.05)
|
138 |
+
testing_config.wealth = kwargs.get("wealth", "ons")
|
139 |
+
testing_config.bet = kwargs.get("bet", "tanh")
|
140 |
+
testing_config.kernel = kwargs.get("kernel", "rbf")
|
141 |
+
testing_config.kernel_scale_method = kwargs.get("kernel_scale_method", "quantile")
|
142 |
+
testing_config.kernel_scale = kwargs.get("kernel_scale", 0.5)
|
143 |
+
testing_config.tau_max = kwargs.get("tau_max", 200)
|
144 |
+
testing_config.r = kwargs.get("r", 10)
|
145 |
+
return testing_config
|
146 |
+
|
147 |
+
|
148 |
+
def test(
|
149 |
+
testing_config,
|
150 |
+
image,
|
151 |
+
class_name,
|
152 |
+
concepts,
|
153 |
+
cardinality,
|
154 |
+
dataset_name,
|
155 |
+
model_name,
|
156 |
+
device,
|
157 |
+
):
|
158 |
with st.spinner("Loading model"):
|
159 |
+
model, preprocess, tokenizer = _load_model(model_name, device)
|
160 |
|
161 |
with st.spinner("Encoding concepts"):
|
162 |
+
cbm = _encode_concepts(tokenizer, model, concepts, device)
|
163 |
|
164 |
with st.spinner("Encoding image"):
|
165 |
+
h = _encode_image(model, preprocess, image, device)
|
166 |
z = h @ cbm.T
|
167 |
z = z.squeeze()
|
168 |
|
169 |
+
progress_bar = st.progress(
|
170 |
+
0,
|
171 |
+
text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
|
172 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
embedding = _load_dataset(dataset_name, model_name)
|
175 |
+
semantics = embedding @ cbm.T
|
176 |
+
sampler = cKDE(embedding, semantics)
|
177 |
|
178 |
+
classifier = _encode_class_name(tokenizer, model, class_name, device)
|
|
|
|
|
179 |
|
180 |
+
progress_bar.progress(
|
181 |
+
1 / (len(concepts) + 1),
|
182 |
+
text=f"Testing concepts (can take a few minutes) [0 / {len(concepts)} completed]",
|
183 |
+
)
|
184 |
|
185 |
+
with ThreadPoolExecutor() as executor:
|
186 |
+
futures = [
|
187 |
+
executor.submit(
|
188 |
+
_test,
|
189 |
+
testing_config,
|
190 |
+
z,
|
191 |
+
concept_idx,
|
192 |
+
concepts,
|
193 |
+
cardinality,
|
194 |
+
sampler,
|
195 |
+
classifier,
|
196 |
+
)
|
197 |
+
for concept_idx in range(len(concepts))
|
198 |
+
]
|
199 |
+
|
200 |
+
results = []
|
201 |
+
for idx, future in enumerate(as_completed(futures)):
|
202 |
+
results.append(future.result())
|
203 |
+
progress_bar.progress(
|
204 |
+
(idx + 2) / (len(concepts) + 1),
|
205 |
+
text=f"Testing concepts (can take a few minutes) [{idx + 1} / {len(concepts)} completed]",
|
206 |
+
)
|
207 |
+
|
208 |
+
rejected = np.empty((testing_config.r, len(concepts)))
|
209 |
+
tau = np.empty((testing_config.r, len(concepts)))
|
210 |
+
wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
|
211 |
+
|
212 |
+
for _results in results:
|
213 |
+
concept_idx = concepts.index(_results["concept"])
|
214 |
+
|
215 |
+
rejected[:, concept_idx] = np.array(_results["rejected"])
|
216 |
+
tau[:, concept_idx] = np.array(_results["tau"])
|
217 |
+
wealth[:, :, concept_idx] = np.array(_results["wealth"])
|
218 |
+
|
219 |
+
tau /= testing_config.tau_max
|
220 |
+
|
221 |
+
st.session_state.results = {
|
222 |
+
"significance_level": testing_config.significance_level,
|
223 |
+
"concepts": concepts,
|
224 |
+
"rejected": rejected,
|
225 |
+
"tau": tau,
|
226 |
+
"wealth": wealth,
|
227 |
+
}
|
228 |
|
229 |
st.session_state.disabled = False
|
230 |
st.experimental_rerun()
|
app_lib/user_input.py
CHANGED
@@ -20,6 +20,82 @@ def _validate_concepts(concepts):
|
|
20 |
return (False, "Maximum 10 concepts allowed")
|
21 |
return (True, None)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def get_model_name():
|
24 |
return st.selectbox(
|
25 |
"Model to test",
|
@@ -58,8 +134,8 @@ def get_class_name():
|
|
58 |
|
59 |
def get_concepts():
|
60 |
concepts = st.text_area(
|
61 |
-
"Concepts to test
|
62 |
-
help="List of concepts to test the predictions of the model with. Write one concept per line.",
|
63 |
height=160,
|
64 |
value="piano\ncute\nwhiskers\nmusic\nwild",
|
65 |
disabled=st.session_state.disabled,
|
@@ -73,13 +149,11 @@ def get_concepts():
|
|
73 |
return concepts, concepts_ready, concepts_error
|
74 |
|
75 |
|
76 |
-
def
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
disabled=st.session_state.disabled or not concepts_ready,
|
85 |
-
)
|
|
|
20 |
return (False, "Maximum 10 concepts allowed")
|
21 |
return (True, None)
|
22 |
|
23 |
+
|
24 |
+
def _get_significance_level():
|
25 |
+
DEFAULT = 0.05
|
26 |
+
return st.slider(
|
27 |
+
"Significance level",
|
28 |
+
help=" ".join(
|
29 |
+
[
|
30 |
+
"The level of significance of the tests.",
|
31 |
+
f"Defaults to {DEFAULT:.2F}.",
|
32 |
+
]
|
33 |
+
),
|
34 |
+
min_value=0.01,
|
35 |
+
max_value=1.0,
|
36 |
+
value=DEFAULT,
|
37 |
+
step=0.01,
|
38 |
+
disabled=st.session_state.disabled,
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def _get_tau_max():
|
43 |
+
DEFAULT = 200
|
44 |
+
return int(
|
45 |
+
st.slider(
|
46 |
+
"Duration of test",
|
47 |
+
help=" ".join(
|
48 |
+
[
|
49 |
+
"The maximum number of steps for each test.",
|
50 |
+
f"Defaults to {DEFAULT}.",
|
51 |
+
]
|
52 |
+
),
|
53 |
+
min_value=1,
|
54 |
+
max_value=1000,
|
55 |
+
step=1,
|
56 |
+
value=DEFAULT,
|
57 |
+
disabled=st.session_state.disabled,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def _get_number_of_tests():
|
63 |
+
DEFAULT = 20
|
64 |
+
return int(
|
65 |
+
st.slider(
|
66 |
+
"Number of tests per concept",
|
67 |
+
help=" ".join(
|
68 |
+
[
|
69 |
+
"The number of tests to average for each concept.",
|
70 |
+
f"Defaults to {DEFAULT}.",
|
71 |
+
]
|
72 |
+
),
|
73 |
+
min_value=1,
|
74 |
+
max_value=100,
|
75 |
+
step=1,
|
76 |
+
value=DEFAULT,
|
77 |
+
disabled=st.session_state.disabled,
|
78 |
+
)
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def _get_cardinality(concepts, concepts_ready):
|
83 |
+
return st.slider(
|
84 |
+
"Size of conditioning set",
|
85 |
+
help=" ".join(
|
86 |
+
[
|
87 |
+
"The number of concepts to condition model predictions on.",
|
88 |
+
"Defaults to half of the number of concepts.",
|
89 |
+
]
|
90 |
+
),
|
91 |
+
min_value=1,
|
92 |
+
max_value=max(2, len(concepts) - 1),
|
93 |
+
value=int(len(concepts) / 2),
|
94 |
+
step=1,
|
95 |
+
disabled=st.session_state.disabled or not concepts_ready,
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
def get_model_name():
|
100 |
return st.selectbox(
|
101 |
"Model to test",
|
|
|
134 |
|
135 |
def get_concepts():
|
136 |
concepts = st.text_area(
|
137 |
+
"Concepts to test",
|
138 |
+
help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
|
139 |
height=160,
|
140 |
value="piano\ncute\nwhiskers\nmusic\nwild",
|
141 |
disabled=st.session_state.disabled,
|
|
|
149 |
return concepts, concepts_ready, concepts_error
|
150 |
|
151 |
|
152 |
+
def get_advanced_settings(concepts, concepts_ready):
|
153 |
+
with st.popover("Advanced settings", disabled=st.session_state.disabled):
|
154 |
+
significance_level = _get_significance_level()
|
155 |
+
tau_max = _get_tau_max()
|
156 |
+
r = _get_number_of_tests()
|
157 |
+
cardinality = _get_cardinality(concepts, concepts_ready)
|
158 |
+
|
159 |
+
return significance_level, tau_max, r, cardinality
|
|
|
|
header.md
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
-
Official
|
4 |
-
|
5 |
-
---
|
|
|
1 |
# 🤔 I Bet You Did Not Mean That
|
2 |
|
3 |
+
Official 🤗 Space for the paper [*I Bet You Did Not Mean That: Testing Semantic Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
|
|
|
|
ibydmt/test.py
CHANGED
@@ -141,7 +141,8 @@ class xSKIT(SequentialTester):
|
|
141 |
C: list[int],
|
142 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
143 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
144 |
-
|
|
|
145 |
) -> Tuple[bool, int]:
|
146 |
sample = functools.partial(self._sample, z, j, C, cond_p, model)
|
147 |
|
@@ -159,6 +160,8 @@ class xSKIT(SequentialTester):
|
|
159 |
|
160 |
if self.wealth.rejected:
|
161 |
tau = min(tau, t)
|
162 |
-
if
|
|
|
|
|
163 |
break
|
164 |
return (self.wealth.rejected, tau)
|
|
|
141 |
C: list[int],
|
142 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
143 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
144 |
+
interrupt_on: str = "rejection",
|
145 |
+
max_wealth: float = None,
|
146 |
) -> Tuple[bool, int]:
|
147 |
sample = functools.partial(self._sample, z, j, C, cond_p, model)
|
148 |
|
|
|
160 |
|
161 |
if self.wealth.rejected:
|
162 |
tau = min(tau, t)
|
163 |
+
if interrupt_on == "rejection":
|
164 |
+
break
|
165 |
+
if interrupt_on == "max_wealth" and self.wealth._w >= max_wealth:
|
166 |
break
|
167 |
return (self.wealth.rejected, tau)
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ open_clip_torch
|
|
4 |
h5py
|
5 |
ml_collections
|
6 |
jaxtyping
|
7 |
-
scikit-learn
|
|
|
|
4 |
h5py
|
5 |
ml_collections
|
6 |
jaxtyping
|
7 |
+
scikit-learn
|
8 |
+
plotly
|
style.css
CHANGED
@@ -33,7 +33,22 @@ h1 {
|
|
33 |
}
|
34 |
}
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
button:hover>div:first-of-type>p {
|
37 |
text-decoration: underline;
|
38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
}
|
|
|
33 |
}
|
34 |
}
|
35 |
|
36 |
+
button:active {
|
37 |
+
background: white;
|
38 |
+
}
|
39 |
+
|
40 |
+
button:focus:not(:active) {
|
41 |
+
color: rgb(49, 51, 63);
|
42 |
+
}
|
43 |
+
|
44 |
button:hover>div:first-of-type>p {
|
45 |
text-decoration: underline;
|
46 |
}
|
47 |
+
}
|
48 |
+
|
49 |
+
[data-testid="stSpinner"] {
|
50 |
+
>div {
|
51 |
+
display: flex;
|
52 |
+
justify-content: center;
|
53 |
+
}
|
54 |
}
|