Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Update
Browse files- app.py +1 -6
- app_lib/__pycache__/__init__.cpython-310.pyc +0 -0
- app_lib/__pycache__/ckde.cpython-310.pyc +0 -0
- app_lib/__pycache__/main.cpython-310.pyc +0 -0
- app_lib/__pycache__/test.cpython-310.pyc +0 -0
- app_lib/__pycache__/user_input.cpython-310.pyc +0 -0
- app_lib/ckde.py +10 -13
- app_lib/main.py +13 -31
- app_lib/test.py +100 -3
- app_lib/user_input.py +1 -1
- ibydmt/__pycache__/__init__.cpython-310.pyc +0 -0
- ibydmt/__pycache__/bet.cpython-310.pyc +0 -0
- ibydmt/__pycache__/payoff.cpython-310.pyc +0 -0
- ibydmt/__pycache__/test.cpython-310.pyc +0 -0
- ibydmt/__pycache__/utils.cpython-310.pyc +0 -0
- ibydmt/__pycache__/wealth.cpython-310.pyc +0 -0
- ibydmt/payoff.py +1 -0
- ibydmt/test.py +7 -2
- ibydmt/wealth.py +17 -14
app.py
CHANGED
@@ -27,11 +27,6 @@ st.markdown(
|
|
27 |
[data-testid="stHorizontalBlock"] {
|
28 |
align-items: center;
|
29 |
}
|
30 |
-
div.stSpinner > div {
|
31 |
-
text-align:center;
|
32 |
-
align-items: center;
|
33 |
-
justify-content: center;
|
34 |
-
}
|
35 |
</style>
|
36 |
""",
|
37 |
unsafe_allow_html=True,
|
@@ -44,7 +39,7 @@ st.markdown(
|
|
44 |
Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
|
45 |
|
46 |
---
|
47 |
-
|
48 |
)
|
49 |
|
50 |
if __name__ == "__main__":
|
|
|
27 |
[data-testid="stHorizontalBlock"] {
|
28 |
align-items: center;
|
29 |
}
|
|
|
|
|
|
|
|
|
|
|
30 |
</style>
|
31 |
""",
|
32 |
unsafe_allow_html=True,
|
|
|
39 |
Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
|
40 |
|
41 |
---
|
42 |
+
""",
|
43 |
)
|
44 |
|
45 |
if __name__ == "__main__":
|
app_lib/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/__init__.cpython-310.pyc and b/app_lib/__pycache__/__init__.cpython-310.pyc differ
|
|
app_lib/__pycache__/ckde.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
app_lib/__pycache__/main.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/main.cpython-310.pyc and b/app_lib/__pycache__/main.cpython-310.pyc differ
|
|
app_lib/__pycache__/test.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/test.cpython-310.pyc and b/app_lib/__pycache__/test.cpython-310.pyc differ
|
|
app_lib/__pycache__/user_input.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/user_input.cpython-310.pyc and b/app_lib/__pycache__/user_input.cpython-310.pyc differ
|
|
app_lib/ckde.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
import numpy as np
|
2 |
-
import torch
|
3 |
from scipy.spatial.distance import cdist
|
4 |
from scipy.stats import gaussian_kde
|
5 |
|
|
|
6 |
class cKDE:
|
7 |
-
def __init__(
|
8 |
-
|
9 |
-
|
10 |
-
self.metric =
|
11 |
-
self.scale_method =
|
12 |
-
self.scale =
|
13 |
|
14 |
-
self.
|
15 |
-
self.
|
16 |
|
17 |
def _quantile_scale(self, Z_cond_dist):
|
18 |
return np.quantile(Z_cond_dist, self.scale)
|
@@ -62,7 +62,7 @@ class cKDE:
|
|
62 |
dist = cdist(z, self.Z, metric=self.metric)
|
63 |
return np.argmin(dist, axis=-1)
|
64 |
|
65 |
-
def sample(self, z, cond_idx, m=1
|
66 |
if z.ndim == 1:
|
67 |
z = z.reshape(1, -1)
|
68 |
|
@@ -71,7 +71,4 @@ class cKDE:
|
|
71 |
nn_idx = self.nearest_neighbor(sample_z)
|
72 |
sample_h = self.H[nn_idx]
|
73 |
|
74 |
-
if return_images:
|
75 |
-
sample_images = torch.stack([self.dataset[_idx][0] for _idx in nn_idx])
|
76 |
-
return sample_z, sample_h, sample_images
|
77 |
return sample_z, sample_h
|
|
|
1 |
import numpy as np
|
|
|
2 |
from scipy.spatial.distance import cdist
|
3 |
from scipy.stats import gaussian_kde
|
4 |
|
5 |
+
|
6 |
class cKDE:
|
7 |
+
def __init__(
|
8 |
+
self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000
|
9 |
+
):
|
10 |
+
self.metric = metric
|
11 |
+
self.scale_method = scale_method
|
12 |
+
self.scale = scale
|
13 |
|
14 |
+
self.H = embedding
|
15 |
+
self.Z = semantics
|
16 |
|
17 |
def _quantile_scale(self, Z_cond_dist):
|
18 |
return np.quantile(Z_cond_dist, self.scale)
|
|
|
62 |
dist = cdist(z, self.Z, metric=self.metric)
|
63 |
return np.argmin(dist, axis=-1)
|
64 |
|
65 |
+
def sample(self, z, cond_idx, m=1):
|
66 |
if z.ndim == 1:
|
67 |
z = z.reshape(1, -1)
|
68 |
|
|
|
71 |
nn_idx = self.nearest_neighbor(sample_z)
|
72 |
sample_h = self.H[nn_idx]
|
73 |
|
|
|
|
|
|
|
74 |
return sample_z, sample_h
|
app_lib/main.py
CHANGED
@@ -9,13 +9,7 @@ from app_lib.user_input import (
|
|
9 |
get_image,
|
10 |
get_model_name,
|
11 |
)
|
12 |
-
from app_lib.test import
|
13 |
-
load_dataset,
|
14 |
-
load_model,
|
15 |
-
encode_image,
|
16 |
-
encode_concepts,
|
17 |
-
encode_class_name,
|
18 |
-
)
|
19 |
|
20 |
|
21 |
def _disable():
|
@@ -67,27 +61,15 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
|
67 |
)
|
68 |
|
69 |
with columns[1]:
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
with st.spinner("Preparing zero-shot classifier"):
|
84 |
-
classifier = encode_class_name(tokenizer, model, class_name, device)
|
85 |
-
|
86 |
-
with st.spinner("Encoding image"):
|
87 |
-
h = encode_image(model, preprocess, image, device)
|
88 |
-
z = h @ cbm.T
|
89 |
-
print(h.shape, cbm.shape, z.shape)
|
90 |
-
time.sleep(2)
|
91 |
-
|
92 |
-
st.session_state.disabled = False
|
93 |
-
st.experimental_rerun()
|
|
|
9 |
get_image,
|
10 |
get_model_name,
|
11 |
)
|
12 |
+
from app_lib.test import test
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def _disable():
|
|
|
61 |
)
|
62 |
|
63 |
with columns[1]:
|
64 |
+
_, centercol, _ = st.columns(3)
|
65 |
+
with centercol:
|
66 |
+
if test_button:
|
67 |
+
test(
|
68 |
+
image,
|
69 |
+
class_name,
|
70 |
+
concepts,
|
71 |
+
cardinality,
|
72 |
+
"imagenette",
|
73 |
+
model_name,
|
74 |
+
device,
|
75 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_lib/test.py
CHANGED
@@ -2,10 +2,29 @@ import torch
|
|
2 |
import clip
|
3 |
import open_clip
|
4 |
import h5py
|
|
|
|
|
|
|
|
|
5 |
|
|
|
6 |
from huggingface_hub import hf_hub_download
|
7 |
|
|
|
8 |
from app_lib.utils import SUPPORTED_MODELS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
def _get_open_clip_model(model_name, device):
|
@@ -39,7 +58,6 @@ def load_dataset(dataset_name, model_name):
|
|
39 |
|
40 |
|
41 |
def load_model(model_name, device):
|
42 |
-
print(model_name)
|
43 |
if "open_clip" in model_name:
|
44 |
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
45 |
elif "clip" in model_name:
|
@@ -79,6 +97,85 @@ def encode_class_name(tokenizer, model, class_name, device):
|
|
79 |
return class_features.cpu().numpy()
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import clip
|
3 |
import open_clip
|
4 |
import h5py
|
5 |
+
import streamlit as st
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
|
10 |
+
import ml_collections
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
|
13 |
+
from ibydmt.test import xSKIT
|
14 |
from app_lib.utils import SUPPORTED_MODELS
|
15 |
+
from app_lib.ckde import cKDE
|
16 |
+
|
17 |
+
rng = np.random.default_rng()
|
18 |
+
|
19 |
+
testing_config = ml_collections.ConfigDict()
|
20 |
+
testing_config.significance_level = 0.05
|
21 |
+
testing_config.wealth = "ons"
|
22 |
+
testing_config.bet = "tanh"
|
23 |
+
testing_config.kernel = "rbf"
|
24 |
+
testing_config.kernel_scale_method = "quantile"
|
25 |
+
testing_config.kernel_scale = 0.5
|
26 |
+
testing_config.tau_max = 200
|
27 |
+
testing_config.r = 10
|
28 |
|
29 |
|
30 |
def _get_open_clip_model(model_name, device):
|
|
|
58 |
|
59 |
|
60 |
def load_model(model_name, device):
|
|
|
61 |
if "open_clip" in model_name:
|
62 |
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
63 |
elif "clip" in model_name:
|
|
|
97 |
return class_features.cpu().numpy()
|
98 |
|
99 |
|
100 |
+
def _sample_random_subset(concept_idx, concepts, cardinality):
|
101 |
+
sample_idx = list(set(range(len(concepts))) - {concept_idx})
|
102 |
+
return rng.permutation(sample_idx)[:cardinality].tolist()
|
103 |
+
|
104 |
+
|
105 |
+
def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
|
106 |
+
def cond_p(z, cond_idx, m):
|
107 |
+
_, sample_h = sampler.sample(z, cond_idx, m=m)
|
108 |
+
return sample_h
|
109 |
+
|
110 |
+
def f(h):
|
111 |
+
output = h @ classifier.T
|
112 |
+
return output.squeeze()
|
113 |
+
|
114 |
+
rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], []
|
115 |
+
for _ in range(testing_config.r):
|
116 |
+
subset_idx = _sample_random_subset(concept_idx, concepts, cardinality)
|
117 |
+
subset = [concepts[idx] for idx in subset_idx]
|
118 |
+
|
119 |
+
tester = xSKIT(testing_config)
|
120 |
+
rejected, tau = tester.test(
|
121 |
+
z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False
|
122 |
+
)
|
123 |
+
wealth = tester.wealth._wealth
|
124 |
+
|
125 |
+
rejected_hist.append(rejected)
|
126 |
+
tau_hist.append(tau)
|
127 |
+
wealth_hist.append(wealth)
|
128 |
+
subset_hist.append(subset)
|
129 |
+
|
130 |
+
return {
|
131 |
+
"concept": concepts[concept_idx],
|
132 |
+
"rejected": rejected_hist,
|
133 |
+
"tau": tau_hist,
|
134 |
+
"wealth": wealth_hist,
|
135 |
+
"subset": subset_hist,
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
|
140 |
+
with st.spinner("Loading model"):
|
141 |
+
model, preprocess, tokenizer = load_model(model_name, device)
|
142 |
+
|
143 |
+
with st.spinner("Encoding concepts"):
|
144 |
+
cbm = encode_concepts(tokenizer, model, concepts, device)
|
145 |
+
|
146 |
+
with st.spinner("Encoding image"):
|
147 |
+
h = encode_image(model, preprocess, image, device)
|
148 |
+
z = h @ cbm.T
|
149 |
+
z = z.squeeze()
|
150 |
+
|
151 |
+
with st.spinner("Testing"):
|
152 |
+
progress_bar = st.progress(0)
|
153 |
+
|
154 |
+
embedding = load_dataset("imagenette", model_name)
|
155 |
+
semantics = embedding @ cbm.T
|
156 |
+
sampler = cKDE(embedding, semantics)
|
157 |
+
|
158 |
+
classifier = encode_class_name(tokenizer, model, class_name, device)
|
159 |
+
|
160 |
+
with ThreadPoolExecutor() as executor:
|
161 |
+
futures = [
|
162 |
+
executor.submit(
|
163 |
+
_test, z, concept_idx, concepts, cardinality, sampler, classifier
|
164 |
+
)
|
165 |
+
for concept_idx in range(len(concepts))
|
166 |
+
]
|
167 |
+
|
168 |
+
results = []
|
169 |
+
for idx, future in enumerate(as_completed(futures)):
|
170 |
+
results.append(future.result())
|
171 |
+
progress_bar.progress((idx + 1) / len(concepts))
|
172 |
+
|
173 |
+
# print(results)
|
174 |
+
# wealth = np.empty((testing_config.tau_max, len(concepts)))
|
175 |
+
# wealth[:] = np.nan
|
176 |
+
# for _results in results:
|
177 |
+
# concept_idx = concepts.index(_results["concept"])
|
178 |
+
# _wealth =
|
179 |
+
|
180 |
+
st.session_state.disabled = False
|
181 |
+
st.experimental_rerun()
|
app_lib/user_input.py
CHANGED
@@ -80,7 +80,7 @@ def get_cardinality(concepts, concepts_ready):
|
|
80 |
help="The number of concepts to condition model predictions on.",
|
81 |
min_value=1,
|
82 |
max_value=max(2, len(concepts) - 1),
|
83 |
-
value=
|
84 |
step=1,
|
85 |
disabled=st.session_state.disabled or not concepts_ready,
|
86 |
)
|
|
|
80 |
help="The number of concepts to condition model predictions on.",
|
81 |
min_value=1,
|
82 |
max_value=max(2, len(concepts) - 1),
|
83 |
+
value=2,
|
84 |
step=1,
|
85 |
disabled=st.session_state.disabled or not concepts_ready,
|
86 |
)
|
ibydmt/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/__init__.cpython-310.pyc and b/ibydmt/__pycache__/__init__.cpython-310.pyc differ
|
|
ibydmt/__pycache__/bet.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/bet.cpython-310.pyc and b/ibydmt/__pycache__/bet.cpython-310.pyc differ
|
|
ibydmt/__pycache__/payoff.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/payoff.cpython-310.pyc and b/ibydmt/__pycache__/payoff.cpython-310.pyc differ
|
|
ibydmt/__pycache__/test.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/test.cpython-310.pyc and b/ibydmt/__pycache__/test.cpython-310.pyc differ
|
|
ibydmt/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/utils.cpython-310.pyc and b/ibydmt/__pycache__/utils.cpython-310.pyc differ
|
|
ibydmt/__pycache__/wealth.cpython-310.pyc
CHANGED
Binary files a/ibydmt/__pycache__/wealth.cpython-310.pyc and b/ibydmt/__pycache__/wealth.cpython-310.pyc differ
|
|
ibydmt/payoff.py
CHANGED
@@ -81,6 +81,7 @@ class KernelPayoff(Payoff):
|
|
81 |
zip(d, null_d),
|
82 |
0,
|
83 |
)
|
|
|
84 |
|
85 |
return self.bet.compute(g)
|
86 |
|
|
|
81 |
zip(d, null_d),
|
82 |
0,
|
83 |
)
|
84 |
+
g = g.squeeze().item()
|
85 |
|
86 |
return self.bet.compute(g)
|
87 |
|
ibydmt/test.py
CHANGED
@@ -141,9 +141,12 @@ class xSKIT(SequentialTester):
|
|
141 |
C: list[int],
|
142 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
143 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
|
|
144 |
) -> Tuple[bool, int]:
|
145 |
sample = functools.partial(self._sample, z, j, C, cond_p, model)
|
146 |
|
|
|
|
|
147 |
prev_d = np.stack(sample(), axis=1)
|
148 |
for t in range(1, self.tau_max):
|
149 |
y, null_y = sample()
|
@@ -155,5 +158,7 @@ class xSKIT(SequentialTester):
|
|
155 |
prev_d = np.vstack([prev_d, d])
|
156 |
|
157 |
if self.wealth.rejected:
|
158 |
-
|
159 |
-
|
|
|
|
|
|
141 |
C: list[int],
|
142 |
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
143 |
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
144 |
+
interrupt_on_rejection: bool = True,
|
145 |
) -> Tuple[bool, int]:
|
146 |
sample = functools.partial(self._sample, z, j, C, cond_p, model)
|
147 |
|
148 |
+
tau = self.tau_max - 1
|
149 |
+
|
150 |
prev_d = np.stack(sample(), axis=1)
|
151 |
for t in range(1, self.tau_max):
|
152 |
y, null_y = sample()
|
|
|
158 |
prev_d = np.vstack([prev_d, d])
|
159 |
|
160 |
if self.wealth.rejected:
|
161 |
+
tau = min(tau, t)
|
162 |
+
if interrupt_on_rejection:
|
163 |
+
break
|
164 |
+
return (self.wealth.rejected, tau)
|
ibydmt/wealth.py
CHANGED
@@ -46,27 +46,30 @@ class ONS(Wealth):
|
|
46 |
def __init__(self, config):
|
47 |
super().__init__(config)
|
48 |
|
49 |
-
self.
|
50 |
-
self.
|
51 |
-
self.
|
52 |
|
53 |
-
self.
|
54 |
-
self.
|
|
|
|
|
55 |
|
56 |
def _update_v(self, payoff):
|
57 |
-
z = payoff / (1 + self.
|
58 |
-
self.
|
59 |
-
self.
|
60 |
-
self.
|
61 |
)
|
62 |
|
63 |
def update(self, payoff):
|
64 |
-
w = self.
|
65 |
|
66 |
-
if w >= 0 and not self.
|
67 |
-
self.
|
68 |
-
|
|
|
69 |
self.rejected = True
|
70 |
self._update_v(payoff)
|
71 |
else:
|
72 |
-
self.
|
|
|
46 |
def __init__(self, config):
|
47 |
super().__init__(config)
|
48 |
|
49 |
+
self._w = 1.0
|
50 |
+
self._v = 0
|
51 |
+
self._a = 1
|
52 |
|
53 |
+
self._min_v, self._max_v = config.get("min_v", 0), config.get("max_v", 1 / 2)
|
54 |
+
self._wealth_flag = False
|
55 |
+
|
56 |
+
self._wealth = [self._w]
|
57 |
|
58 |
def _update_v(self, payoff):
|
59 |
+
z = payoff / (1 + self._v * payoff)
|
60 |
+
self._a += z**2
|
61 |
+
self._v = max(
|
62 |
+
self._min_v, min(self._max_v, self._v + 2 / (2 - np.log(3)) * z / self._a)
|
63 |
)
|
64 |
|
65 |
def update(self, payoff):
|
66 |
+
w = self._w * (1 + self._v * payoff)
|
67 |
|
68 |
+
if w >= 0 and not self._wealth_flag:
|
69 |
+
self._w = w
|
70 |
+
self._wealth.append(self._w)
|
71 |
+
if self._w >= 1 / self.significance_level:
|
72 |
self.rejected = True
|
73 |
self._update_v(payoff)
|
74 |
else:
|
75 |
+
self._wealth_flag = True
|