Spaces:
Sleeping
Sleeping
jacopoteneggi
commited on
Update
Browse files- app.py +17 -20
- app_lib/__pycache__/__init__.cpython-310.pyc +0 -0
- app_lib/__pycache__/main.cpython-310.pyc +0 -0
- app_lib/__pycache__/test.cpython-310.pyc +0 -0
- app_lib/__pycache__/user_input.cpython-310.pyc +0 -0
- app_lib/__pycache__/utils.cpython-310.pyc +0 -0
- app_lib/ckde.py +77 -0
- app_lib/main.py +48 -6
- app_lib/test.py +84 -0
- app_lib/user_input.py +5 -2
- app_lib/utils.py +15 -3
- assets/ace.jpg +0 -0
- ibydmt/__init__.py +1 -0
- ibydmt/__pycache__/__init__.cpython-310.pyc +0 -0
- ibydmt/__pycache__/__init__.cpython-311.pyc +0 -0
- ibydmt/__pycache__/bet.cpython-310.pyc +0 -0
- ibydmt/__pycache__/bet.cpython-311.pyc +0 -0
- ibydmt/__pycache__/payoff.cpython-310.pyc +0 -0
- ibydmt/__pycache__/payoff.cpython-311.pyc +0 -0
- ibydmt/__pycache__/test.cpython-310.pyc +0 -0
- ibydmt/__pycache__/test.cpython-311.pyc +0 -0
- ibydmt/__pycache__/utils.cpython-310.pyc +0 -0
- ibydmt/__pycache__/utils.cpython-311.pyc +0 -0
- ibydmt/__pycache__/wealth.cpython-310.pyc +0 -0
- ibydmt/__pycache__/wealth.cpython-311.pyc +0 -0
- ibydmt/bet.py +56 -0
- ibydmt/payoff.py +160 -0
- ibydmt/test.py +159 -0
- ibydmt/utils.py +12 -0
- ibydmt/wealth.py +72 -0
- requirements.txt +1 -1
app.py
CHANGED
@@ -1,16 +1,18 @@
|
|
1 |
import numpy as np
|
2 |
import open_clip
|
3 |
import streamlit as st
|
4 |
-
import torch
|
5 |
|
6 |
from app_lib.main import main
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
st.set_page_config(
|
11 |
-
layout="wide",
|
12 |
-
initial_sidebar_state=st.session_state.get("sidebar_state", "collapsed"),
|
13 |
-
)
|
14 |
st.session_state.sidebar_state = "collapsed"
|
15 |
st.markdown(
|
16 |
"""
|
@@ -21,6 +23,15 @@ st.markdown(
|
|
21 |
input {
|
22 |
font-family: monospace !important;
|
23 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
</style>
|
25 |
""",
|
26 |
unsafe_allow_html=True,
|
@@ -36,19 +47,5 @@ st.markdown(
|
|
36 |
""",
|
37 |
)
|
38 |
|
39 |
-
|
40 |
-
def load_clip():
|
41 |
-
model, _, preprocess = open_clip.create_model_and_transforms(
|
42 |
-
"hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
43 |
-
)
|
44 |
-
tokenizer = open_clip.get_tokenizer("hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
45 |
-
|
46 |
-
|
47 |
-
def test(
|
48 |
-
image, class_name, concepts, cardinality, model_name, dataset_name="imagenette"
|
49 |
-
):
|
50 |
-
print("test!")
|
51 |
-
|
52 |
-
|
53 |
if __name__ == "__main__":
|
54 |
main()
|
|
|
1 |
import numpy as np
|
2 |
import open_clip
|
3 |
import streamlit as st
|
|
|
4 |
|
5 |
from app_lib.main import main
|
6 |
|
7 |
+
if "sidebar_state" not in st.session_state:
|
8 |
+
st.session_state.sidebar_state = "collapsed"
|
9 |
+
if "disabled" not in st.session_state:
|
10 |
+
st.session_state.disabled = False
|
11 |
+
if "results" not in st.session_state:
|
12 |
+
st.session_state.results = None
|
13 |
+
|
14 |
+
st.set_page_config(layout="wide", initial_sidebar_state=st.session_state.sidebar_state)
|
15 |
|
|
|
|
|
|
|
|
|
16 |
st.session_state.sidebar_state = "collapsed"
|
17 |
st.markdown(
|
18 |
"""
|
|
|
23 |
input {
|
24 |
font-family: monospace !important;
|
25 |
}
|
26 |
+
|
27 |
+
[data-testid="stHorizontalBlock"] {
|
28 |
+
align-items: center;
|
29 |
+
}
|
30 |
+
div.stSpinner > div {
|
31 |
+
text-align:center;
|
32 |
+
align-items: center;
|
33 |
+
justify-content: center;
|
34 |
+
}
|
35 |
</style>
|
36 |
""",
|
37 |
unsafe_allow_html=True,
|
|
|
47 |
""",
|
48 |
)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
if __name__ == "__main__":
|
51 |
main()
|
app_lib/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/__init__.cpython-310.pyc and b/app_lib/__pycache__/__init__.cpython-310.pyc differ
|
|
app_lib/__pycache__/main.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/main.cpython-310.pyc and b/app_lib/__pycache__/main.cpython-310.pyc differ
|
|
app_lib/__pycache__/test.cpython-310.pyc
ADDED
Binary file (2.56 kB). View file
|
|
app_lib/__pycache__/user_input.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/user_input.cpython-310.pyc and b/app_lib/__pycache__/user_input.cpython-310.pyc differ
|
|
app_lib/__pycache__/utils.cpython-310.pyc
CHANGED
Binary files a/app_lib/__pycache__/utils.cpython-310.pyc and b/app_lib/__pycache__/utils.cpython-310.pyc differ
|
|
app_lib/ckde.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from scipy.spatial.distance import cdist
|
4 |
+
from scipy.stats import gaussian_kde
|
5 |
+
|
6 |
+
class cKDE:
|
7 |
+
def __init__(self, config, concept_class_name=None, concept_image_idx=None):
|
8 |
+
ckde_config = config.ckde
|
9 |
+
self.image_size = image_size = ckde_config.get("image_size", 128)
|
10 |
+
self.metric = ckde_config.get("metric", "euclidean")
|
11 |
+
self.scale_method = ckde_config.get("scale_method", "neff")
|
12 |
+
self.scale = ckde_config.get("scale", 2000)
|
13 |
+
|
14 |
+
self.Z = self.dataset.Z
|
15 |
+
self.H = self.dataset.H
|
16 |
+
|
17 |
+
def _quantile_scale(self, Z_cond_dist):
|
18 |
+
return np.quantile(Z_cond_dist, self.scale)
|
19 |
+
|
20 |
+
def _neff_scale(self, Z_cond_dist):
|
21 |
+
scales = np.linspace(1e-02, 0.4, 100)[:, None]
|
22 |
+
|
23 |
+
_Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1))
|
24 |
+
|
25 |
+
weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2))
|
26 |
+
neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1)
|
27 |
+
diff = np.abs(neff - self.scale)
|
28 |
+
scale_idx = np.argmin(diff)
|
29 |
+
return scales[scale_idx].item()
|
30 |
+
|
31 |
+
def _sample(self, z, cond_idx, m):
|
32 |
+
sample_idx = list(set(range(len(z))) - set(cond_idx))
|
33 |
+
|
34 |
+
kde, _ = self.kde(z, cond_idx)
|
35 |
+
|
36 |
+
sample_z = np.tile(z, (m, 1))
|
37 |
+
sample_z[:, sample_idx] = kde.resample(m).T
|
38 |
+
|
39 |
+
return sample_z
|
40 |
+
|
41 |
+
def kde(self, z, cond_idx):
|
42 |
+
sample_idx = list(set(range(len(z))) - set(cond_idx))
|
43 |
+
|
44 |
+
Z_sample = self.Z[:, sample_idx]
|
45 |
+
Z_cond = self.Z[:, cond_idx]
|
46 |
+
|
47 |
+
z_cond = z[cond_idx]
|
48 |
+
Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze()
|
49 |
+
|
50 |
+
if self.scale_method == "constant":
|
51 |
+
scale = self.scale
|
52 |
+
if self.scale_method == "quantile":
|
53 |
+
scale = self._quantile_scale(Z_cond_dist)
|
54 |
+
elif self.scale_method == "neff":
|
55 |
+
scale = self._neff_scale(Z_cond_dist)
|
56 |
+
|
57 |
+
weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2))
|
58 |
+
|
59 |
+
return gaussian_kde(Z_sample.T, weights=weights), scale
|
60 |
+
|
61 |
+
def nearest_neighbor(self, z):
|
62 |
+
dist = cdist(z, self.Z, metric=self.metric)
|
63 |
+
return np.argmin(dist, axis=-1)
|
64 |
+
|
65 |
+
def sample(self, z, cond_idx, m=1, return_images=False):
|
66 |
+
if z.ndim == 1:
|
67 |
+
z = z.reshape(1, -1)
|
68 |
+
|
69 |
+
sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0)
|
70 |
+
|
71 |
+
nn_idx = self.nearest_neighbor(sample_z)
|
72 |
+
sample_h = self.H[nn_idx]
|
73 |
+
|
74 |
+
if return_images:
|
75 |
+
sample_images = torch.stack([self.dataset[_idx][0] for _idx in nn_idx])
|
76 |
+
return sample_z, sample_h, sample_images
|
77 |
+
return sample_z, sample_h
|
app_lib/main.py
CHANGED
@@ -1,4 +1,6 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
2 |
|
3 |
from app_lib.user_input import (
|
4 |
get_cardinality,
|
@@ -7,9 +9,20 @@ from app_lib.user_input import (
|
|
7 |
get_image,
|
8 |
get_model_name,
|
9 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
-
def main():
|
13 |
columns = st.columns([0.40, 0.60])
|
14 |
|
15 |
with columns[0]:
|
@@ -27,7 +40,11 @@ def main():
|
|
27 |
cardinality = get_cardinality(concepts, concepts_ready)
|
28 |
|
29 |
with row2[0]:
|
30 |
-
change_image_button = st.button(
|
|
|
|
|
|
|
|
|
31 |
if change_image_button:
|
32 |
st.session_state.sidebar_state = "expanded"
|
33 |
st.experimental_rerun()
|
@@ -39,13 +56,38 @@ def main():
|
|
39 |
error_message += f"- {class_error}\n"
|
40 |
if concepts_error is not None:
|
41 |
error_message += f"- {concepts_error}\n"
|
|
|
|
|
42 |
|
43 |
test_button = st.button(
|
44 |
"Test",
|
45 |
-
help=None if ready else error_message,
|
46 |
use_container_width=True,
|
47 |
-
|
|
|
48 |
)
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
import streamlit as st
|
3 |
+
import time
|
4 |
|
5 |
from app_lib.user_input import (
|
6 |
get_cardinality,
|
|
|
9 |
get_image,
|
10 |
get_model_name,
|
11 |
)
|
12 |
+
from app_lib.test import (
|
13 |
+
load_dataset,
|
14 |
+
load_model,
|
15 |
+
encode_image,
|
16 |
+
encode_concepts,
|
17 |
+
encode_class_name,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def _disable():
|
22 |
+
st.session_state.disabled = True
|
23 |
|
24 |
|
25 |
+
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
|
26 |
columns = st.columns([0.40, 0.60])
|
27 |
|
28 |
with columns[0]:
|
|
|
40 |
cardinality = get_cardinality(concepts, concepts_ready)
|
41 |
|
42 |
with row2[0]:
|
43 |
+
change_image_button = st.button(
|
44 |
+
"Change Image",
|
45 |
+
use_container_width=True,
|
46 |
+
disabled=st.session_state.disabled,
|
47 |
+
)
|
48 |
if change_image_button:
|
49 |
st.session_state.sidebar_state = "expanded"
|
50 |
st.experimental_rerun()
|
|
|
56 |
error_message += f"- {class_error}\n"
|
57 |
if concepts_error is not None:
|
58 |
error_message += f"- {concepts_error}\n"
|
59 |
+
if error_message:
|
60 |
+
st.error(error_message)
|
61 |
|
62 |
test_button = st.button(
|
63 |
"Test",
|
|
|
64 |
use_container_width=True,
|
65 |
+
on_click=_disable,
|
66 |
+
disabled=st.session_state.disabled or not ready,
|
67 |
)
|
68 |
|
69 |
+
with columns[1]:
|
70 |
+
if test_button:
|
71 |
+
with st.spinner("Loading dataset"):
|
72 |
+
embedding = load_dataset("imagenette", model_name)
|
73 |
+
time.sleep(1)
|
74 |
+
|
75 |
+
with st.spinner("Loading model"):
|
76 |
+
model, preprocess, tokenizer = load_model(model_name, device)
|
77 |
+
time.sleep(1)
|
78 |
+
|
79 |
+
with st.spinner("Encoding concepts"):
|
80 |
+
cbm = encode_concepts(tokenizer, model, concepts, device)
|
81 |
+
time.sleep(1)
|
82 |
+
|
83 |
+
with st.spinner("Preparing zero-shot classifier"):
|
84 |
+
classifier = encode_class_name(tokenizer, model, class_name, device)
|
85 |
+
|
86 |
+
with st.spinner("Encoding image"):
|
87 |
+
h = encode_image(model, preprocess, image, device)
|
88 |
+
z = h @ cbm.T
|
89 |
+
print(h.shape, cbm.shape, z.shape)
|
90 |
+
time.sleep(2)
|
91 |
+
|
92 |
+
st.session_state.disabled = False
|
93 |
+
st.experimental_rerun()
|
app_lib/test.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import clip
|
3 |
+
import open_clip
|
4 |
+
import h5py
|
5 |
+
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
from app_lib.utils import SUPPORTED_MODELS
|
9 |
+
|
10 |
+
|
11 |
+
def _get_open_clip_model(model_name, device):
|
12 |
+
backbone = model_name.split(":")[-1]
|
13 |
+
|
14 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
15 |
+
SUPPORTED_MODELS[model_name], device=device
|
16 |
+
)
|
17 |
+
model.eval()
|
18 |
+
tokenizer = open_clip.get_tokenizer(backbone)
|
19 |
+
return model, preprocess, tokenizer
|
20 |
+
|
21 |
+
|
22 |
+
def _get_clip_model(model_name, device):
|
23 |
+
backbone = model_name.split(":")[-1]
|
24 |
+
model, preprocess = clip.load(backbone, device=device)
|
25 |
+
tokenizer = clip.tokenize
|
26 |
+
return model, preprocess, tokenizer
|
27 |
+
|
28 |
+
|
29 |
+
def load_dataset(dataset_name, model_name):
|
30 |
+
dataset_path = hf_hub_download(
|
31 |
+
repo_id="jacopoteneggi/IBYDMT",
|
32 |
+
filename=f"{dataset_name}_{model_name}_train.h5",
|
33 |
+
repo_type="dataset",
|
34 |
+
)
|
35 |
+
|
36 |
+
with h5py.File(dataset_path, "r") as dataset:
|
37 |
+
embedding = dataset["embedding"][:]
|
38 |
+
return embedding
|
39 |
+
|
40 |
+
|
41 |
+
def load_model(model_name, device):
|
42 |
+
print(model_name)
|
43 |
+
if "open_clip" in model_name:
|
44 |
+
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
|
45 |
+
elif "clip" in model_name:
|
46 |
+
model, preprocess, tokenizer = _get_clip_model(model_name, device)
|
47 |
+
return model, preprocess, tokenizer
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
@torch.cuda.amp.autocast()
|
52 |
+
def encode_concepts(tokenizer, model, concepts, device):
|
53 |
+
concepts_text = tokenizer(concepts).to(device)
|
54 |
+
|
55 |
+
concept_features = model.encode_text(concepts_text)
|
56 |
+
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
|
57 |
+
return concept_features.cpu().numpy()
|
58 |
+
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
@torch.cuda.amp.autocast()
|
62 |
+
def encode_image(model, preprocess, image, device):
|
63 |
+
image = preprocess(image)
|
64 |
+
image = image.unsqueeze(0)
|
65 |
+
image = image.to(device)
|
66 |
+
|
67 |
+
image_features = model.encode_image(image)
|
68 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
69 |
+
return image_features.cpu().numpy()
|
70 |
+
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
@torch.cuda.amp.autocast()
|
74 |
+
def encode_class_name(tokenizer, model, class_name, device):
|
75 |
+
class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
|
76 |
+
|
77 |
+
class_features = model.encode_text(class_text)
|
78 |
+
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
|
79 |
+
return class_features.cpu().numpy()
|
80 |
+
|
81 |
+
|
82 |
+
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
|
83 |
+
model, preprocess = load_model(model_name, device)
|
84 |
+
print(f"loaded {model_name}")
|
app_lib/user_input.py
CHANGED
@@ -24,8 +24,9 @@ def _validate_concepts(concepts):
|
|
24 |
def get_model_name():
|
25 |
return st.selectbox(
|
26 |
"Choose a model to test",
|
27 |
-
options=SUPPORTED_MODELS,
|
28 |
help="Name of the vision-language model to test the predictions of.",
|
|
|
29 |
)
|
30 |
|
31 |
|
@@ -49,6 +50,7 @@ def get_class_name():
|
|
49 |
"Class to test",
|
50 |
help="Name of the class to build the zero-shot CLIP classifier with.",
|
51 |
value="cat",
|
|
|
52 |
)
|
53 |
|
54 |
class_ready, class_error = _validate_class_name(class_name)
|
@@ -61,6 +63,7 @@ def get_concepts():
|
|
61 |
help="List of concepts to test the predictions of the model with. Write one concept per line.",
|
62 |
height=160,
|
63 |
value="piano\ncute\nwhiskers\nmusic\nwild",
|
|
|
64 |
)
|
65 |
concepts = concepts.split("\n")
|
66 |
concepts = [concept.strip() for concept in concepts]
|
@@ -79,5 +82,5 @@ def get_cardinality(concepts, concepts_ready):
|
|
79 |
max_value=max(2, len(concepts) - 1),
|
80 |
value=1,
|
81 |
step=1,
|
82 |
-
disabled=not concepts_ready,
|
83 |
)
|
|
|
24 |
def get_model_name():
|
25 |
return st.selectbox(
|
26 |
"Choose a model to test",
|
27 |
+
options=list(SUPPORTED_MODELS.keys()),
|
28 |
help="Name of the vision-language model to test the predictions of.",
|
29 |
+
disabled=st.session_state.disabled,
|
30 |
)
|
31 |
|
32 |
|
|
|
50 |
"Class to test",
|
51 |
help="Name of the class to build the zero-shot CLIP classifier with.",
|
52 |
value="cat",
|
53 |
+
disabled=st.session_state.disabled,
|
54 |
)
|
55 |
|
56 |
class_ready, class_error = _validate_class_name(class_name)
|
|
|
63 |
help="List of concepts to test the predictions of the model with. Write one concept per line.",
|
64 |
height=160,
|
65 |
value="piano\ncute\nwhiskers\nmusic\nwild",
|
66 |
+
disabled=st.session_state.disabled,
|
67 |
)
|
68 |
concepts = concepts.split("\n")
|
69 |
concepts = [concept.strip() for concept in concepts]
|
|
|
82 |
max_value=max(2, len(concepts) - 1),
|
83 |
value=1,
|
84 |
step=1,
|
85 |
+
disabled=st.session_state.disabled or not concepts_ready,
|
86 |
)
|
app_lib/utils.py
CHANGED
@@ -5,10 +5,22 @@ supported_models_path = hf_hub_download(
|
|
5 |
filename="supported_models.txt",
|
6 |
repo_type="dataset",
|
7 |
)
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
SUPPORTED_MODELS =
|
10 |
with open(supported_models_path, "r") as f:
|
11 |
for line in f:
|
12 |
line = line.strip()
|
13 |
-
model_name,
|
14 |
-
SUPPORTED_MODELS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
filename="supported_models.txt",
|
6 |
repo_type="dataset",
|
7 |
)
|
8 |
+
supported_datasets_path = hf_hub_download(
|
9 |
+
repo_id="jacopoteneggi/IBYDMT",
|
10 |
+
filename="supported_datasets.txt",
|
11 |
+
repo_type="dataset",
|
12 |
+
)
|
13 |
|
14 |
+
SUPPORTED_MODELS = {}
|
15 |
with open(supported_models_path, "r") as f:
|
16 |
for line in f:
|
17 |
line = line.strip()
|
18 |
+
model_name, model_url = line.split(",")
|
19 |
+
SUPPORTED_MODELS[model_name] = model_url
|
20 |
+
|
21 |
+
|
22 |
+
SUPPORTED_DATASETS = []
|
23 |
+
with open(supported_models_path, "r") as f:
|
24 |
+
for line in f:
|
25 |
+
dataset_name = line.strip()
|
26 |
+
SUPPORTED_DATASETS.append(dataset_name)
|
assets/ace.jpg
ADDED
ibydmt/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from ibydmt.test import SKIT, cSKIT, xSKIT
|
ibydmt/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (224 Bytes). View file
|
|
ibydmt/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (266 Bytes). View file
|
|
ibydmt/__pycache__/bet.cpython-310.pyc
ADDED
Binary file (2.06 kB). View file
|
|
ibydmt/__pycache__/bet.cpython-311.pyc
ADDED
Binary file (3.52 kB). View file
|
|
ibydmt/__pycache__/payoff.cpython-310.pyc
ADDED
Binary file (5.24 kB). View file
|
|
ibydmt/__pycache__/payoff.cpython-311.pyc
ADDED
Binary file (10.4 kB). View file
|
|
ibydmt/__pycache__/test.cpython-310.pyc
ADDED
Binary file (5.3 kB). View file
|
|
ibydmt/__pycache__/test.cpython-311.pyc
ADDED
Binary file (9.79 kB). View file
|
|
ibydmt/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (599 Bytes). View file
|
|
ibydmt/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (793 Bytes). View file
|
|
ibydmt/__pycache__/wealth.cpython-310.pyc
ADDED
Binary file (2.65 kB). View file
|
|
ibydmt/__pycache__/wealth.cpython-311.pyc
ADDED
Binary file (4.48 kB). View file
|
|
ibydmt/bet.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from ibydmt.utils import _get_cls, _register_cls
|
7 |
+
|
8 |
+
|
9 |
+
class Bet(ABC):
|
10 |
+
def __init__(self):
|
11 |
+
pass
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def compute(self, *args, **kwargs):
|
15 |
+
pass
|
16 |
+
|
17 |
+
|
18 |
+
_BETS: Dict[str, Bet] = {}
|
19 |
+
|
20 |
+
|
21 |
+
def register_bet(name):
|
22 |
+
return _register_cls(name, dict=_BETS)
|
23 |
+
|
24 |
+
|
25 |
+
def get_bet(name):
|
26 |
+
return _get_cls(name, dict=_BETS)
|
27 |
+
|
28 |
+
|
29 |
+
@register_bet("sign")
|
30 |
+
class Sign(Bet):
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__()
|
33 |
+
self.m = config.get("m", 0.5)
|
34 |
+
self.prev_g = []
|
35 |
+
|
36 |
+
def compute(self, g):
|
37 |
+
return self.m * np.sign(g)
|
38 |
+
|
39 |
+
|
40 |
+
@register_bet("tanh")
|
41 |
+
class Tanh(Bet):
|
42 |
+
def __init__(self, config):
|
43 |
+
super().__init__()
|
44 |
+
self.alpha = config.get("alpha", 0.20)
|
45 |
+
self.prev_g = []
|
46 |
+
|
47 |
+
def compute(self, g):
|
48 |
+
if len(self.prev_g) < 2:
|
49 |
+
scale = 1
|
50 |
+
else:
|
51 |
+
l, u = np.quantile(self.prev_g, [self.alpha / 2, 1 - self.alpha / 2])
|
52 |
+
scale = u - l
|
53 |
+
|
54 |
+
self.prev_g.append(g)
|
55 |
+
|
56 |
+
return np.tanh(g / np.clip(scale, 1e-04, None))
|
ibydmt/payoff.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from functools import reduce
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from sklearn.metrics import pairwise_distances
|
6 |
+
from sklearn.metrics.pairwise import linear_kernel, rbf_kernel
|
7 |
+
|
8 |
+
from ibydmt.bet import get_bet
|
9 |
+
|
10 |
+
|
11 |
+
class Payoff(ABC):
|
12 |
+
def __init__(self, config):
|
13 |
+
self.bet = get_bet(config.bet)(config)
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def compute(self, *args, **kwargs):
|
17 |
+
pass
|
18 |
+
|
19 |
+
|
20 |
+
class Kernel:
|
21 |
+
def __init__(self, kernel: str, scale_method: str, scale: float):
|
22 |
+
if kernel == "linear":
|
23 |
+
self.base_kernel = linear_kernel
|
24 |
+
elif kernel == "rbf":
|
25 |
+
self.base_kernel = rbf_kernel
|
26 |
+
|
27 |
+
self.scale_method = scale_method
|
28 |
+
self.scale = scale
|
29 |
+
|
30 |
+
self.gamma = None
|
31 |
+
self.recompute_gamma = True
|
32 |
+
self.prev = None
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(f"{kernel} is not implemented")
|
35 |
+
|
36 |
+
def __call__(self, x, y):
|
37 |
+
if self.base_kernel == linear_kernel:
|
38 |
+
return self.base_kernel(x, y)
|
39 |
+
if self.base_kernel == rbf_kernel:
|
40 |
+
if self.scale_method == "constant":
|
41 |
+
self.gamma = self.scale
|
42 |
+
elif self.scale_method == "quantile":
|
43 |
+
if self.prev is None:
|
44 |
+
self.prev = y
|
45 |
+
|
46 |
+
if self.recompute_gamma:
|
47 |
+
dist = pairwise_distances(
|
48 |
+
self.prev.reshape(-1, self.prev.shape[-1])
|
49 |
+
)
|
50 |
+
scale = np.quantile(dist, self.scale)
|
51 |
+
gamma = 1 / (2 * scale**2) if scale > 0 else None
|
52 |
+
self.gamma = gamma
|
53 |
+
|
54 |
+
if len(self.prev) > 100:
|
55 |
+
self.recompute_gamma = False
|
56 |
+
self.prev = np.vstack([self.prev, x])
|
57 |
+
else:
|
58 |
+
raise NotImplementedError(
|
59 |
+
f"{self.scale} is not implemented for rbf_kernel"
|
60 |
+
)
|
61 |
+
return self.base_kernel(x, y, gamma=self.gamma)
|
62 |
+
|
63 |
+
|
64 |
+
class KernelPayoff(Payoff):
|
65 |
+
def __init__(self, config):
|
66 |
+
super().__init__(config)
|
67 |
+
|
68 |
+
self.kernel = config.kernel
|
69 |
+
self.scale_method = config.get("kernel_scale_method", "quantile")
|
70 |
+
self.scale = config.get("kernel_scale", 0.5)
|
71 |
+
|
72 |
+
@abstractmethod
|
73 |
+
def witness_function(self, d, prev_d):
|
74 |
+
pass
|
75 |
+
|
76 |
+
def compute(self, d, null_d, prev_d):
|
77 |
+
g = reduce(
|
78 |
+
lambda acc, u: acc
|
79 |
+
+ self.witness_function(u[0], prev_d)
|
80 |
+
- self.witness_function(u[1], prev_d),
|
81 |
+
zip(d, null_d),
|
82 |
+
0,
|
83 |
+
)
|
84 |
+
|
85 |
+
return self.bet.compute(g)
|
86 |
+
|
87 |
+
|
88 |
+
class HSIC(KernelPayoff):
|
89 |
+
def __init__(self, config):
|
90 |
+
super().__init__(config)
|
91 |
+
|
92 |
+
kernel = self.kernel
|
93 |
+
scale_method = self.scale_method
|
94 |
+
scale = self.scale
|
95 |
+
|
96 |
+
self.kernel_y = Kernel(kernel, scale_method, scale)
|
97 |
+
self.kernel_z = Kernel(kernel, scale_method, scale)
|
98 |
+
|
99 |
+
def witness_function(self, d, prev_d):
|
100 |
+
y, z = d
|
101 |
+
prev_y, prev_z = prev_d[:, 0], prev_d[:, 1]
|
102 |
+
|
103 |
+
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
|
104 |
+
z_mat = self.kernel_z(z.reshape(-1, 1), prev_z.reshape(-1, 1))
|
105 |
+
|
106 |
+
mu_joint = np.mean(y_mat * z_mat)
|
107 |
+
mu_prod = np.mean(y_mat, axis=1) @ np.mean(z_mat, axis=1)
|
108 |
+
return mu_joint - mu_prod
|
109 |
+
|
110 |
+
|
111 |
+
class cMMD(KernelPayoff):
|
112 |
+
def __init__(self, config):
|
113 |
+
super().__init__(config)
|
114 |
+
|
115 |
+
kernel = self.kernel
|
116 |
+
scale_method = self.scale_method
|
117 |
+
scale = self.scale
|
118 |
+
|
119 |
+
self.kernel_y = Kernel(kernel, scale_method, scale)
|
120 |
+
self.kernel_zj = Kernel(kernel, scale_method, scale)
|
121 |
+
self.kernel_cond_z = Kernel(kernel, scale_method, scale)
|
122 |
+
|
123 |
+
def witness_function(self, u, prev_d):
|
124 |
+
y, zj, cond_z = u[0], u[1], u[2:]
|
125 |
+
|
126 |
+
prev_y, prev_zj, prev_null_zj, prev_cond_z = (
|
127 |
+
prev_d[:, 0],
|
128 |
+
prev_d[:, 1],
|
129 |
+
prev_d[:, 2],
|
130 |
+
prev_d[:, 3:],
|
131 |
+
)
|
132 |
+
|
133 |
+
y_mat = self.kernel_y(y.reshape(-1, 1), prev_y.reshape(-1, 1))
|
134 |
+
zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_zj.reshape(-1, 1))
|
135 |
+
cond_z_mat = self.kernel_cond_z(
|
136 |
+
cond_z.reshape(-1, prev_cond_z.shape[1]),
|
137 |
+
prev_cond_z.reshape(-1, prev_cond_z.shape[1]),
|
138 |
+
)
|
139 |
+
|
140 |
+
null_zj_mat = self.kernel_zj(zj.reshape(-1, 1), prev_null_zj.reshape(-1, 1))
|
141 |
+
|
142 |
+
mu = np.mean(y_mat * zj_mat * cond_z_mat)
|
143 |
+
mu_null = np.mean(y_mat * null_zj_mat * cond_z_mat)
|
144 |
+
return mu - mu_null
|
145 |
+
|
146 |
+
|
147 |
+
class xMMD(KernelPayoff):
|
148 |
+
def __init__(self, config):
|
149 |
+
super().__init__(config)
|
150 |
+
|
151 |
+
self.kernel = Kernel(self.kernel, self.scale_method, self.scale)
|
152 |
+
|
153 |
+
def witness_function(self, u, prev_d):
|
154 |
+
prev_y, prev_y_null = prev_d[:, 0], prev_d[:, 1]
|
155 |
+
|
156 |
+
mu_y = np.mean(self.kernel(u.reshape(-1, 1), prev_y.reshape(-1, 1)), axis=1)
|
157 |
+
mu_y_null = np.mean(
|
158 |
+
self.kernel(u.reshape(-1, 1), prev_y_null.reshape(-1, 1)), axis=1
|
159 |
+
)
|
160 |
+
return mu_y - mu_y_null
|
ibydmt/test.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from collections import deque
|
4 |
+
from typing import Callable, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from jaxtyping import Float
|
9 |
+
|
10 |
+
from ibydmt.payoff import HSIC, cMMD, xMMD
|
11 |
+
from ibydmt.wealth import get_wealth
|
12 |
+
|
13 |
+
Array = Union[np.ndarray, torch.Tensor]
|
14 |
+
|
15 |
+
|
16 |
+
class Tester(ABC):
|
17 |
+
def __init__(self):
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def test(self, *args, **kwargs) -> Tuple[bool, int]:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class SequentialTester(Tester):
|
26 |
+
def __init__(self, config):
|
27 |
+
super().__init__()
|
28 |
+
self.wealth = get_wealth(config.wealth)(config)
|
29 |
+
|
30 |
+
self.tau_max = config.tau_max
|
31 |
+
|
32 |
+
|
33 |
+
class SKIT(SequentialTester):
|
34 |
+
"""Global Independence Tester"""
|
35 |
+
|
36 |
+
def __init__(self, config):
|
37 |
+
super().__init__(config)
|
38 |
+
self.payoff = HSIC(config)
|
39 |
+
|
40 |
+
def test(self, Y: Float[Array, "N"], Z: Float[Array, "N"]) -> Tuple[bool, int]:
|
41 |
+
D = np.stack([Y, Z], axis=1)
|
42 |
+
for t in range(1, self.tau_max):
|
43 |
+
d = D[2 * t : 2 * (t + 1)]
|
44 |
+
prev_d = D[: 2 * t]
|
45 |
+
|
46 |
+
null_d = np.stack([d[:, 0], np.flip(d[:, 1])], axis=1)
|
47 |
+
|
48 |
+
payoff = self.payoff.compute(d, null_d, prev_d)
|
49 |
+
self.wealth.update(payoff)
|
50 |
+
|
51 |
+
if self.wealth.rejected:
|
52 |
+
return (True, t)
|
53 |
+
return (False, t)
|
54 |
+
|
55 |
+
|
56 |
+
class cSKIT(SequentialTester):
|
57 |
+
"""Global Conditional Independence Tester"""
|
58 |
+
|
59 |
+
def __init__(self, config):
|
60 |
+
super().__init__(config)
|
61 |
+
self.payoff = cMMD(config)
|
62 |
+
|
63 |
+
def _sample(
|
64 |
+
self,
|
65 |
+
z: Float[Array, "N D"],
|
66 |
+
j: int = None,
|
67 |
+
cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]] = None,
|
68 |
+
) -> Tuple[Float[Array, "N"], Float[Array, "N"], Float[Array, "N D-1"]]:
|
69 |
+
C = list(set(range(z.shape[1])) - {j})
|
70 |
+
|
71 |
+
zj, cond_z = z[:, [j]], z[:, C]
|
72 |
+
samples = cond_p(z, C)
|
73 |
+
null_zj = samples[:, [j]]
|
74 |
+
return zj, null_zj, cond_z
|
75 |
+
|
76 |
+
def test(
|
77 |
+
self,
|
78 |
+
Y: Float[Array, "N"],
|
79 |
+
Z: Float[Array, "N D"],
|
80 |
+
j: int,
|
81 |
+
cond_p: Callable[[Float[Array, "N D"], list[int]], Float[Array, "N D"]],
|
82 |
+
) -> Tuple[bool, int]:
|
83 |
+
sample = functools.partial(self._sample, j=j, cond_p=cond_p)
|
84 |
+
|
85 |
+
prev_y, prev_z = Y[:1][:, None], Z[:1]
|
86 |
+
prev_zj, prev_null_zj, prev_cond_z = sample(prev_z)
|
87 |
+
prev_d = np.concatenate([prev_y, prev_zj, prev_null_zj, prev_cond_z], axis=-1)
|
88 |
+
for t in range(1, self.tau_max):
|
89 |
+
y, z = Y[[t]][:, None], Z[[t]]
|
90 |
+
zj, null_zj, cond_z = sample(z)
|
91 |
+
|
92 |
+
u = np.concatenate([y, zj, cond_z], axis=-1)
|
93 |
+
null_u = np.concatenate([y, null_zj, cond_z], axis=-1)
|
94 |
+
|
95 |
+
payoff = self.payoff.compute(u, null_u, prev_d)
|
96 |
+
self.wealth.update(payoff)
|
97 |
+
|
98 |
+
d = np.concatenate([y, zj, null_zj, cond_z], axis=-1)
|
99 |
+
prev_d = np.vstack([prev_d, d])
|
100 |
+
|
101 |
+
if self.wealth.rejected:
|
102 |
+
return (True, t)
|
103 |
+
return (False, t)
|
104 |
+
|
105 |
+
|
106 |
+
class xSKIT(SequentialTester):
|
107 |
+
"""Local Conditional Independence Tester"""
|
108 |
+
|
109 |
+
def __init__(self, config):
|
110 |
+
super().__init__(config)
|
111 |
+
self.payoff = xMMD(config)
|
112 |
+
|
113 |
+
self._queue = deque()
|
114 |
+
|
115 |
+
def _sample(
|
116 |
+
self,
|
117 |
+
z: Float[Array, "D"],
|
118 |
+
j: int,
|
119 |
+
C: list[int],
|
120 |
+
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
121 |
+
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
122 |
+
) -> Tuple[Float[Array, "1"], Float[Array, "1"]]:
|
123 |
+
|
124 |
+
if len(self._queue) == 0:
|
125 |
+
Cuj = C + [j]
|
126 |
+
|
127 |
+
h = cond_p(z, Cuj, self.tau_max)
|
128 |
+
null_h = cond_p(z, C, self.tau_max)
|
129 |
+
|
130 |
+
y = model(h)[:, None]
|
131 |
+
null_y = model(null_h)[:, None]
|
132 |
+
|
133 |
+
self._queue.extend(zip(y, null_y))
|
134 |
+
|
135 |
+
return self._queue.pop()
|
136 |
+
|
137 |
+
def test(
|
138 |
+
self,
|
139 |
+
z: Float[Array, "D"],
|
140 |
+
j: int,
|
141 |
+
C: list[int],
|
142 |
+
cond_p: Callable[[Float[Array, "D"], list[int], int], Float[Array, "N D2"]],
|
143 |
+
model: Callable[[Float[Array, "N D2"]], Float[Array, "N"]],
|
144 |
+
) -> Tuple[bool, int]:
|
145 |
+
sample = functools.partial(self._sample, z, j, C, cond_p, model)
|
146 |
+
|
147 |
+
prev_d = np.stack(sample(), axis=1)
|
148 |
+
for t in range(1, self.tau_max):
|
149 |
+
y, null_y = sample()
|
150 |
+
|
151 |
+
payoff = self.payoff.compute(y, null_y, prev_d)
|
152 |
+
self.wealth.update(payoff)
|
153 |
+
|
154 |
+
d = np.stack([y, null_y], axis=1)
|
155 |
+
prev_d = np.vstack([prev_d, d])
|
156 |
+
|
157 |
+
if self.wealth.rejected:
|
158 |
+
return (True, t)
|
159 |
+
return (False, t)
|
ibydmt/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def _register_cls(name, dict=None):
|
2 |
+
def _register(cls):
|
3 |
+
if name in dict:
|
4 |
+
raise ValueError(f"{name} is already registered")
|
5 |
+
|
6 |
+
dict[name] = cls
|
7 |
+
|
8 |
+
return _register
|
9 |
+
|
10 |
+
|
11 |
+
def _get_cls(name, dict=None):
|
12 |
+
return dict[name]
|
ibydmt/wealth.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from ibydmt.utils import _get_cls, _register_cls
|
7 |
+
|
8 |
+
|
9 |
+
class Wealth(ABC):
|
10 |
+
def __init__(self, config):
|
11 |
+
self.significance_level = config.significance_level
|
12 |
+
self.rejected = False
|
13 |
+
|
14 |
+
@abstractmethod
|
15 |
+
def update(self, payoff):
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
_WEALTH: Dict[str, Wealth] = {}
|
20 |
+
|
21 |
+
|
22 |
+
def register_wealth(name):
|
23 |
+
return _register_cls(name, dict=_WEALTH)
|
24 |
+
|
25 |
+
|
26 |
+
def get_wealth(name):
|
27 |
+
return _get_cls(name, dict=_WEALTH)
|
28 |
+
|
29 |
+
|
30 |
+
@register_wealth("mixture")
|
31 |
+
class Mixture(Wealth):
|
32 |
+
def __init__(self, config):
|
33 |
+
super().__init__(config)
|
34 |
+
|
35 |
+
self.grid_size = grid_size = config.grid_size
|
36 |
+
self.wealth = np.ones((grid_size,))
|
37 |
+
self.wealth_flag = np.ones(grid_size, dtype=bool)
|
38 |
+
self.v = np.linspace(0.05, 0.95, grid_size)
|
39 |
+
|
40 |
+
def update(self, payoff):
|
41 |
+
raise NotImplementedError
|
42 |
+
|
43 |
+
|
44 |
+
@register_wealth("ons")
|
45 |
+
class ONS(Wealth):
|
46 |
+
def __init__(self, config):
|
47 |
+
super().__init__(config)
|
48 |
+
|
49 |
+
self.w = 1.0
|
50 |
+
self.v = 0
|
51 |
+
self.a = 1
|
52 |
+
|
53 |
+
self.min_v, self.max_v = config.get("min_v", 0), config.get("max_v", 1 / 2)
|
54 |
+
self.wealth_flag = False
|
55 |
+
|
56 |
+
def _update_v(self, payoff):
|
57 |
+
z = payoff / (1 + self.v * payoff)
|
58 |
+
self.a += z**2
|
59 |
+
self.v = max(
|
60 |
+
self.min_v, min(self.max_v, self.v + 2 / (2 - np.log(3)) * z / self.a)
|
61 |
+
)
|
62 |
+
|
63 |
+
def update(self, payoff):
|
64 |
+
w = self.w * (1 + self.v * payoff)
|
65 |
+
|
66 |
+
if w >= 0 and not self.wealth_flag:
|
67 |
+
self.w = w
|
68 |
+
if self.w >= 1 / self.significance_level:
|
69 |
+
self.rejected = True
|
70 |
+
self._update_v(payoff)
|
71 |
+
else:
|
72 |
+
self.wealth_flag = True
|
requirements.txt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
streamlit-image-select
|
2 |
clip @ git+https://github.com/openai/CLIP@main
|
3 |
-
open_clip_torch
|
|
|
1 |
streamlit-image-select
|
2 |
clip @ git+https://github.com/openai/CLIP@main
|
3 |
+
open_clip_torch
|