jacopoteneggi commited on
Commit
4f55ca2
·
verified ·
1 Parent(s): 80dc74c
app.py CHANGED
@@ -1,16 +1,18 @@
1
  import numpy as np
2
  import open_clip
3
  import streamlit as st
4
- import torch
5
 
6
  from app_lib.main import main
7
 
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
9
 
10
- st.set_page_config(
11
- layout="wide",
12
- initial_sidebar_state=st.session_state.get("sidebar_state", "collapsed"),
13
- )
14
  st.session_state.sidebar_state = "collapsed"
15
  st.markdown(
16
  """
@@ -21,6 +23,15 @@ st.markdown(
21
  input {
22
  font-family: monospace !important;
23
  }
 
 
 
 
 
 
 
 
 
24
  </style>
25
  """,
26
  unsafe_allow_html=True,
@@ -36,19 +47,5 @@ st.markdown(
36
  """,
37
  )
38
 
39
-
40
- def load_clip():
41
- model, _, preprocess = open_clip.create_model_and_transforms(
42
- "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
43
- )
44
- tokenizer = open_clip.get_tokenizer("hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
45
-
46
-
47
- def test(
48
- image, class_name, concepts, cardinality, model_name, dataset_name="imagenette"
49
- ):
50
- print("test!")
51
-
52
-
53
  if __name__ == "__main__":
54
  main()
 
1
  import numpy as np
2
  import open_clip
3
  import streamlit as st
 
4
 
5
  from app_lib.main import main
6
 
7
+ if "sidebar_state" not in st.session_state:
8
+ st.session_state.sidebar_state = "collapsed"
9
+ if "disabled" not in st.session_state:
10
+ st.session_state.disabled = False
11
+ if "results" not in st.session_state:
12
+ st.session_state.results = None
13
+
14
+ st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
15
 
 
 
 
 
16
  st.session_state.sidebar_state = "collapsed"
17
  st.markdown(
18
  """
 
23
  input {
24
  font-family: monospace !important;
25
  }
26
+
27
+ [data-testid="stHorizontalBlock"] {
28
+ align-items: center;
29
+ }
30
+ div.stSpinner > div {
31
+ text-align:center;
32
+ align-items: center;
33
+ justify-content: center;
34
+ }
35
  </style>
36
  """,
37
  unsafe_allow_html=True,
 
47
  """,
48
  )
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if __name__ == "__main__":
51
  main()
app_lib/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/__init__.cpython-310.pyc and b/app_lib/__pycache__/__init__.cpython-310.pyc differ
 
app_lib/__pycache__/main.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/main.cpython-310.pyc and b/app_lib/__pycache__/main.cpython-310.pyc differ
 
app_lib/__pycache__/test.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
app_lib/__pycache__/user_input.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/user_input.cpython-310.pyc and b/app_lib/__pycache__/user_input.cpython-310.pyc differ
 
app_lib/__pycache__/utils.cpython-310.pyc CHANGED
Binary files a/app_lib/__pycache__/utils.cpython-310.pyc and b/app_lib/__pycache__/utils.cpython-310.pyc differ
 
app_lib/ckde.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy.spatial.distance import cdist
4
+ from scipy.stats import gaussian_kde
5
+
6
+ class cKDE:
7
+ def __init__(self, config, concept_class_name=None, concept_image_idx=None):
8
+ ckde_config = config.ckde
9
+ self.image_size = image_size = ckde_config.get("image_size", 128)
10
+ self.metric = ckde_config.get("metric", "euclidean")
11
+ self.scale_method = ckde_config.get("scale_method", "neff")
12
+ self.scale = ckde_config.get("scale", 2000)
13
+
14
+ self.Z = self.dataset.Z
15
+ self.H = self.dataset.H
16
+
17
+ def _quantile_scale(self, Z_cond_dist):
18
+ return np.quantile(Z_cond_dist, self.scale)
19
+
20
+ def _neff_scale(self, Z_cond_dist):
21
+ scales = np.linspace(1e-02, 0.4, 100)[:, None]
22
+
23
+ _Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1))
24
+
25
+ weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2))
26
+ neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1)
27
+ diff = np.abs(neff - self.scale)
28
+ scale_idx = np.argmin(diff)
29
+ return scales[scale_idx].item()
30
+
31
+ def _sample(self, z, cond_idx, m):
32
+ sample_idx = list(set(range(len(z))) - set(cond_idx))
33
+
34
+ kde, _ = self.kde(z, cond_idx)
35
+
36
+ sample_z = np.tile(z, (m, 1))
37
+ sample_z[:, sample_idx] = kde.resample(m).T
38
+
39
+ return sample_z
40
+
41
+ def kde(self, z, cond_idx):
42
+ sample_idx = list(set(range(len(z))) - set(cond_idx))
43
+
44
+ Z_sample = self.Z[:, sample_idx]
45
+ Z_cond = self.Z[:, cond_idx]
46
+
47
+ z_cond = z[cond_idx]
48
+ Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze()
49
+
50
+ if self.scale_method == "constant":
51
+ scale = self.scale
52
+ if self.scale_method == "quantile":
53
+ scale = self._quantile_scale(Z_cond_dist)
54
+ elif self.scale_method == "neff":
55
+ scale = self._neff_scale(Z_cond_dist)
56
+
57
+ weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2))
58
+
59
+ return gaussian_kde(Z_sample.T, weights=weights), scale
60
+
61
+ def nearest_neighbor(self, z):
62
+ dist = cdist(z, self.Z, metric=self.metric)
63
+ return np.argmin(dist, axis=-1)
64
+
65
+ def sample(self, z, cond_idx, m=1, return_images=False):
66
+ if z.ndim == 1:
67
+ z = z.reshape(1, -1)
68
+
69
+ sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0)
70
+
71
+ nn_idx = self.nearest_neighbor(sample_z)
72
+ sample_h = self.H[nn_idx]
73
+
74
+ if return_images:
75
+ sample_images = torch.stack([self.dataset[_idx][0] for _idx in nn_idx])
76
+ return sample_z, sample_h, sample_images
77
+ return sample_z, sample_h
app_lib/main.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import streamlit as st
 
2
 
3
  from app_lib.user_input import (
4
  get_cardinality,
@@ -7,9 +9,20 @@ from app_lib.user_input import (
7
  get_image,
8
  get_model_name,
9
  )
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
- def main():
13
  columns = st.columns([0.40, 0.60])
14
 
15
  with columns[0]:
@@ -27,7 +40,11 @@ def main():
27
  cardinality = get_cardinality(concepts, concepts_ready)
28
 
29
  with row2[0]:
30
- change_image_button = st.button("Change Image", use_container_width=True)
 
 
 
 
31
  if change_image_button:
32
  st.session_state.sidebar_state = "expanded"
33
  st.experimental_rerun()
@@ -39,13 +56,38 @@ def main():
39
  error_message += f"- {class_error}\n"
40
  if concepts_error is not None:
41
  error_message += f"- {concepts_error}\n"
 
 
42
 
43
  test_button = st.button(
44
  "Test",
45
- help=None if ready else error_message,
46
  use_container_width=True,
47
- disabled=not ready,
 
48
  )
49
 
50
- if test_button:
51
- test(image, class_name, concepts, cardinality, model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import streamlit as st
3
+ import time
4
 
5
  from app_lib.user_input import (
6
  get_cardinality,
 
9
  get_image,
10
  get_model_name,
11
  )
12
+ from app_lib.test import (
13
+ load_dataset,
14
+ load_model,
15
+ encode_image,
16
+ encode_concepts,
17
+ encode_class_name,
18
+ )
19
+
20
+
21
+ def _disable():
22
+ st.session_state.disabled = True
23
 
24
 
25
+ def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
26
  columns = st.columns([0.40, 0.60])
27
 
28
  with columns[0]:
 
40
  cardinality = get_cardinality(concepts, concepts_ready)
41
 
42
  with row2[0]:
43
+ change_image_button = st.button(
44
+ "Change Image",
45
+ use_container_width=True,
46
+ disabled=st.session_state.disabled,
47
+ )
48
  if change_image_button:
49
  st.session_state.sidebar_state = "expanded"
50
  st.experimental_rerun()
 
56
  error_message += f"- {class_error}\n"
57
  if concepts_error is not None:
58
  error_message += f"- {concepts_error}\n"
59
+ if error_message:
60
+ st.error(error_message)
61
 
62
  test_button = st.button(
63
  "Test",
 
64
  use_container_width=True,
65
+ on_click=_disable,
66
+ disabled=st.session_state.disabled or not ready,
67
  )
68
 
69
+ with columns[1]:
70
+ if test_button:
71
+ with st.spinner("Loading dataset"):
72
+ embedding = load_dataset("imagenette", model_name)
73
+ time.sleep(1)
74
+
75
+ with st.spinner("Loading model"):
76
+ model, preprocess, tokenizer = load_model(model_name, device)
77
+ time.sleep(1)
78
+
79
+ with st.spinner("Encoding concepts"):
80
+ cbm = encode_concepts(tokenizer, model, concepts, device)
81
+ time.sleep(1)
82
+
83
+ with st.spinner("Preparing zero-shot classifier"):
84
+ classifier = encode_class_name(tokenizer, model, class_name, device)
85
+
86
+ with st.spinner("Encoding image"):
87
+ h = encode_image(model, preprocess, image, device)
88
+ z = h @ cbm.T
89
+ print(h.shape, cbm.shape, z.shape)
90
+ time.sleep(2)
91
+
92
+ st.session_state.disabled = False
93
+ st.experimental_rerun()
app_lib/test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import open_clip
4
+ import h5py
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from app_lib.utils import SUPPORTED_MODELS
9
+
10
+
11
+ def _get_open_clip_model(model_name, device):
12
+ backbone = model_name.split(":")[-1]
13
+
14
+ model, _, preprocess = open_clip.create_model_and_transforms(
15
+ SUPPORTED_MODELS[model_name], device=device
16
+ )
17
+ model.eval()
18
+ tokenizer = open_clip.get_tokenizer(backbone)
19
+ return model, preprocess, tokenizer
20
+
21
+
22
+ def _get_clip_model(model_name, device):
23
+ backbone = model_name.split(":")[-1]
24
+ model, preprocess = clip.load(backbone, device=device)
25
+ tokenizer = clip.tokenize
26
+ return model, preprocess, tokenizer
27
+
28
+
29
+ def load_dataset(dataset_name, model_name):
30
+ dataset_path = hf_hub_download(
31
+ repo_id="jacopoteneggi/IBYDMT",
32
+ filename=f"{dataset_name}_{model_name}_train.h5",
33
+ repo_type="dataset",
34
+ )
35
+
36
+ with h5py.File(dataset_path, "r") as dataset:
37
+ embedding = dataset["embedding"][:]
38
+ return embedding
39
+
40
+
41
+ def load_model(model_name, device):
42
+ print(model_name)
43
+ if "open_clip" in model_name:
44
+ model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
45
+ elif "clip" in model_name:
46
+ model, preprocess, tokenizer = _get_clip_model(model_name, device)
47
+ return model, preprocess, tokenizer
48
+
49
+
50
+ @torch.no_grad()
51
+ @torch.cuda.amp.autocast()
52
+ def encode_concepts(tokenizer, model, concepts, device):
53
+ concepts_text = tokenizer(concepts).to(device)
54
+
55
+ concept_features = model.encode_text(concepts_text)
56
+ concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
57
+ return concept_features.cpu().numpy()
58
+
59
+
60
+ @torch.no_grad()
61
+ @torch.cuda.amp.autocast()
62
+ def encode_image(model, preprocess, image, device):
63
+ image = preprocess(image)
64
+ image = image.unsqueeze(0)
65
+ image = image.to(device)
66
+
67
+ image_features = model.encode_image(image)
68
+ image_features /= image_features.norm(dim=-1, keepdim=True)
69
+ return image_features.cpu().numpy()
70
+
71
+
72
+ @torch.no_grad()
73
+ @torch.cuda.amp.autocast()
74
+ def encode_class_name(tokenizer, model, class_name, device):
75
+ class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
76
+
77
+ class_features = model.encode_text(class_text)
78
+ class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
79
+ return class_features.cpu().numpy()
80
+
81
+
82
+ def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
83
+ model, preprocess = load_model(model_name, device)
84
+ print(f"loaded {model_name}")
app_lib/user_input.py CHANGED
@@ -24,8 +24,9 @@ def _validate_concepts(concepts):
24
  def get_model_name():
25
  return st.selectbox(
26
  "Choose a model to test",
27
- options=SUPPORTED_MODELS,
28
  help="Name of the vision-language model to test the predictions of.",
 
29
  )
30
 
31
 
@@ -49,6 +50,7 @@ def get_class_name():
49
  "Class to test",
50
  help="Name of the class to build the zero-shot CLIP classifier with.",
51
  value="cat",
 
52
  )
53
 
54
  class_ready, class_error = _validate_class_name(class_name)
@@ -61,6 +63,7 @@ def get_concepts():
61
  help="List of concepts to test the predictions of the model with. Write one concept per line.",
62
  height=160,
63
  value="piano\ncute\nwhiskers\nmusic\nwild",
 
64
  )
65
  concepts = concepts.split("\n")
66
  concepts = [concept.strip() for concept in concepts]
@@ -79,5 +82,5 @@ def get_cardinality(concepts, concepts_ready):
79
  max_value=max(2, len(concepts) - 1),
80
  value=1,
81
  step=1,
82
- disabled=not concepts_ready,
83
  )
 
24
  def get_model_name():
25
  return st.selectbox(
26
  "Choose a model to test",
27
+ options=list(SUPPORTED_MODELS.keys()),
28
  help="Name of the vision-language model to test the predictions of.",
29
+ disabled=st.session_state.disabled,
30
  )
31
 
32
 
 
50
  "Class to test",
51
  help="Name of the class to build the zero-shot CLIP classifier with.",
52
  value="cat",
53
+ disabled=st.session_state.disabled,
54
  )
55
 
56
  class_ready, class_error = _validate_class_name(class_name)
 
63
  help="List of concepts to test the predictions of the model with. Write one concept per line.",
64
  height=160,
65
  value="piano\ncute\nwhiskers\nmusic\nwild",
66
+ disabled=st.session_state.disabled,
67
  )
68
  concepts = concepts.split("\n")
69
  concepts = [concept.strip() for concept in concepts]
 
82
  max_value=max(2, len(concepts) - 1),
83
  value=1,
84
  step=1,
85
+ disabled=st.session_state.disabled or not concepts_ready,
86
  )
app_lib/utils.py CHANGED
@@ -5,10 +5,22 @@ supported_models_path = hf_hub_download(
5
  filename="supported_models.txt",
6
  repo_type="dataset",
7
  )
 
 
 
 
 
8
 
9
- SUPPORTED_MODELS = []
10
  with open(supported_models_path, "r") as f:
11
  for line in f:
12
  line = line.strip()
13
- model_name, _ = line.split(",")
14
- SUPPORTED_MODELS.append(model_name)
 
 
 
 
 
 
 
 
5
  filename="supported_models.txt",
6
  repo_type="dataset",
7
  )
8
+ supported_datasets_path = hf_hub_download(
9
+ repo_id="jacopoteneggi/IBYDMT",
10
+ filename="supported_datasets.txt",
11
+ repo_type="dataset",
12
+ )
13
 
14
+ SUPPORTED_MODELS = {}
15
  with open(supported_models_path, "r") as f:
16
  for line in f:
17
  line = line.strip()
18
+ model_name, model_url = line.split(",")
19
+ SUPPORTED_MODELS[model_name] = model_url
20
+
21
+
22
+ SUPPORTED_DATASETS = []
23
+ with open(supported_models_path, "r") as f:
24
+ for line in f:
25
+ dataset_name = line.strip()
26
+ SUPPORTED_DATASETS.append(dataset_name)
assets/ace.jpg ADDED
ibydmt/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from ibydmt.test import SKIT, cSKIT, xSKIT
ibydmt/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (224 Bytes). View file
 
ibydmt/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (266 Bytes). View file
 
ibydmt/__pycache__/bet.cpython-310.pyc ADDED
Binary file (2.06 kB). View file
 
ibydmt/__pycache__/bet.cpython-311.pyc ADDED
Binary file (3.52 kB). View file
 
ibydmt/__pycache__/payoff.cpython-310.pyc ADDED
Binary file (5.24 kB). View file
 
ibydmt/__pycache__/payoff.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
ibydmt/__pycache__/test.cpython-310.pyc ADDED
Binary file (5.3 kB). View file
 
ibydmt/__pycache__/test.cpython-311.pyc ADDED
Binary file (9.79 kB). View file
 
ibydmt/__pycache__/utils.cpython-310.pyc ADDED
Binary file (599 Bytes). View file
 
ibydmt/__pycache__/utils.cpython-311.pyc ADDED
Binary file (793 Bytes). View file
 
ibydmt/__pycache__/wealth.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
ibydmt/__pycache__/wealth.cpython-311.pyc ADDED
Binary file (4.48 kB). View file
 
ibydmt/bet.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+
6
+ from ibydmt.utils import _get_cls, _register_cls
7
+
8
+
9
+ class Bet(ABC):
10
+ def __init__(self):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def compute(self, *args, **kwargs):
15
+ pass
16
+
17
+
18
+ _BETS: Dict[str, Bet] = {}
19
+
20
+
21
+ def register_bet(name):
22
+ return _register_cls(name, dict=_BETS)
23
+
24
+
25
+ def get_bet(name):
26
+ return _get_cls(name, dict=_BETS)
27
+
28
+
29
+ @register_bet("sign")
30
+ class Sign(Bet):
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ self.m = config.get("m", 0.5)
34
+ self.prev_g = []
35
+
36
+ def compute(self, g):
37
+ return self.m * np.sign(g)
38
+
39
+
40
+ @register_bet("tanh")
41
+ class Tanh(Bet):
42
+ def __init__(self, config):
43
+ super().__init__()
44
+ self.alpha = config.get("alpha", 0.20)
45
+ self.prev_g = []
46
+
47
+ def compute(self, g):
48
+ if len(self.prev_g) < 2:
49
+ scale = 1
50
+ else:
51
+ l, u = np.quantile(self.prev_g, [self.alpha / 2, 1 - self.alpha / 2])
52
+ scale = u - l
53
+
54
+ self.prev_g.append(g)
55
+
56
+ return np.tanh(g / np.clip(scale, 1e-04, None))
ibydmt/payoff.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import reduce
3
+
4
+ import numpy as np
5
+ from sklearn.metrics import pairwise_distances
6
+ from sklearn.metrics.pairwise import linear_kernel, rbf_kernel
7
+
8
+ from ibydmt.bet import get_bet
9
+
10
+
11
+ class Payoff(ABC):
12
+ def __init__(self, config):
13
+ self.bet = get_bet(config.bet)(config)
14
+
15
+ @abstractmethod
16
+ def compute(self, *args, **kwargs):
17
+ pass
18
+
19
+
20
+ class Kernel:
21
+ def __init__(self, kernel: str, scale_method: str, scale: float):
22
+ if kernel == "linear":
23
+ self.base_kernel = linear_kernel
24
+ elif kernel == "rbf":
25
+ self.base_kernel = rbf_kernel
26
+
27
+ self.scale_method = scale_method
28
+ self.scale = scale
29
+
30
+ self.gamma = None
31
+ self.recompute_gamma = True
32
+ self.prev = None
33
+ else:
34
+ raise NotImplementedError(f"{kernel} is not implemented")
35
+
36
+ def __call__(self, x, y):
37
+ if self.base_kernel == linear_kernel:
38
+ return self.base_kernel(x, y)
39
+ if self.base_kernel == rbf_kernel:
40
+ if self.scale_method == "constant":
41
+ self.gamma = self.scale
42
+ elif self.scale_method == "quantile":
43
+ if self.prev is None:
44
+ self.prev = y
45
+
46
+ if self.recompute_gamma:
47
+ dist = pairwise_distances(
48
+ self.prev.reshape(-1, self.prev.shape[-1])
49
+ )
50
+ scale = np.quantile(dist, self.scale)
51
+ gamma = 1 / (2 * scale**2) if scale > 0 else None
52
+ self.gamma = gamma
53
+
54
+ if len(self.prev) > 100:
55
+ self.recompute_gamma = False
56
+ self.prev = np.vstack([self.prev, x])
57
+ else:
58
+ raise NotImplementedError(
59
+ f"{self.scale} is not implemented for rbf_kernel"
60
+ )
61
+ return self.base_kernel(x, y, gamma=self.gamma)
62
+
63
+
64
+ class KernelPayoff(Payoff):
65
+ def __init__(self, config):
66
+ super().__init__(config)
67
+
68
+ self.kernel = config.kernel
69
+ self.scale_method = config.get("kernel_scale_method", "quantile")
70
+ self.scale = config.get("kernel_scale", 0.5)
71
+
72
+ @abstractmethod
73
+ def witness_function(self, d, prev_d):
74
+ pass
75
+
76
+ def compute(self, d, null_d, prev_d):
77
+ g = reduce(
78
+ lambda acc, u: acc
79
+ + self.witness_function(u[0], prev_d)
80
+ - self.witness_function(u[1], prev_d),
81
+ zip(d, null_d),
82
+ 0,
83
+ )
84
+
85
+ return self.bet.compute(g)
86
+
87
+
88
+ class HSIC(KernelPayoff):
89
+ def __init__(self, config):
90
+ super().__init__(config)
91
+
92
+ kernel = self.kernel
93
+ scale_method = self.scale_method
94
+ scale = self.scale
95
+
96
+ self.kernel_y = Kernel(kernel, scale_method, scale)
97
+ self.kernel_z = Kernel(kernel, scale_method, scale)
98
+
99
+ def witness_function(self, d, prev_d):
100
+ y, z = d
101
+ prev_y, prev_z = prev_d[:, 0], prev_d[:, 1]
102
+
103
+ y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
104
+ z_mat = self.kernel_z(z.reshape(-1, 1), prev_z.reshape(-1, 1))
105
+
106
+ mu_joint = np.mean(y_mat * z_mat)
107
+ mu_prod = np.mean(y_mat, axis=1) @ np.mean(z_mat, axis=1)
108
+ return mu_joint - mu_prod
109
+
110
+
111
+ class cMMD(KernelPayoff):
112
+ def __init__(self, config):
113
+ super().__init__(config)
114
+
115
+ kernel = self.kernel
116
+ scale_method = self.scale_method
117
+ scale = self.scale
118
+
119
+ self.kernel_y = Kernel(kernel, scale_method, scale)
120
+ self.kernel_zj = Kernel(kernel, scale_method, scale)
121
+ self.kernel_cond_z = Kernel(kernel, scale_method, scale)
122
+
123
+ def witness_function(self, u, prev_d):
124
+ y, zj, cond_z = u[0], u[1], u[2:]
125
+
126
+ prev_y, prev_zj, prev_null_zj, prev_cond_z = (
127
+ prev_d[:, 0],
128
+ prev_d[:, 1],
129
+ prev_d[:, 2],
130
+ prev_d[:, 3:],
131
+ )
132
+
133
+ y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
134
+ zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_zj.reshape(-1, 1))
135
+ cond_z_mat = self.kernel_cond_z(
136
+ cond_z.reshape(-1, prev_cond_z.shape[1]),
137
+ prev_cond_z.reshape(-1, prev_cond_z.shape[1]),
138
+ )
139
+
140
+ null_zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_null_zj.reshape(-1, 1))
141
+
142
+ mu = np.mean(y_mat * zj_mat * cond_z_mat)
143
+ mu_null = np.mean(y_mat * null_zj_mat * cond_z_mat)
144
+ return mu - mu_null
145
+
146
+
147
+ class xMMD(KernelPayoff):
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+
151
+ self.kernel = Kernel(self.kernel, self.scale_method, self.scale)
152
+
153
+ def witness_function(self, u, prev_d):
154
+ prev_y, prev_y_null = prev_d[:, 0], prev_d[:, 1]
155
+
156
+ mu_y = np.mean(self.kernel(u.reshape(-1, 1), prev_y.reshape(-1, 1)), axis=1)
157
+ mu_y_null = np.mean(
158
+ self.kernel(u.reshape(-1, 1), prev_y_null.reshape(-1, 1)), axis=1
159
+ )
160
+ return mu_y - mu_y_null
ibydmt/test.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from abc import ABC, abstractmethod
3
+ from collections import deque
4
+ from typing import Callable, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from jaxtyping import Float
9
+
10
+ from ibydmt.payoff import HSIC, cMMD, xMMD
11
+ from ibydmt.wealth import get_wealth
12
+
13
+ Array = Union[np.ndarray, torch.Tensor]
14
+
15
+
16
+ class Tester(ABC):
17
+ def __init__(self):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def test(self, *args, **kwargs) -> Tuple[bool, int]:
22
+ pass
23
+
24
+
25
+ class SequentialTester(Tester):
26
+ def __init__(self, config):
27
+ super().__init__()
28
+ self.wealth = get_wealth(config.wealth)(config)
29
+
30
+ self.tau_max = config.tau_max
31
+
32
+
33
+ class SKIT(SequentialTester):
34
+ """Global Independence Tester"""
35
+
36
+ def __init__(self, config):
37
+ super().__init__(config)
38
+ self.payoff = HSIC(config)
39
+
40
+ def test(self, Y: Float[Array, "N"], Z: Float[Array, "N"]) -> Tuple[bool, int]:
41
+ D = np.stack([Y, Z], axis=1)
42
+ for t in range(1, self.tau_max):
43
+ d = D[2 * t : 2 * (t + 1)]
44
+ prev_d = D[: 2 * t]
45
+
46
+ null_d = np.stack([d[:, 0], np.flip(d[:, 1])], axis=1)
47
+
48
+ payoff = self.payoff.compute(d, null_d, prev_d)
49
+ self.wealth.update(payoff)
50
+
51
+ if self.wealth.rejected:
52
+ return (True, t)
53
+ return (False, t)
54
+
55
+
56
+ class cSKIT(SequentialTester):
57
+ """Global Conditional Independence Tester"""
58
+
59
+ def __init__(self, config):
60
+ super().__init__(config)
61
+ self.payoff = cMMD(config)
62
+
63
+ def _sample(
64
+ self,
65
+ z: Float[Array, "N D"],
66
+ j: int = None,
67
+ cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]] = None,
68
+ ) -> Tuple[Float[Array, "N"], Float[Array, "N"], Float[Array, "N D-1"]]:
69
+ C = list(set(range(z.shape[1])) - {j})
70
+
71
+ zj, cond_z = z[:, [j]], z[:, C]
72
+ samples = cond_p(z, C)
73
+ null_zj = samples[:, [j]]
74
+ return zj, null_zj, cond_z
75
+
76
+ def test(
77
+ self,
78
+ Y: Float[Array, "N"],
79
+ Z: Float[Array, "N D"],
80
+ j: int,
81
+ cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]],
82
+ ) -> Tuple[bool, int]:
83
+ sample = functools.partial(self._sample, j=j, cond_p=cond_p)
84
+
85
+ prev_y, prev_z = Y[:1][:, None], Z[:1]
86
+ prev_zj, prev_null_zj, prev_cond_z = sample(prev_z)
87
+ prev_d = np.concatenate([prev_y, prev_zj, prev_null_zj, prev_cond_z], axis=-1)
88
+ for t in range(1, self.tau_max):
89
+ y, z = Y[[t]][:, None], Z[[t]]
90
+ zj, null_zj, cond_z = sample(z)
91
+
92
+ u = np.concatenate([y, zj, cond_z], axis=-1)
93
+ null_u = np.concatenate([y, null_zj, cond_z], axis=-1)
94
+
95
+ payoff = self.payoff.compute(u, null_u, prev_d)
96
+ self.wealth.update(payoff)
97
+
98
+ d = np.concatenate([y, zj, null_zj, cond_z], axis=-1)
99
+ prev_d = np.vstack([prev_d, d])
100
+
101
+ if self.wealth.rejected:
102
+ return (True, t)
103
+ return (False, t)
104
+
105
+
106
+ class xSKIT(SequentialTester):
107
+ """Local Conditional Independence Tester"""
108
+
109
+ def __init__(self, config):
110
+ super().__init__(config)
111
+ self.payoff = xMMD(config)
112
+
113
+ self._queue = deque()
114
+
115
+ def _sample(
116
+ self,
117
+ z: Float[Array, "D"],
118
+ j: int,
119
+ C: list[int],
120
+ cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
121
+ model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
122
+ ) -> Tuple[Float[Array, "1"], Float[Array, "1"]]:
123
+
124
+ if len(self._queue) == 0:
125
+ Cuj = C + [j]
126
+
127
+ h = cond_p(z, Cuj, self.tau_max)
128
+ null_h = cond_p(z, C, self.tau_max)
129
+
130
+ y = model(h)[:, None]
131
+ null_y = model(null_h)[:, None]
132
+
133
+ self._queue.extend(zip(y, null_y))
134
+
135
+ return self._queue.pop()
136
+
137
+ def test(
138
+ self,
139
+ z: Float[Array, "D"],
140
+ j: int,
141
+ C: list[int],
142
+ cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
143
+ model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
144
+ ) -> Tuple[bool, int]:
145
+ sample = functools.partial(self._sample, z, j, C, cond_p, model)
146
+
147
+ prev_d = np.stack(sample(), axis=1)
148
+ for t in range(1, self.tau_max):
149
+ y, null_y = sample()
150
+
151
+ payoff = self.payoff.compute(y, null_y, prev_d)
152
+ self.wealth.update(payoff)
153
+
154
+ d = np.stack([y, null_y], axis=1)
155
+ prev_d = np.vstack([prev_d, d])
156
+
157
+ if self.wealth.rejected:
158
+ return (True, t)
159
+ return (False, t)
ibydmt/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def _register_cls(name, dict=None):
2
+ def _register(cls):
3
+ if name in dict:
4
+ raise ValueError(f"{name} is already registered")
5
+
6
+ dict[name] = cls
7
+
8
+ return _register
9
+
10
+
11
+ def _get_cls(name, dict=None):
12
+ return dict[name]
ibydmt/wealth.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict
3
+
4
+ import numpy as np
5
+
6
+ from ibydmt.utils import _get_cls, _register_cls
7
+
8
+
9
+ class Wealth(ABC):
10
+ def __init__(self, config):
11
+ self.significance_level = config.significance_level
12
+ self.rejected = False
13
+
14
+ @abstractmethod
15
+ def update(self, payoff):
16
+ pass
17
+
18
+
19
+ _WEALTH: Dict[str, Wealth] = {}
20
+
21
+
22
+ def register_wealth(name):
23
+ return _register_cls(name, dict=_WEALTH)
24
+
25
+
26
+ def get_wealth(name):
27
+ return _get_cls(name, dict=_WEALTH)
28
+
29
+
30
+ @register_wealth("mixture")
31
+ class Mixture(Wealth):
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+
35
+ self.grid_size = grid_size = config.grid_size
36
+ self.wealth = np.ones((grid_size,))
37
+ self.wealth_flag = np.ones(grid_size, dtype=bool)
38
+ self.v = np.linspace(0.05, 0.95, grid_size)
39
+
40
+ def update(self, payoff):
41
+ raise NotImplementedError
42
+
43
+
44
+ @register_wealth("ons")
45
+ class ONS(Wealth):
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+
49
+ self.w = 1.0
50
+ self.v = 0
51
+ self.a = 1
52
+
53
+ self.min_v, self.max_v = config.get("min_v", 0), config.get("max_v", 1 / 2)
54
+ self.wealth_flag = False
55
+
56
+ def _update_v(self, payoff):
57
+ z = payoff / (1 + self.v * payoff)
58
+ self.a += z**2
59
+ self.v = max(
60
+ self.min_v, min(self.max_v, self.v + 2 / (2 - np.log(3)) * z / self.a)
61
+ )
62
+
63
+ def update(self, payoff):
64
+ w = self.w * (1 + self.v * payoff)
65
+
66
+ if w >= 0 and not self.wealth_flag:
67
+ self.w = w
68
+ if self.w >= 1 / self.significance_level:
69
+ self.rejected = True
70
+ self._update_v(payoff)
71
+ else:
72
+ self.wealth_flag = True
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  streamlit-image-select
2
  clip @ git+https://github.com/openai/CLIP@main
3
- open_clip_torch
 
1
  streamlit-image-select
2
  clip @ git+https://github.com/openai/CLIP@main
3
+ open_clip_torch