jacopoteneggi commited on
Commit
a40e67a
·
verified ·
1 Parent(s): 7e74032
app.py CHANGED
@@ -27,11 +27,6 @@ st.markdown(
27
  [data-testid="stHorizontalBlock"] {
28
  align-items: center;
29
  }
30
- div.stSpinner > div {
31
- text-align:center;
32
- align-items: center;
33
- justify-content: center;
34
- }
35
  </style>
36
  """,
37
  unsafe_allow_html=True,
@@ -44,7 +39,7 @@ st.markdown(
44
  Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
45
 
46
  ---
47
- """,
48
  )
49
 
50
  if __name__ == "__main__":
 
27
  [data-testid="stHorizontalBlock"] {
28
  align-items: center;
29
  }
 
 
 
 
 
30
  </style>
31
  """,
32
  unsafe_allow_html=True,
 
39
  Official HF Space for the paper [*I Bet You Did Not Mean That: Testing Semantci Importance via Betting*](https://arxiv.org/pdf/2405.19146), by [Jacopo Teneggi](https://jacopoteneggi.github.io) and [Jeremias Sulam](https://sites.google.com/view/jsulam).
40
 
41
  ---
42
+ """,
43
  )
44
 
45
  if __name__ == "__main__":
app_lib/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/__init__.cpython-310.pyc and b/app_lib/__pycache__/__init__.cpython-310.pyc differ
 
app_lib/__pycache__/ckde.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
app_lib/__pycache__/main.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/main.cpython-310.pyc and b/app_lib/__pycache__/main.cpython-310.pyc differ
 
app_lib/__pycache__/test.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/test.cpython-310.pyc and b/app_lib/__pycache__/test.cpython-310.pyc differ
 
app_lib/__pycache__/user_input.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/user_input.cpython-310.pyc and b/app_lib/__pycache__/user_input.cpython-310.pyc differ
 
app_lib/ckde.py CHANGED
@@ -1,18 +1,18 @@
1
  import numpy as np
2
- import torch
3
  from scipy.spatial.distance import cdist
4
  from scipy.stats import gaussian_kde
5
 
 
6
  class cKDE:
7
- def __init__(self, config, concept_class_name=None, concept_image_idx=None):
8
- ckde_config = config.ckde
9
- self.image_size = image_size = ckde_config.get("image_size", 128)
10
- self.metric = ckde_config.get("metric", "euclidean")
11
- self.scale_method = ckde_config.get("scale_method", "neff")
12
- self.scale = ckde_config.get("scale", 2000)
13
 
14
- self.Z = self.dataset.Z
15
- self.H = self.dataset.H
16
 
17
  def _quantile_scale(self, Z_cond_dist):
18
  return np.quantile(Z_cond_dist, self.scale)
@@ -62,7 +62,7 @@ class cKDE:
62
  dist = cdist(z, self.Z, metric=self.metric)
63
  return np.argmin(dist, axis=-1)
64
 
65
- def sample(self, z, cond_idx, m=1, return_images=False):
66
  if z.ndim == 1:
67
  z = z.reshape(1, -1)
68
 
@@ -71,7 +71,4 @@ class cKDE:
71
  nn_idx = self.nearest_neighbor(sample_z)
72
  sample_h = self.H[nn_idx]
73
 
74
- if return_images:
75
- sample_images = torch.stack([self.dataset[_idx][0] for _idx in nn_idx])
76
- return sample_z, sample_h, sample_images
77
  return sample_z, sample_h
 
1
  import numpy as np
 
2
  from scipy.spatial.distance import cdist
3
  from scipy.stats import gaussian_kde
4
 
5
+
6
  class cKDE:
7
+ def __init__(
8
+ self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000
9
+ ):
10
+ self.metric = metric
11
+ self.scale_method = scale_method
12
+ self.scale = scale
13
 
14
+ self.H = embedding
15
+ self.Z = semantics
16
 
17
  def _quantile_scale(self, Z_cond_dist):
18
  return np.quantile(Z_cond_dist, self.scale)
 
62
  dist = cdist(z, self.Z, metric=self.metric)
63
  return np.argmin(dist, axis=-1)
64
 
65
+ def sample(self, z, cond_idx, m=1):
66
  if z.ndim == 1:
67
  z = z.reshape(1, -1)
68
 
 
71
  nn_idx = self.nearest_neighbor(sample_z)
72
  sample_h = self.H[nn_idx]
73
 
 
 
 
74
  return sample_z, sample_h
app_lib/main.py CHANGED
@@ -9,13 +9,7 @@ from app_lib.user_input import (
9
  get_image,
10
  get_model_name,
11
  )
12
- from app_lib.test import (
13
- load_dataset,
14
- load_model,
15
- encode_image,
16
- encode_concepts,
17
- encode_class_name,
18
- )
19
 
20
 
21
  def _disable():
@@ -67,27 +61,15 @@ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
67
  )
68
 
69
  with columns[1]:
70
- if test_button:
71
- with st.spinner("Loading dataset"):
72
- embedding = load_dataset("imagenette", model_name)
73
- time.sleep(1)
74
-
75
- with st.spinner("Loading model"):
76
- model, preprocess, tokenizer = load_model(model_name, device)
77
- time.sleep(1)
78
-
79
- with st.spinner("Encoding concepts"):
80
- cbm = encode_concepts(tokenizer, model, concepts, device)
81
- time.sleep(1)
82
-
83
- with st.spinner("Preparing zero-shot classifier"):
84
- classifier = encode_class_name(tokenizer, model, class_name, device)
85
-
86
- with st.spinner("Encoding image"):
87
- h = encode_image(model, preprocess, image, device)
88
- z = h @ cbm.T
89
- print(h.shape, cbm.shape, z.shape)
90
- time.sleep(2)
91
-
92
- st.session_state.disabled = False
93
- st.experimental_rerun()
 
9
  get_image,
10
  get_model_name,
11
  )
12
+ from app_lib.test import test
 
 
 
 
 
 
13
 
14
 
15
  def _disable():
 
61
  )
62
 
63
  with columns[1]:
64
+ _, centercol, _ = st.columns(3)
65
+ with centercol:
66
+ if test_button:
67
+ test(
68
+ image,
69
+ class_name,
70
+ concepts,
71
+ cardinality,
72
+ "imagenette",
73
+ model_name,
74
+ device,
75
+ )
 
 
 
 
 
 
 
 
 
 
 
 
app_lib/test.py CHANGED
@@ -2,10 +2,29 @@ import torch
2
  import clip
3
  import open_clip
4
  import h5py
 
 
 
 
5
 
 
6
  from huggingface_hub import hf_hub_download
7
 
 
8
  from app_lib.utils import SUPPORTED_MODELS
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def _get_open_clip_model(model_name, device):
@@ -39,7 +58,6 @@ def load_dataset(dataset_name, model_name):
39
 
40
 
41
  def load_model(model_name, device):
42
- print(model_name)
43
  if "open_clip" in model_name:
44
  model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
45
  elif "clip" in model_name:
@@ -79,6 +97,85 @@ def encode_class_name(tokenizer, model, class_name, device):
79
  return class_features.cpu().numpy()
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
83
- model, preprocess = load_model(model_name, device)
84
- print(f"loaded {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import clip
3
  import open_clip
4
  import h5py
5
+ import streamlit as st
6
+ import numpy as np
7
+ import time
8
+ from concurrent.futures import ThreadPoolExecutor, as_completed
9
 
10
+ import ml_collections
11
  from huggingface_hub import hf_hub_download
12
 
13
+ from ibydmt.test import xSKIT
14
  from app_lib.utils import SUPPORTED_MODELS
15
+ from app_lib.ckde import cKDE
16
+
17
+ rng = np.random.default_rng()
18
+
19
+ testing_config = ml_collections.ConfigDict()
20
+ testing_config.significance_level = 0.05
21
+ testing_config.wealth = "ons"
22
+ testing_config.bet = "tanh"
23
+ testing_config.kernel = "rbf"
24
+ testing_config.kernel_scale_method = "quantile"
25
+ testing_config.kernel_scale = 0.5
26
+ testing_config.tau_max = 200
27
+ testing_config.r = 10
28
 
29
 
30
  def _get_open_clip_model(model_name, device):
 
58
 
59
 
60
  def load_model(model_name, device):
 
61
  if "open_clip" in model_name:
62
  model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
63
  elif "clip" in model_name:
 
97
  return class_features.cpu().numpy()
98
 
99
 
100
+ def _sample_random_subset(concept_idx, concepts, cardinality):
101
+ sample_idx = list(set(range(len(concepts))) - {concept_idx})
102
+ return rng.permutation(sample_idx)[:cardinality].tolist()
103
+
104
+
105
+ def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
106
+ def cond_p(z, cond_idx, m):
107
+ _, sample_h = sampler.sample(z, cond_idx, m=m)
108
+ return sample_h
109
+
110
+ def f(h):
111
+ output = h @ classifier.T
112
+ return output.squeeze()
113
+
114
+ rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], []
115
+ for _ in range(testing_config.r):
116
+ subset_idx = _sample_random_subset(concept_idx, concepts, cardinality)
117
+ subset = [concepts[idx] for idx in subset_idx]
118
+
119
+ tester = xSKIT(testing_config)
120
+ rejected, tau = tester.test(
121
+ z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False
122
+ )
123
+ wealth = tester.wealth._wealth
124
+
125
+ rejected_hist.append(rejected)
126
+ tau_hist.append(tau)
127
+ wealth_hist.append(wealth)
128
+ subset_hist.append(subset)
129
+
130
+ return {
131
+ "concept": concepts[concept_idx],
132
+ "rejected": rejected_hist,
133
+ "tau": tau_hist,
134
+ "wealth": wealth_hist,
135
+ "subset": subset_hist,
136
+ }
137
+
138
+
139
  def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
140
+ with st.spinner("Loading model"):
141
+ model, preprocess, tokenizer = load_model(model_name, device)
142
+
143
+ with st.spinner("Encoding concepts"):
144
+ cbm = encode_concepts(tokenizer, model, concepts, device)
145
+
146
+ with st.spinner("Encoding image"):
147
+ h = encode_image(model, preprocess, image, device)
148
+ z = h @ cbm.T
149
+ z = z.squeeze()
150
+
151
+ with st.spinner("Testing"):
152
+ progress_bar = st.progress(0)
153
+
154
+ embedding = load_dataset("imagenette", model_name)
155
+ semantics = embedding @ cbm.T
156
+ sampler = cKDE(embedding, semantics)
157
+
158
+ classifier = encode_class_name(tokenizer, model, class_name, device)
159
+
160
+ with ThreadPoolExecutor() as executor:
161
+ futures = [
162
+ executor.submit(
163
+ _test, z, concept_idx, concepts, cardinality, sampler, classifier
164
+ )
165
+ for concept_idx in range(len(concepts))
166
+ ]
167
+
168
+ results = []
169
+ for idx, future in enumerate(as_completed(futures)):
170
+ results.append(future.result())
171
+ progress_bar.progress((idx + 1) / len(concepts))
172
+
173
+ # print(results)
174
+ # wealth = np.empty((testing_config.tau_max, len(concepts)))
175
+ # wealth[:] = np.nan
176
+ # for _results in results:
177
+ # concept_idx = concepts.index(_results["concept"])
178
+ # _wealth =
179
+
180
+ st.session_state.disabled = False
181
+ st.experimental_rerun()
app_lib/user_input.py CHANGED
@@ -80,7 +80,7 @@ def get_cardinality(concepts, concepts_ready):
80
  help="The number of concepts to condition model predictions on.",
81
  min_value=1,
82
  max_value=max(2, len(concepts) - 1),
83
- value=1,
84
  step=1,
85
  disabled=st.session_state.disabled or not concepts_ready,
86
  )
 
80
  help="The number of concepts to condition model predictions on.",
81
  min_value=1,
82
  max_value=max(2, len(concepts) - 1),
83
+ value=2,
84
  step=1,
85
  disabled=st.session_state.disabled or not concepts_ready,
86
  )
ibydmt/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/__init__.cpython-310.pyc and b/ibydmt/__pycache__/__init__.cpython-310.pyc differ
 
ibydmt/__pycache__/bet.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/bet.cpython-310.pyc and b/ibydmt/__pycache__/bet.cpython-310.pyc differ
 
ibydmt/__pycache__/payoff.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/payoff.cpython-310.pyc and b/ibydmt/__pycache__/payoff.cpython-310.pyc differ
 
ibydmt/__pycache__/test.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/test.cpython-310.pyc and b/ibydmt/__pycache__/test.cpython-310.pyc differ
 
ibydmt/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/utils.cpython-310.pyc and b/ibydmt/__pycache__/utils.cpython-310.pyc differ
 
ibydmt/__pycache__/wealth.cpython-310.pyc CHANGED
Binary files a/ibydmt/__pycache__/wealth.cpython-310.pyc and b/ibydmt/__pycache__/wealth.cpython-310.pyc differ
 
ibydmt/payoff.py CHANGED
@@ -81,6 +81,7 @@ class KernelPayoff(Payoff):
81
  zip(d, null_d),
82
  0,
83
  )
 
84
 
85
  return self.bet.compute(g)
86
 
 
81
  zip(d, null_d),
82
  0,
83
  )
84
+ g = g.squeeze().item()
85
 
86
  return self.bet.compute(g)
87
 
ibydmt/test.py CHANGED
@@ -141,9 +141,12 @@ class xSKIT(SequentialTester):
141
  C: list[int],
142
  cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
143
  model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
 
144
  ) -> Tuple[bool, int]:
145
  sample = functools.partial(self._sample, z, j, C, cond_p, model)
146
 
 
 
147
  prev_d = np.stack(sample(), axis=1)
148
  for t in range(1, self.tau_max):
149
  y, null_y = sample()
@@ -155,5 +158,7 @@ class xSKIT(SequentialTester):
155
  prev_d = np.vstack([prev_d, d])
156
 
157
  if self.wealth.rejected:
158
- return (True, t)
159
- return (False, t)
 
 
 
141
  C: list[int],
142
  cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
143
  model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
144
+ interrupt_on_rejection: bool = True,
145
  ) -> Tuple[bool, int]:
146
  sample = functools.partial(self._sample, z, j, C, cond_p, model)
147
 
148
+ tau = self.tau_max - 1
149
+
150
  prev_d = np.stack(sample(), axis=1)
151
  for t in range(1, self.tau_max):
152
  y, null_y = sample()
 
158
  prev_d = np.vstack([prev_d, d])
159
 
160
  if self.wealth.rejected:
161
+ tau = min(tau, t)
162
+ if interrupt_on_rejection:
163
+ break
164
+ return (self.wealth.rejected, tau)
ibydmt/wealth.py CHANGED
@@ -46,27 +46,30 @@ class ONS(Wealth):
46
  def __init__(self, config):
47
  super().__init__(config)
48
 
49
- self.w = 1.0
50
- self.v = 0
51
- self.a = 1
52
 
53
- self.min_v, self.max_v = config.get("min_v", 0), config.get("max_v", 1 / 2)
54
- self.wealth_flag = False
 
 
55
 
56
  def _update_v(self, payoff):
57
- z = payoff / (1 + self.v * payoff)
58
- self.a += z**2
59
- self.v = max(
60
- self.min_v, min(self.max_v, self.v + 2 / (2 - np.log(3)) * z / self.a)
61
  )
62
 
63
  def update(self, payoff):
64
- w = self.w * (1 + self.v * payoff)
65
 
66
- if w >= 0 and not self.wealth_flag:
67
- self.w = w
68
- if self.w >= 1 / self.significance_level:
 
69
  self.rejected = True
70
  self._update_v(payoff)
71
  else:
72
- self.wealth_flag = True
 
46
  def __init__(self, config):
47
  super().__init__(config)
48
 
49
+ self._w = 1.0
50
+ self._v = 0
51
+ self._a = 1
52
 
53
+ self._min_v, self._max_v = config.get("min_v", 0), config.get("max_v", 1 / 2)
54
+ self._wealth_flag = False
55
+
56
+ self._wealth = [self._w]
57
 
58
  def _update_v(self, payoff):
59
+ z = payoff / (1 + self._v * payoff)
60
+ self._a += z**2
61
+ self._v = max(
62
+ self._min_v, min(self._max_v, self._v + 2 / (2 - np.log(3)) * z / self._a)
63
  )
64
 
65
  def update(self, payoff):
66
+ w = self._w * (1 + self._v * payoff)
67
 
68
+ if w >= 0 and not self._wealth_flag:
69
+ self._w = w
70
+ self._wealth.append(self._w)
71
+ if self._w >= 1 / self.significance_level:
72
  self.rejected = True
73
  self._update_v(payoff)
74
  else:
75
+ self._wealth_flag = True