Spaces:
Sleeping
Sleeping
Commit
·
b1beb2e
1
Parent(s):
8d12475
add demos
Browse files- app.py +101 -0
- lcpfn/.ipynb_checkpoints/__init__-checkpoint.py +53 -0
- lcpfn/.ipynb_checkpoints/curves-checkpoint.py +277 -0
- lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py +195 -0
- lcpfn/__init__.py +53 -0
- lcpfn/__pycache__/__init__.cpython-310.pyc +0 -0
- lcpfn/__pycache__/bar_distribution.cpython-310.pyc +0 -0
- lcpfn/__pycache__/curves.cpython-310.pyc +0 -0
- lcpfn/__pycache__/domhan_prior.cpython-310.pyc +0 -0
- lcpfn/__pycache__/encoders.cpython-310.pyc +0 -0
- lcpfn/__pycache__/layer.cpython-310.pyc +0 -0
- lcpfn/__pycache__/model.cpython-310.pyc +0 -0
- lcpfn/__pycache__/positional_encodings.cpython-310.pyc +0 -0
- lcpfn/__pycache__/train.cpython-310.pyc +0 -0
- lcpfn/__pycache__/train_lcpfn.cpython-310.pyc +0 -0
- lcpfn/__pycache__/transformer.cpython-310.pyc +0 -0
- lcpfn/__pycache__/utils.cpython-310.pyc +0 -0
- lcpfn/bar_distribution.py +269 -0
- lcpfn/curves.py +277 -0
- lcpfn/decoders.py +30 -0
- lcpfn/domhan_prior.py +195 -0
- lcpfn/encoders.py +161 -0
- lcpfn/initializers.py +9 -0
- lcpfn/layer.py +126 -0
- lcpfn/model.py +29 -0
- lcpfn/positional_encodings.py +70 -0
- lcpfn/priors/__init__.py +1 -0
- lcpfn/priors/__pycache__/__init__.cpython-310.pyc +0 -0
- lcpfn/priors/__pycache__/gp.cpython-310.pyc +0 -0
- lcpfn/priors/__pycache__/prior.cpython-310.pyc +0 -0
- lcpfn/priors/__pycache__/ridge.cpython-310.pyc +0 -0
- lcpfn/priors/__pycache__/utils.cpython-310.pyc +0 -0
- lcpfn/priors/binarized_regression.py +19 -0
- lcpfn/priors/fast_gp.py +143 -0
- lcpfn/priors/fast_gp_mix.py +394 -0
- lcpfn/priors/gp.py +69 -0
- lcpfn/priors/prior.py +25 -0
- lcpfn/priors/pyro.py +41 -0
- lcpfn/priors/ridge.py +37 -0
- lcpfn/priors/stroke.py +143 -0
- lcpfn/priors/utils.py +151 -0
- lcpfn/train.py +602 -0
- lcpfn/train_lcpfn.py +92 -0
- lcpfn/transformer.py +226 -0
- lcpfn/utils.py +258 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import lcpfn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
model = lcpfn.LCPFN()
|
9 |
+
|
10 |
+
def line_plot_fn(data, cutoff, ci_form):
|
11 |
+
cutoff = int(cutoff)
|
12 |
+
ci = int(ci_form)
|
13 |
+
|
14 |
+
empty_values = list(data[data.y == ""].index)
|
15 |
+
|
16 |
+
if len(empty_values) > 0:
|
17 |
+
if (len(empty_values) == 1 and empty_values[0] != 49) or (len(empty_values) > 1 and not all(y-x==1 for x,y in zip(empty_values, empty_values[1:]))):
|
18 |
+
raise gr.Error("Please enter a valid learning curve.")
|
19 |
+
else:
|
20 |
+
data = data[data.y != ""]
|
21 |
+
|
22 |
+
if len(data) < cutoff:
|
23 |
+
raise gr.Error(f"Cutoff ({cutoff}) cannot be greater than the number of data points ({len(data)}).")
|
24 |
+
|
25 |
+
try:
|
26 |
+
data["y"] = data["y"].astype(float)
|
27 |
+
except:
|
28 |
+
raise gr.Error("Please enter a valid learning curve.")
|
29 |
+
|
30 |
+
x = torch.arange(1, 51).unsqueeze(1)
|
31 |
+
y = torch.from_numpy(data.y.values).float().unsqueeze(1)
|
32 |
+
|
33 |
+
rest_prob = (1 - (ci / 100)) / 2
|
34 |
+
predictions = model.predict_quantiles(x_train=x[:cutoff], y_train=y[:cutoff], x_test=x[(cutoff-1):], qs=[rest_prob, 0.5, 1-rest_prob])
|
35 |
+
|
36 |
+
fig, ax = plt.subplots()
|
37 |
+
|
38 |
+
ax.plot(x, data.y, "black", label="target")
|
39 |
+
|
40 |
+
# plot extrapolation
|
41 |
+
ax.plot(x[(cutoff-1):], predictions[:, 1], "blue", label="Extrapolation by PFN")
|
42 |
+
ax.fill_between(
|
43 |
+
x[(cutoff-1):].flatten(), predictions[:, 0], predictions[:, 2], color="blue", alpha=0.2, label="CI of 90%"
|
44 |
+
)
|
45 |
+
|
46 |
+
# plot cutoff
|
47 |
+
ax.vlines(cutoff, 0, 1, linewidth=0.5, color="k", label="cutoff", linestyles="dashed")
|
48 |
+
ax.set_ylim(0, 1)
|
49 |
+
ax.set_xlim(0, 50)
|
50 |
+
ax.legend(loc="lower right")
|
51 |
+
ax.set_xlabel("t")
|
52 |
+
ax.set_ylabel("y")
|
53 |
+
|
54 |
+
return fig
|
55 |
+
|
56 |
+
prior = lcpfn.sample_from_prior(np.random)
|
57 |
+
curve, _ = prior()
|
58 |
+
|
59 |
+
examples = []
|
60 |
+
for _ in range(10):
|
61 |
+
prior = lcpfn.sample_from_prior(np.random)
|
62 |
+
curve, _ = prior()
|
63 |
+
if np.random.rand() < 0.5:
|
64 |
+
curve = _
|
65 |
+
df = pd.DataFrame.from_records(curve[:50][..., np.newaxis], columns=["y"])
|
66 |
+
df["t"] = [i for i in range(1, 50 + 1)]
|
67 |
+
examples.append([df[["t", "y"]], 10])
|
68 |
+
|
69 |
+
with gr.Column() as components:
|
70 |
+
gr.Number(value=10)
|
71 |
+
gr.Number(value=10)
|
72 |
+
|
73 |
+
with gr.Blocks() as demo:
|
74 |
+
with gr.Row():
|
75 |
+
with gr.Column():
|
76 |
+
dataform = gr.Dataframe(
|
77 |
+
value=examples[0][0],
|
78 |
+
headers=["t", "y"],
|
79 |
+
datatype=["number", "number"],
|
80 |
+
row_count=(50, "fixed"),
|
81 |
+
col_count=(2, "fixed"),
|
82 |
+
type="pandas",
|
83 |
+
)
|
84 |
+
with gr.Row():
|
85 |
+
cutoffform = gr.Number(label="cutoff", value=10)
|
86 |
+
ci_form = gr.Dropdown(label="Confidence Interval", choices=[
|
87 |
+
("90%", 90),
|
88 |
+
("95%", 95),
|
89 |
+
("99%", 99)
|
90 |
+
], value=90)
|
91 |
+
btn = gr.Button("Run")
|
92 |
+
outputform = gr.Plot()
|
93 |
+
btn.click(fn=line_plot_fn, inputs=[dataform, cutoffform, ci_form], outputs=outputform)
|
94 |
+
gr.Examples(examples, inputs=[dataform], label="Examples of synthetic learning curves")
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
demo.launch()
|
101 |
+
|
lcpfn/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
3 |
+
|
4 |
+
|
5 |
+
model_path = 'trained_models'
|
6 |
+
|
7 |
+
def prepare_models():
|
8 |
+
pfns4bo_dir = os.path.dirname(__file__)
|
9 |
+
model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
|
10 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
|
11 |
+
|
12 |
+
for name in model_names:
|
13 |
+
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
14 |
+
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
|
15 |
+
if not os.path.exists(weights_path):
|
16 |
+
if not os.path.exists(compressed_weights_path):
|
17 |
+
print("Downloading", os.path.abspath(compressed_weights_path))
|
18 |
+
import requests
|
19 |
+
url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
|
20 |
+
r = requests.get(url, allow_redirects=True)
|
21 |
+
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
22 |
+
with open(compressed_weights_path, 'wb') as f:
|
23 |
+
f.write(r.content)
|
24 |
+
if os.path.exists(compressed_weights_path):
|
25 |
+
print("Unzipping", name)
|
26 |
+
os.system(f"gzip -dk {compressed_weights_path}")
|
27 |
+
else:
|
28 |
+
print("Failed to find", compressed_weights_path)
|
29 |
+
print("Make sure you have an internet connection to download the model automatically..")
|
30 |
+
if os.path.exists(weights_path):
|
31 |
+
print("Successfully located model at", weights_path)
|
32 |
+
|
33 |
+
|
34 |
+
model_dict = {
|
35 |
+
'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
36 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
|
37 |
+
'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
38 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def __getattr__(name):
|
43 |
+
if name in model_dict:
|
44 |
+
if not os.path.exists(model_dict[name]):
|
45 |
+
print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
|
46 |
+
print("This might take a while..")
|
47 |
+
prepare_models()
|
48 |
+
return model_dict[name]
|
49 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
50 |
+
|
51 |
+
from lcpfn.model import LCPFN
|
52 |
+
from lcpfn.train_lcpfn import train_lcpfn
|
53 |
+
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
lcpfn/.ipynb_checkpoints/curves-checkpoint.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
prior = {
|
5 |
+
"pow3": {
|
6 |
+
"uniform": OrderedDict(
|
7 |
+
a={"type": "uniform", "param1": -1, "param2": 1},
|
8 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
9 |
+
alpha={"type": "uniform", "param1": 0, "param2": 1},
|
10 |
+
),
|
11 |
+
"peaked": OrderedDict(
|
12 |
+
a={"type": "uniform", "param1": -0.6, "param2": 0.6},
|
13 |
+
c={"type": "uniform", "param1": 0, "param2": 1.25},
|
14 |
+
alpha={"type": "log_normal", "param1": 0, "param2": 2},
|
15 |
+
),
|
16 |
+
},
|
17 |
+
"ilog2": {
|
18 |
+
"uniform": OrderedDict(
|
19 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
20 |
+
a={"type": "uniform", "param1": -1, "param2": 1},
|
21 |
+
),
|
22 |
+
"peaked": OrderedDict(
|
23 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
24 |
+
a={"type": "uniform", "param1": -0.5, "param2": 0.5},
|
25 |
+
),
|
26 |
+
},
|
27 |
+
"janoschek": {
|
28 |
+
"uniform": OrderedDict(
|
29 |
+
a={"type": "uniform", "param1": 0, "param2": 1},
|
30 |
+
beta={"type": "uniform", "param1": 0, "param2": 2},
|
31 |
+
k={"type": "uniform", "param1": 0, "param2": 1},
|
32 |
+
delta={"type": "uniform", "param1": -5, "param2": 5},
|
33 |
+
),
|
34 |
+
"peaked": OrderedDict(
|
35 |
+
a={"type": "uniform", "param1": 0, "param2": 1},
|
36 |
+
beta={"type": "uniform", "param1": 0, "param2": 2},
|
37 |
+
k={"type": "log_normal", "param1": -2, "param2": 1},
|
38 |
+
delta={"type": "log_normal", "param1": 0, "param2": 0.5},
|
39 |
+
),
|
40 |
+
},
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def prior_sampler(rng, type, param1, param2):
|
45 |
+
if type == "uniform":
|
46 |
+
return rng.uniform(param1, param2)
|
47 |
+
elif type == "log_normal":
|
48 |
+
return rng.lognormal(param1, param2)
|
49 |
+
raise Exception("Unknown prior type: {}".format(type))
|
50 |
+
|
51 |
+
|
52 |
+
def pow3(x, c, a, alpha):
|
53 |
+
return c - a * (x) ** (-alpha)
|
54 |
+
|
55 |
+
|
56 |
+
def prior_pow3(rng):
|
57 |
+
return {
|
58 |
+
p: prior_sampler(
|
59 |
+
rng,
|
60 |
+
prior["pow3"]["peaked"][p]["type"],
|
61 |
+
param1=prior["pow3"]["peaked"][p]["param1"],
|
62 |
+
param2=prior["pow3"]["peaked"][p]["param2"],
|
63 |
+
)
|
64 |
+
for p in ["a", "c", "alpha"]
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def uniform_prior_pow3(rng):
|
69 |
+
return {
|
70 |
+
p: prior_sampler(
|
71 |
+
rng,
|
72 |
+
prior["pow3"]["uniform"][p]["type"],
|
73 |
+
param1=prior["pow3"]["uniform"][p]["param1"],
|
74 |
+
param2=prior["pow3"]["uniform"][p]["param2"],
|
75 |
+
)
|
76 |
+
for p in ["a", "c", "alpha"]
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def ilog2(x, c, a):
|
81 |
+
return c - a / (np.log(x + 1))
|
82 |
+
|
83 |
+
|
84 |
+
def prior_ilog2(rng):
|
85 |
+
return {
|
86 |
+
p: prior_sampler(
|
87 |
+
rng,
|
88 |
+
prior["ilog2"]["peaked"][p]["type"],
|
89 |
+
param1=prior["ilog2"]["peaked"][p]["param1"],
|
90 |
+
param2=prior["ilog2"]["peaked"][p]["param2"],
|
91 |
+
)
|
92 |
+
for p in ["a", "c"]
|
93 |
+
}
|
94 |
+
|
95 |
+
|
96 |
+
def uniform_prior_ilog2(rng):
|
97 |
+
return {
|
98 |
+
p: prior_sampler(
|
99 |
+
rng,
|
100 |
+
prior["ilog2"]["uniform"][p]["type"],
|
101 |
+
param1=prior["ilog2"]["uniform"][p]["param1"],
|
102 |
+
param2=prior["ilog2"]["uniform"][p]["param2"],
|
103 |
+
)
|
104 |
+
for p in ["a", "c"]
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
def janoschek(x, a, beta, k, delta):
|
109 |
+
"""
|
110 |
+
http://www.pisces-conservation.com/growthhelp/janoschek.htm
|
111 |
+
"""
|
112 |
+
return a - (a - beta) * np.exp(-k * x**delta)
|
113 |
+
|
114 |
+
|
115 |
+
def prior_janoschek(rng):
|
116 |
+
return {
|
117 |
+
p: prior_sampler(
|
118 |
+
rng,
|
119 |
+
prior["janoschek"]["peaked"][p]["type"],
|
120 |
+
param1=prior["janoschek"]["peaked"][p]["param1"],
|
121 |
+
param2=prior["janoschek"]["peaked"][p]["param2"],
|
122 |
+
)
|
123 |
+
for p in ["a", "beta", "k", "delta"]
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
def uniform_prior_janoschek(rng):
|
128 |
+
return {
|
129 |
+
p: prior_sampler(
|
130 |
+
rng,
|
131 |
+
prior["janoschek"]["uniform"][p]["type"],
|
132 |
+
param1=prior["janoschek"]["uniform"][p]["param1"],
|
133 |
+
param2=prior["janoschek"]["uniform"][p]["param2"],
|
134 |
+
)
|
135 |
+
for p in ["a", "beta", "k", "delta"]
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
def log_power(x, a, b, c):
|
140 |
+
# a: upper bound
|
141 |
+
# c: growth rate
|
142 |
+
# initial = a/ (1 + (1/e^b)^c
|
143 |
+
return a / (1.0 + (x / np.exp(b)) ** c)
|
144 |
+
|
145 |
+
|
146 |
+
def prior_log_power(rng):
|
147 |
+
# a ~ N(0.8,0.1)
|
148 |
+
# b ~ N(1,1)
|
149 |
+
# c ~ U(-3,0)
|
150 |
+
a = rng.normal(0.8, 0.1)
|
151 |
+
b = rng.normal(1.0, 1.0)
|
152 |
+
c = rng.uniform(-3.0, 0.0)
|
153 |
+
return {"a": a, "b": b, "c": c}
|
154 |
+
|
155 |
+
|
156 |
+
def weibull(x, alpha, beta, kappa, delta):
|
157 |
+
"""
|
158 |
+
Weibull modell
|
159 |
+
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
160 |
+
alpha: upper asymptote
|
161 |
+
beta: lower asymptote
|
162 |
+
k: growth rate
|
163 |
+
delta: controls the x-ordinate for the point of inflection
|
164 |
+
"""
|
165 |
+
return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
|
166 |
+
|
167 |
+
|
168 |
+
def prior_weibull(rng):
|
169 |
+
alpha = rng.uniform(0.0, 1.5)
|
170 |
+
beta = rng.uniform(0.0, 1)
|
171 |
+
kappa = np.exp(rng.normal(-2.0, 1.0))
|
172 |
+
delta = np.exp(rng.normal(0, 0.5))
|
173 |
+
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
174 |
+
|
175 |
+
|
176 |
+
def mmf(x, alpha, beta, kappa, delta):
|
177 |
+
"""
|
178 |
+
Morgan-Mercer-Flodin
|
179 |
+
description:
|
180 |
+
Nonlinear Regression page 342
|
181 |
+
http://bit.ly/1jodG17
|
182 |
+
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
183 |
+
alpha: upper asymptote
|
184 |
+
kappa: growth rate
|
185 |
+
beta: initial value
|
186 |
+
delta: controls the point of inflection
|
187 |
+
"""
|
188 |
+
return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
|
189 |
+
|
190 |
+
|
191 |
+
def prior_mmf(rng):
|
192 |
+
# alpha ~ N(0.8,0.1)
|
193 |
+
# beta ~ N(0.2,0.1)
|
194 |
+
# ln(kappa) ~ N(0,2)
|
195 |
+
# ln(delta) ~ N(0,1)
|
196 |
+
alpha = rng.normal(0.8, 0.1)
|
197 |
+
beta = rng.normal(0.2, 0.1)
|
198 |
+
kappa = np.exp(rng.normal(0, 2))
|
199 |
+
delta = np.exp(rng.normal(0, 1))
|
200 |
+
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
201 |
+
|
202 |
+
|
203 |
+
def vap(x, a, b, c):
|
204 |
+
"""Vapor pressure model"""
|
205 |
+
# no upper bound if c > 0
|
206 |
+
# a = ln(upper bound) for c=0
|
207 |
+
# a+b = ln(initial)
|
208 |
+
return np.exp(a + b / x + c * np.log(x))
|
209 |
+
|
210 |
+
|
211 |
+
def prior_vap(rng):
|
212 |
+
a = rng.uniform(-2.0, 0.0) # @heri: range check
|
213 |
+
b = rng.uniform(-4.0, 0.0) # @heri: range check
|
214 |
+
c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
|
215 |
+
return {"a": a, "b": b, "c": c}
|
216 |
+
|
217 |
+
|
218 |
+
def loglog_linear(x, a, b):
|
219 |
+
x = np.log(x)
|
220 |
+
return np.log(a * x + b)
|
221 |
+
|
222 |
+
|
223 |
+
def prior_loglog_linear(rng):
|
224 |
+
# ln(a) ~ N(-2, 1)
|
225 |
+
# ln(b) ~ U(0, 1)
|
226 |
+
a = np.exp(rng.normal(-2.0, 1.0))
|
227 |
+
b = np.exp(rng.uniform(0.0, 1.0))
|
228 |
+
return {"a": a, "b": b}
|
229 |
+
|
230 |
+
|
231 |
+
def exp4(x, c, a, b, alpha):
|
232 |
+
return c - np.exp(-a * (x**alpha) + b)
|
233 |
+
|
234 |
+
|
235 |
+
def prior_exp4(rng):
|
236 |
+
# c ~ N(0.8,0.1)
|
237 |
+
c = rng.normal(0.8, 0.1)
|
238 |
+
# ln(a) ~ N(-2,1)
|
239 |
+
a = np.exp(rng.normal(-2, 1))
|
240 |
+
# ln(alpha) ~ N(0,1)
|
241 |
+
alpha = np.exp(rng.normal(0, 1))
|
242 |
+
# ln(b) ~ N(0,0.5)
|
243 |
+
b = np.exp(rng.normal(0, 0.5))
|
244 |
+
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
245 |
+
|
246 |
+
|
247 |
+
def pow4(x, c, a, b, alpha):
|
248 |
+
return c - (a * x + b) ** -alpha
|
249 |
+
|
250 |
+
|
251 |
+
def prior_pow4(rng):
|
252 |
+
# ln(1 - c) ~ U(-5, 0)
|
253 |
+
c = 1 - np.exp(rng.uniform(-5.0, 0))
|
254 |
+
# ln(a) ~ N(-3, 2)
|
255 |
+
a = np.exp(rng.normal(-3.0, 2))
|
256 |
+
# ln(alpha) ~ N(0,1)
|
257 |
+
alpha = np.exp(rng.normal(0, 1))
|
258 |
+
# ln(b) ~ U(0, 1)
|
259 |
+
b = np.exp(rng.uniform(0, 1))
|
260 |
+
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
261 |
+
|
262 |
+
|
263 |
+
def dr_hill_zero_background(x, theta, eta, kappa):
|
264 |
+
# theta: upper bound
|
265 |
+
# eta: growth rate
|
266 |
+
# initial = theta/(kappa^eta + 1)
|
267 |
+
return (theta * x**eta) / (kappa**eta + x**eta)
|
268 |
+
|
269 |
+
|
270 |
+
def prior_dr_hill_zero_background(rng):
|
271 |
+
# theta ~ U(1,0) N(0.8,0.1)
|
272 |
+
# ln(eta) ~ N(1,1)
|
273 |
+
# ln(kappa) ~ N(1,2)
|
274 |
+
theta = rng.normal(0.8, 0.1)
|
275 |
+
eta = np.exp(rng.normal(1.0, 1.0))
|
276 |
+
kappa = np.exp(rng.normal(1.0, 2.0))
|
277 |
+
return {"theta": theta, "eta": eta, "kappa": kappa}
|
lcpfn/.ipynb_checkpoints/domhan_prior-checkpoint.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from lcpfn.curves import (
|
5 |
+
pow3,
|
6 |
+
ilog2,
|
7 |
+
janoschek,
|
8 |
+
log_power,
|
9 |
+
prior_ilog2,
|
10 |
+
uniform_prior_pow3,
|
11 |
+
weibull,
|
12 |
+
mmf,
|
13 |
+
vap,
|
14 |
+
loglog_linear,
|
15 |
+
exp4,
|
16 |
+
pow4,
|
17 |
+
dr_hill_zero_background,
|
18 |
+
)
|
19 |
+
from lcpfn.curves import (
|
20 |
+
prior_pow3,
|
21 |
+
prior_janoschek,
|
22 |
+
prior_log_power,
|
23 |
+
prior_weibull,
|
24 |
+
prior_mmf,
|
25 |
+
prior_vap,
|
26 |
+
prior_loglog_linear,
|
27 |
+
prior_exp4,
|
28 |
+
prior_pow4,
|
29 |
+
prior_dr_hill_zero_background,
|
30 |
+
)
|
31 |
+
from lcpfn.curves import (
|
32 |
+
uniform_prior_pow3,
|
33 |
+
uniform_prior_ilog2,
|
34 |
+
uniform_prior_janoschek,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def prior_weights(
|
39 |
+
rng,
|
40 |
+
components=[
|
41 |
+
"pow3",
|
42 |
+
"ilog2",
|
43 |
+
"janoschek",
|
44 |
+
"log_power",
|
45 |
+
"weibull",
|
46 |
+
"mmf",
|
47 |
+
"vap",
|
48 |
+
"loglog_linear",
|
49 |
+
"exp4",
|
50 |
+
"pow4",
|
51 |
+
"dr_hill_zero_background",
|
52 |
+
],
|
53 |
+
):
|
54 |
+
K = len(components)
|
55 |
+
weights = rng.uniform(0.0, 1, size=(K,))
|
56 |
+
return {f: weights[i] for i, f in enumerate(components)}
|
57 |
+
|
58 |
+
|
59 |
+
def sample_from_prior(rng, seq_len=100):
|
60 |
+
return sample_prior_comb(
|
61 |
+
rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def sample_prior_comb(
|
66 |
+
rng,
|
67 |
+
components,
|
68 |
+
distribution,
|
69 |
+
var_lnloc=-4,
|
70 |
+
var_lnscale=1,
|
71 |
+
range_constraint=True,
|
72 |
+
seq_len=100,
|
73 |
+
):
|
74 |
+
f_components = {
|
75 |
+
"pow3": pow3,
|
76 |
+
"ilog2": ilog2,
|
77 |
+
"janoschek": janoschek,
|
78 |
+
"log_power": log_power,
|
79 |
+
"weibull": weibull,
|
80 |
+
"mmf": mmf,
|
81 |
+
"vap": vap,
|
82 |
+
"loglog_linear": loglog_linear,
|
83 |
+
"exp4": exp4,
|
84 |
+
"pow4": pow4,
|
85 |
+
"dr_hill_zero_background": dr_hill_zero_background,
|
86 |
+
}
|
87 |
+
|
88 |
+
if distribution == "peaked":
|
89 |
+
f_priors = {
|
90 |
+
"pow3": prior_pow3,
|
91 |
+
"ilog2": prior_ilog2,
|
92 |
+
"janoschek": prior_janoschek,
|
93 |
+
"log_power": prior_log_power,
|
94 |
+
"weibull": prior_weibull,
|
95 |
+
"mmf": prior_mmf,
|
96 |
+
"vap": prior_vap,
|
97 |
+
"loglog_linear": prior_loglog_linear,
|
98 |
+
"exp4": prior_exp4,
|
99 |
+
"pow4": prior_pow4,
|
100 |
+
"dr_hill_zero_background": prior_dr_hill_zero_background,
|
101 |
+
}
|
102 |
+
elif distribution == "uniform":
|
103 |
+
f_priors = {
|
104 |
+
"pow3": uniform_prior_pow3,
|
105 |
+
"ilog2": uniform_prior_ilog2,
|
106 |
+
"janoschek": uniform_prior_janoschek
|
107 |
+
}
|
108 |
+
else:
|
109 |
+
raise NotImplemented()
|
110 |
+
|
111 |
+
x = np.arange(1, seq_len + 1)
|
112 |
+
|
113 |
+
while True:
|
114 |
+
# sample the noiseless curve
|
115 |
+
weights = prior_weights(rng, components=components)
|
116 |
+
y = np.zeros(x.shape, dtype="float")
|
117 |
+
kwargs = 0
|
118 |
+
for f, w in weights.items():
|
119 |
+
kwargs = f_priors[f](rng)
|
120 |
+
# print(f_components[f](x, **kwargs))
|
121 |
+
y += w * f_components[f](x, **kwargs)
|
122 |
+
# add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
|
123 |
+
var = np.exp(
|
124 |
+
rng.normal(var_lnloc, var_lnscale)
|
125 |
+
) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
|
126 |
+
|
127 |
+
# reject any curves that are non-increasing, exceed the [0,1] range
|
128 |
+
if (
|
129 |
+
y[-1] <= y[0]
|
130 |
+
or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
|
131 |
+
or np.isnan(y).any()
|
132 |
+
):
|
133 |
+
continue
|
134 |
+
else:
|
135 |
+
break
|
136 |
+
|
137 |
+
def curve(): # generates a sample from the same model, but with independent noise
|
138 |
+
y_noisy = y + rng.normal(np.zeros_like(y), var)
|
139 |
+
return y, y_noisy
|
140 |
+
|
141 |
+
return curve
|
142 |
+
|
143 |
+
|
144 |
+
def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
|
145 |
+
"""
|
146 |
+
Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
|
147 |
+
"""
|
148 |
+
rng = np.random.RandomState(seed)
|
149 |
+
prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
|
150 |
+
return prior_data
|
151 |
+
|
152 |
+
|
153 |
+
def create_get_batch_func(prior):
|
154 |
+
return partial(get_batch_domhan, prior=prior)
|
155 |
+
|
156 |
+
# function producing batches for PFN training
|
157 |
+
def get_batch_domhan(
|
158 |
+
batch_size,
|
159 |
+
seq_len,
|
160 |
+
num_features,
|
161 |
+
prior,
|
162 |
+
device="cpu",
|
163 |
+
noisy_target=True,
|
164 |
+
**_,
|
165 |
+
):
|
166 |
+
assert num_features == 1
|
167 |
+
|
168 |
+
x = np.arange(1, seq_len + 1)
|
169 |
+
y_target = np.empty((batch_size, seq_len), dtype=float)
|
170 |
+
y_noisy = np.empty((batch_size, seq_len), dtype=float)
|
171 |
+
|
172 |
+
for i in range(batch_size):
|
173 |
+
curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
|
174 |
+
if noisy_target:
|
175 |
+
_, y_noisy[i] = curve_func()
|
176 |
+
y_target[i] = y_noisy[i]
|
177 |
+
else:
|
178 |
+
y_target[i], y_noisy[i] = curve_func()
|
179 |
+
|
180 |
+
# turn numpy arrays into correctly shaped torch tensors & move them to device
|
181 |
+
x = (
|
182 |
+
torch.arange(1, seq_len + 1)
|
183 |
+
.repeat((num_features, batch_size, 1))
|
184 |
+
.transpose(2, 0)
|
185 |
+
.to(device)
|
186 |
+
)
|
187 |
+
y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
|
188 |
+
y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
|
189 |
+
|
190 |
+
# changes
|
191 |
+
x = x.float()
|
192 |
+
y_target = y_target.float()
|
193 |
+
y_noisy = y_noisy.float()
|
194 |
+
|
195 |
+
return x, y_noisy, y_target
|
lcpfn/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
3 |
+
|
4 |
+
|
5 |
+
model_path = 'trained_models'
|
6 |
+
|
7 |
+
def prepare_models():
|
8 |
+
pfns4bo_dir = os.path.dirname(__file__)
|
9 |
+
model_names = ['pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt',
|
10 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt']
|
11 |
+
|
12 |
+
for name in model_names:
|
13 |
+
weights_path = os.path.join(pfns4bo_dir, model_path, name)
|
14 |
+
compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
|
15 |
+
if not os.path.exists(weights_path):
|
16 |
+
if not os.path.exists(compressed_weights_path):
|
17 |
+
print("Downloading", os.path.abspath(compressed_weights_path))
|
18 |
+
import requests
|
19 |
+
url = f'https://github.com/automl/lcpfn/raw/main/lcpfn/trained_models/{name + ".gz"}'
|
20 |
+
r = requests.get(url, allow_redirects=True)
|
21 |
+
os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
|
22 |
+
with open(compressed_weights_path, 'wb') as f:
|
23 |
+
f.write(r.content)
|
24 |
+
if os.path.exists(compressed_weights_path):
|
25 |
+
print("Unzipping", name)
|
26 |
+
os.system(f"gzip -dk {compressed_weights_path}")
|
27 |
+
else:
|
28 |
+
print("Failed to find", compressed_weights_path)
|
29 |
+
print("Make sure you have an internet connection to download the model automatically..")
|
30 |
+
if os.path.exists(weights_path):
|
31 |
+
print("Successfully located model at", weights_path)
|
32 |
+
|
33 |
+
|
34 |
+
model_dict = {
|
35 |
+
'EMSIZE512_NLAYERS12_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
36 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS12_NBUCKETS1000.pt'),
|
37 |
+
'EMSIZE512_NLAYERS6_NBUCKETS1000': os.path.join(os.path.dirname(__file__),model_path,
|
38 |
+
'pfn_EPOCH1000_EMSIZE512_NLAYERS6_NBUCKETS1000.pt'),
|
39 |
+
}
|
40 |
+
|
41 |
+
|
42 |
+
def __getattr__(name):
|
43 |
+
if name in model_dict:
|
44 |
+
if not os.path.exists(model_dict[name]):
|
45 |
+
print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
|
46 |
+
print("This might take a while..")
|
47 |
+
prepare_models()
|
48 |
+
return model_dict[name]
|
49 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
50 |
+
|
51 |
+
from lcpfn.model import LCPFN
|
52 |
+
from lcpfn.train_lcpfn import train_lcpfn
|
53 |
+
from lcpfn.domhan_prior import sample_from_prior, create_get_batch_func
|
lcpfn/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.03 kB). View file
|
|
lcpfn/__pycache__/bar_distribution.cpython-310.pyc
ADDED
Binary file (9.96 kB). View file
|
|
lcpfn/__pycache__/curves.cpython-310.pyc
ADDED
Binary file (6.81 kB). View file
|
|
lcpfn/__pycache__/domhan_prior.cpython-310.pyc
ADDED
Binary file (3.92 kB). View file
|
|
lcpfn/__pycache__/encoders.cpython-310.pyc
ADDED
Binary file (8.02 kB). View file
|
|
lcpfn/__pycache__/layer.cpython-310.pyc
ADDED
Binary file (4.64 kB). View file
|
|
lcpfn/__pycache__/model.cpython-310.pyc
ADDED
Binary file (1.8 kB). View file
|
|
lcpfn/__pycache__/positional_encodings.cpython-310.pyc
ADDED
Binary file (2.86 kB). View file
|
|
lcpfn/__pycache__/train.cpython-310.pyc
ADDED
Binary file (13.5 kB). View file
|
|
lcpfn/__pycache__/train_lcpfn.cpython-310.pyc
ADDED
Binary file (2.82 kB). View file
|
|
lcpfn/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (8.04 kB). View file
|
|
lcpfn/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (10.7 kB). View file
|
|
lcpfn/bar_distribution.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class BarDistribution(nn.Module):
|
6 |
+
def __init__(self, borders: torch.Tensor, smoothing=.0): # here borders should start with min and end with max, where all values lie in (min,max) and are sorted
|
7 |
+
# sorted list of borders
|
8 |
+
super().__init__()
|
9 |
+
assert len(borders.shape) == 1
|
10 |
+
#self.borders = borders
|
11 |
+
self.register_buffer('borders', borders)
|
12 |
+
self.register_buffer('smoothing', torch.tensor(smoothing))
|
13 |
+
#self.bucket_widths = self.borders[1:] - self.borders[:-1]
|
14 |
+
self.register_buffer('bucket_widths', self.borders[1:] - self.borders[:-1])
|
15 |
+
full_width = self.bucket_widths.sum()
|
16 |
+
border_order = torch.argsort(borders)
|
17 |
+
assert (full_width - (self.borders[-1] - self.borders[0])).abs() < 1e-4, f'diff: {full_width - (self.borders[-1] - self.borders[0])}'
|
18 |
+
assert (border_order == torch.arange(len(borders)).to(border_order.device)).all(), "Please provide sorted borders!"
|
19 |
+
self.num_bars = len(borders) - 1
|
20 |
+
|
21 |
+
def map_to_bucket_idx(self, y):
|
22 |
+
target_sample = torch.searchsorted(self.borders, y) - 1
|
23 |
+
target_sample[y == self.borders[0]] = 0
|
24 |
+
target_sample[y == self.borders[-1]] = self.num_bars - 1
|
25 |
+
return target_sample
|
26 |
+
|
27 |
+
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
28 |
+
target_sample = self.map_to_bucket_idx(y)
|
29 |
+
assert (target_sample >= 0).all() and (target_sample < self.num_bars).all(), f'y {y} not in support set for borders (min_y, max_y) {self.borders}'
|
30 |
+
assert logits.shape[-1] == self.num_bars, f'{logits.shape[-1]} vs {self.num_bars}'
|
31 |
+
|
32 |
+
bucket_log_probs = torch.log_softmax(logits, -1)
|
33 |
+
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
34 |
+
#print(bucket_log_probs, logits.shape)
|
35 |
+
|
36 |
+
nll_loss = -scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
|
37 |
+
|
38 |
+
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
39 |
+
smoothing = self.smoothing if self.training else 0.
|
40 |
+
loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
|
41 |
+
return loss
|
42 |
+
|
43 |
+
def mean(self, logits):
|
44 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
45 |
+
p = torch.softmax(logits, -1)
|
46 |
+
return p @ bucket_means
|
47 |
+
|
48 |
+
|
49 |
+
def icdf(self, logits, left_prob):
|
50 |
+
"""
|
51 |
+
Implementation of the quantile function
|
52 |
+
:param logits: Tensor of any shape, with the last dimension being logits
|
53 |
+
:param left_prob: float: The probability mass to the left of the result.
|
54 |
+
:return: Position with `left_prob` probability weight to the left.
|
55 |
+
"""
|
56 |
+
probs = logits.softmax(-1)
|
57 |
+
cumprobs = torch.cumsum(probs, -1)
|
58 |
+
idx = torch.searchsorted(cumprobs, left_prob * torch.ones(*cumprobs.shape[:-1], 1, device = probs.device))\
|
59 |
+
.squeeze(-1).clamp(0, cumprobs.shape[-1] - 1) # this might not do the right for outliers
|
60 |
+
cumprobs = torch.cat(
|
61 |
+
[torch.zeros(*cumprobs.shape[:-1], 1, device=logits.device), cumprobs], -1
|
62 |
+
)
|
63 |
+
|
64 |
+
rest_prob = left_prob - cumprobs.gather(-1, idx[..., None]).squeeze(-1)
|
65 |
+
left_border = self.borders[idx]
|
66 |
+
right_border = self.borders[idx+1]
|
67 |
+
return left_border + (right_border - left_border) * rest_prob / probs.gather(-1, idx[..., None]).squeeze(-1)
|
68 |
+
|
69 |
+
def quantile(self, logits, center_prob=.682):
|
70 |
+
side_probs = (1.-center_prob)/2
|
71 |
+
return torch.stack((self.icdf(logits, side_probs), self.icdf(logits, 1.-side_probs)),-1)
|
72 |
+
|
73 |
+
def ucb(self, logits, best_f, rest_prob=(1-.682)/2, maximize=True):
|
74 |
+
"""
|
75 |
+
UCB utility. Rest Prob is the amount of utility above (below) the confidence interval that is ignored.
|
76 |
+
Higher rest_prob is equivalent to lower beta in the standard GP-UCB formulation.
|
77 |
+
:param logits: Logits, as returned by the Transformer.
|
78 |
+
:param best_f: Only here, since the other utilities have it.
|
79 |
+
:param rest_prob: The amount of utility above (below) the confidence interval that is ignored.
|
80 |
+
The default is equivalent to using GP-UCB with `beta=1`.
|
81 |
+
To get the corresponding `beta`, where `beta` is from
|
82 |
+
the standard GP definition of UCB `ucb_utility = mean + beta * std`,
|
83 |
+
you can use this computation: `beta = math.sqrt(2)*torch.erfinv(torch.tensor(2*rest_prob-1))`.
|
84 |
+
:param maximize:
|
85 |
+
:return: utility
|
86 |
+
"""
|
87 |
+
if maximize:
|
88 |
+
rest_prob = 1 - rest_prob
|
89 |
+
return self.icdf(logits, rest_prob)
|
90 |
+
|
91 |
+
def mode(self, logits):
|
92 |
+
mode_inds = logits.argmax(-1)
|
93 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
94 |
+
return bucket_means[mode_inds]
|
95 |
+
|
96 |
+
def ei(self, logits, best_f, maximize=True): # logits: evaluation_points x batch x feature_dim
|
97 |
+
bucket_means = self.borders[:-1] + self.bucket_widths/2
|
98 |
+
if maximize:
|
99 |
+
bucket_contributions = torch.tensor(
|
100 |
+
[max((bucket_max + max(bucket_min, best_f)) / 2 - best_f,0) for
|
101 |
+
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
|
102 |
+
else:
|
103 |
+
bucket_contributions = torch.tensor(
|
104 |
+
[-min((min(bucket_max,best_f) + bucket_min) / 2 - best_f,0) for # min on max instead of max on min, and compare min < instead of max >
|
105 |
+
bucket_min, bucket_max, bucket_mean in zip(self.borders[:-1], self.borders[1:], bucket_means)], dtype=logits.dtype, device=logits.device)
|
106 |
+
p = torch.softmax(logits, -1)
|
107 |
+
return p @ bucket_contributions
|
108 |
+
|
109 |
+
def pi(self, logits, best_f, maximize=True):# logits: evaluation_points x batch x feature_dim
|
110 |
+
"""
|
111 |
+
Acquisition Function: Probability of Improvement
|
112 |
+
:param logits: as returned by Transformer
|
113 |
+
:param best_f: best evaluation so far (the incumbent)
|
114 |
+
:param maximize: whether to maximize
|
115 |
+
:return: utility
|
116 |
+
"""
|
117 |
+
assert maximize is True
|
118 |
+
p = torch.softmax(logits, -1)
|
119 |
+
border_widths = self.borders[1:] - self.borders[:-1]
|
120 |
+
factor = 1. - ((best_f - self.borders[:-1]) / border_widths).clamp(0., 1.)
|
121 |
+
return (p * factor).sum(-1)
|
122 |
+
|
123 |
+
|
124 |
+
def mean_of_square(self, logits):
|
125 |
+
"""
|
126 |
+
Computes E[x^2].
|
127 |
+
:param logits: Output of the model.
|
128 |
+
"""
|
129 |
+
left_borders = self.borders[:-1]
|
130 |
+
right_borders = self.borders[1:]
|
131 |
+
bucket_mean_of_square = (left_borders.square() + right_borders.square() + left_borders*right_borders)/3.
|
132 |
+
p = torch.softmax(logits, -1)
|
133 |
+
return p @ bucket_mean_of_square
|
134 |
+
|
135 |
+
def variance(self, logits):
|
136 |
+
return self.mean_of_square(logits) - self.mean(logits).square()
|
137 |
+
|
138 |
+
|
139 |
+
class FullSupportBarDistribution(BarDistribution):
|
140 |
+
@staticmethod
|
141 |
+
def halfnormal_with_p_weight_before(range_max,p=.5):
|
142 |
+
s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
|
143 |
+
return torch.distributions.HalfNormal(s)
|
144 |
+
|
145 |
+
def forward(self, logits, y): # gives the negative log density (the _loss_), y: T x B, logits: T x B x self.num_bars
|
146 |
+
assert self.num_bars > 1
|
147 |
+
target_sample = self.map_to_bucket_idx(y)
|
148 |
+
target_sample.clamp_(0,self.num_bars-1)
|
149 |
+
assert logits.shape[-1] == self.num_bars
|
150 |
+
|
151 |
+
bucket_log_probs = torch.log_softmax(logits, -1)
|
152 |
+
scaled_bucket_log_probs = bucket_log_probs - torch.log(self.bucket_widths)
|
153 |
+
#print(bucket_log_probs, logits.shape)
|
154 |
+
log_probs = scaled_bucket_log_probs.gather(-1,target_sample.unsqueeze(-1)).squeeze(-1)
|
155 |
+
|
156 |
+
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]), self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
|
157 |
+
|
158 |
+
|
159 |
+
# TODO look over it again
|
160 |
+
log_probs[target_sample == 0] += side_normals[0].log_prob((self.borders[1]-y[target_sample == 0]).clamp(min=.00000001)) + torch.log(self.bucket_widths[0])
|
161 |
+
log_probs[target_sample == self.num_bars-1] += side_normals[1].log_prob(y[target_sample == self.num_bars-1]-self.borders[-2]) + torch.log(self.bucket_widths[-1])
|
162 |
+
|
163 |
+
nll_loss = -log_probs
|
164 |
+
|
165 |
+
smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
|
166 |
+
smoothing = self.smoothing if self.training else 0.
|
167 |
+
loss = (1. - smoothing) * nll_loss + smoothing * smooth_loss
|
168 |
+
|
169 |
+
|
170 |
+
return loss
|
171 |
+
|
172 |
+
def mean(self, logits):
|
173 |
+
bucket_means = self.borders[:-1] + self.bucket_widths / 2
|
174 |
+
p = torch.softmax(logits, -1)
|
175 |
+
side_normals = (self.halfnormal_with_p_weight_before(self.bucket_widths[0]),
|
176 |
+
self.halfnormal_with_p_weight_before(self.bucket_widths[-1]))
|
177 |
+
bucket_means[0] = -side_normals[0].mean + self.borders[1]
|
178 |
+
bucket_means[-1] = side_normals[1].mean + self.borders[-2]
|
179 |
+
return p @ bucket_means
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
def get_bucket_limits_(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=None, verbose:bool=False):
|
184 |
+
assert (ys is not None) or (full_range is not None)
|
185 |
+
if ys is not None:
|
186 |
+
ys = ys.flatten()
|
187 |
+
if len(ys) % num_outputs: ys = ys[:-(len(ys) % num_outputs)]
|
188 |
+
print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
|
189 |
+
ys_per_bucket = len(ys) // num_outputs
|
190 |
+
if full_range is None:
|
191 |
+
full_range = (ys.min(), ys.max())
|
192 |
+
else:
|
193 |
+
assert full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
194 |
+
full_range = torch.tensor(full_range)
|
195 |
+
ys_sorted, ys_order = ys.sort(0)
|
196 |
+
bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2
|
197 |
+
if verbose:
|
198 |
+
print(f'Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys.')
|
199 |
+
print(full_range)
|
200 |
+
bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)],0)
|
201 |
+
|
202 |
+
else:
|
203 |
+
class_width = (full_range[1] - full_range[0]) / num_outputs
|
204 |
+
bucket_limits = torch.cat([full_range[0] + torch.arange(num_outputs).float()*class_width, torch.tensor(full_range[1]).unsqueeze(0)], 0)
|
205 |
+
|
206 |
+
assert len(bucket_limits) - 1 == num_outputs and full_range[0] == bucket_limits[0] and full_range[-1] == bucket_limits[-1]
|
207 |
+
return bucket_limits
|
208 |
+
|
209 |
+
|
210 |
+
def get_bucket_limits(
|
211 |
+
num_outputs: int,
|
212 |
+
full_range: tuple = None,
|
213 |
+
ys: torch.Tensor = None,
|
214 |
+
verbose: bool = False,
|
215 |
+
):
|
216 |
+
assert (ys is None) != (
|
217 |
+
full_range is None
|
218 |
+
), "Either full_range or ys must be passed."
|
219 |
+
|
220 |
+
if ys is not None:
|
221 |
+
ys = ys.flatten()
|
222 |
+
ys = ys[~torch.isnan(ys)]
|
223 |
+
if len(ys) % num_outputs:
|
224 |
+
ys = ys[: -(len(ys) % num_outputs)]
|
225 |
+
print(
|
226 |
+
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
227 |
+
)
|
228 |
+
ys_per_bucket = len(ys) // num_outputs
|
229 |
+
if full_range is None:
|
230 |
+
full_range = (ys.min(), ys.max())
|
231 |
+
else:
|
232 |
+
assert (
|
233 |
+
full_range[0] <= ys.min() and full_range[1] >= ys.max()
|
234 |
+
), f"full_range {full_range} not in range of ys {ys.min(), ys.max()}"
|
235 |
+
full_range = torch.tensor(full_range)
|
236 |
+
ys_sorted, ys_order = ys.sort(0)
|
237 |
+
bucket_limits = (
|
238 |
+
ys_sorted[ys_per_bucket - 1 :: ys_per_bucket][:-1]
|
239 |
+
+ ys_sorted[ys_per_bucket::ys_per_bucket]
|
240 |
+
) / 2
|
241 |
+
if verbose:
|
242 |
+
print(
|
243 |
+
f"Using {len(ys)} y evals to estimate {num_outputs} buckets. Cut off the last {len(ys) % num_outputs} ys."
|
244 |
+
)
|
245 |
+
print(full_range)
|
246 |
+
bucket_limits = torch.cat(
|
247 |
+
[full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], 0
|
248 |
+
)
|
249 |
+
|
250 |
+
else:
|
251 |
+
class_width = (full_range[1] - full_range[0]) / num_outputs
|
252 |
+
bucket_limits = torch.cat(
|
253 |
+
[
|
254 |
+
full_range[0] + torch.arange(num_outputs).float() * class_width,
|
255 |
+
torch.tensor(full_range[1]).unsqueeze(0),
|
256 |
+
],
|
257 |
+
0,
|
258 |
+
)
|
259 |
+
|
260 |
+
assert (
|
261 |
+
len(bucket_limits) - 1 == num_outputs
|
262 |
+
), f"len(bucket_limits) - 1 == {len(bucket_limits) - 1} != {num_outputs} == num_outputs"
|
263 |
+
assert full_range[0] == bucket_limits[0], f"{full_range[0]} != {bucket_limits[0]}"
|
264 |
+
assert (
|
265 |
+
full_range[-1] == bucket_limits[-1]
|
266 |
+
), f"{full_range[-1]} != {bucket_limits[-1]}"
|
267 |
+
|
268 |
+
return bucket_limits
|
269 |
+
|
lcpfn/curves.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
prior = {
|
5 |
+
"pow3": {
|
6 |
+
"uniform": OrderedDict(
|
7 |
+
a={"type": "uniform", "param1": -1, "param2": 1},
|
8 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
9 |
+
alpha={"type": "uniform", "param1": 0, "param2": 1},
|
10 |
+
),
|
11 |
+
"peaked": OrderedDict(
|
12 |
+
a={"type": "uniform", "param1": -0.6, "param2": 0.6},
|
13 |
+
c={"type": "uniform", "param1": 0, "param2": 1.25},
|
14 |
+
alpha={"type": "log_normal", "param1": 0, "param2": 2},
|
15 |
+
),
|
16 |
+
},
|
17 |
+
"ilog2": {
|
18 |
+
"uniform": OrderedDict(
|
19 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
20 |
+
a={"type": "uniform", "param1": -1, "param2": 1},
|
21 |
+
),
|
22 |
+
"peaked": OrderedDict(
|
23 |
+
c={"type": "uniform", "param1": 0, "param2": 1},
|
24 |
+
a={"type": "uniform", "param1": -0.5, "param2": 0.5},
|
25 |
+
),
|
26 |
+
},
|
27 |
+
"janoschek": {
|
28 |
+
"uniform": OrderedDict(
|
29 |
+
a={"type": "uniform", "param1": 0, "param2": 1},
|
30 |
+
beta={"type": "uniform", "param1": 0, "param2": 2},
|
31 |
+
k={"type": "uniform", "param1": 0, "param2": 1},
|
32 |
+
delta={"type": "uniform", "param1": -5, "param2": 5},
|
33 |
+
),
|
34 |
+
"peaked": OrderedDict(
|
35 |
+
a={"type": "uniform", "param1": 0, "param2": 1},
|
36 |
+
beta={"type": "uniform", "param1": 0, "param2": 2},
|
37 |
+
k={"type": "log_normal", "param1": -2, "param2": 1},
|
38 |
+
delta={"type": "log_normal", "param1": 0, "param2": 0.5},
|
39 |
+
),
|
40 |
+
},
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def prior_sampler(rng, type, param1, param2):
|
45 |
+
if type == "uniform":
|
46 |
+
return rng.uniform(param1, param2)
|
47 |
+
elif type == "log_normal":
|
48 |
+
return rng.lognormal(param1, param2)
|
49 |
+
raise Exception("Unknown prior type: {}".format(type))
|
50 |
+
|
51 |
+
|
52 |
+
def pow3(x, c, a, alpha):
|
53 |
+
return c - a * (x) ** (-alpha)
|
54 |
+
|
55 |
+
|
56 |
+
def prior_pow3(rng):
|
57 |
+
return {
|
58 |
+
p: prior_sampler(
|
59 |
+
rng,
|
60 |
+
prior["pow3"]["peaked"][p]["type"],
|
61 |
+
param1=prior["pow3"]["peaked"][p]["param1"],
|
62 |
+
param2=prior["pow3"]["peaked"][p]["param2"],
|
63 |
+
)
|
64 |
+
for p in ["a", "c", "alpha"]
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def uniform_prior_pow3(rng):
|
69 |
+
return {
|
70 |
+
p: prior_sampler(
|
71 |
+
rng,
|
72 |
+
prior["pow3"]["uniform"][p]["type"],
|
73 |
+
param1=prior["pow3"]["uniform"][p]["param1"],
|
74 |
+
param2=prior["pow3"]["uniform"][p]["param2"],
|
75 |
+
)
|
76 |
+
for p in ["a", "c", "alpha"]
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def ilog2(x, c, a):
|
81 |
+
return c - a / (np.log(x + 1))
|
82 |
+
|
83 |
+
|
84 |
+
def prior_ilog2(rng):
|
85 |
+
return {
|
86 |
+
p: prior_sampler(
|
87 |
+
rng,
|
88 |
+
prior["ilog2"]["peaked"][p]["type"],
|
89 |
+
param1=prior["ilog2"]["peaked"][p]["param1"],
|
90 |
+
param2=prior["ilog2"]["peaked"][p]["param2"],
|
91 |
+
)
|
92 |
+
for p in ["a", "c"]
|
93 |
+
}
|
94 |
+
|
95 |
+
|
96 |
+
def uniform_prior_ilog2(rng):
|
97 |
+
return {
|
98 |
+
p: prior_sampler(
|
99 |
+
rng,
|
100 |
+
prior["ilog2"]["uniform"][p]["type"],
|
101 |
+
param1=prior["ilog2"]["uniform"][p]["param1"],
|
102 |
+
param2=prior["ilog2"]["uniform"][p]["param2"],
|
103 |
+
)
|
104 |
+
for p in ["a", "c"]
|
105 |
+
}
|
106 |
+
|
107 |
+
|
108 |
+
def janoschek(x, a, beta, k, delta):
|
109 |
+
"""
|
110 |
+
http://www.pisces-conservation.com/growthhelp/janoschek.htm
|
111 |
+
"""
|
112 |
+
return a - (a - beta) * np.exp(-k * x**delta)
|
113 |
+
|
114 |
+
|
115 |
+
def prior_janoschek(rng):
|
116 |
+
return {
|
117 |
+
p: prior_sampler(
|
118 |
+
rng,
|
119 |
+
prior["janoschek"]["peaked"][p]["type"],
|
120 |
+
param1=prior["janoschek"]["peaked"][p]["param1"],
|
121 |
+
param2=prior["janoschek"]["peaked"][p]["param2"],
|
122 |
+
)
|
123 |
+
for p in ["a", "beta", "k", "delta"]
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
def uniform_prior_janoschek(rng):
|
128 |
+
return {
|
129 |
+
p: prior_sampler(
|
130 |
+
rng,
|
131 |
+
prior["janoschek"]["uniform"][p]["type"],
|
132 |
+
param1=prior["janoschek"]["uniform"][p]["param1"],
|
133 |
+
param2=prior["janoschek"]["uniform"][p]["param2"],
|
134 |
+
)
|
135 |
+
for p in ["a", "beta", "k", "delta"]
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
def log_power(x, a, b, c):
|
140 |
+
# a: upper bound
|
141 |
+
# c: growth rate
|
142 |
+
# initial = a/ (1 + (1/e^b)^c
|
143 |
+
return a / (1.0 + (x / np.exp(b)) ** c)
|
144 |
+
|
145 |
+
|
146 |
+
def prior_log_power(rng):
|
147 |
+
# a ~ N(0.8,0.1)
|
148 |
+
# b ~ N(1,1)
|
149 |
+
# c ~ U(-3,0)
|
150 |
+
a = rng.normal(0.8, 0.1)
|
151 |
+
b = rng.normal(1.0, 1.0)
|
152 |
+
c = rng.uniform(-3.0, 0.0)
|
153 |
+
return {"a": a, "b": b, "c": c}
|
154 |
+
|
155 |
+
|
156 |
+
def weibull(x, alpha, beta, kappa, delta):
|
157 |
+
"""
|
158 |
+
Weibull modell
|
159 |
+
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
160 |
+
alpha: upper asymptote
|
161 |
+
beta: lower asymptote
|
162 |
+
k: growth rate
|
163 |
+
delta: controls the x-ordinate for the point of inflection
|
164 |
+
"""
|
165 |
+
return alpha - (alpha - beta) * np.exp(-((kappa * x) ** delta))
|
166 |
+
|
167 |
+
|
168 |
+
def prior_weibull(rng):
|
169 |
+
alpha = rng.uniform(0.0, 1.5)
|
170 |
+
beta = rng.uniform(0.0, 1)
|
171 |
+
kappa = np.exp(rng.normal(-2.0, 1.0))
|
172 |
+
delta = np.exp(rng.normal(0, 0.5))
|
173 |
+
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
174 |
+
|
175 |
+
|
176 |
+
def mmf(x, alpha, beta, kappa, delta):
|
177 |
+
"""
|
178 |
+
Morgan-Mercer-Flodin
|
179 |
+
description:
|
180 |
+
Nonlinear Regression page 342
|
181 |
+
http://bit.ly/1jodG17
|
182 |
+
http://www.pisces-conservation.com/growthhelp/index.html?morgan_mercer_floden.htm
|
183 |
+
alpha: upper asymptote
|
184 |
+
kappa: growth rate
|
185 |
+
beta: initial value
|
186 |
+
delta: controls the point of inflection
|
187 |
+
"""
|
188 |
+
return alpha - (alpha - beta) / (1.0 + (kappa * x) ** delta)
|
189 |
+
|
190 |
+
|
191 |
+
def prior_mmf(rng):
|
192 |
+
# alpha ~ N(0.8,0.1)
|
193 |
+
# beta ~ N(0.2,0.1)
|
194 |
+
# ln(kappa) ~ N(0,2)
|
195 |
+
# ln(delta) ~ N(0,1)
|
196 |
+
alpha = rng.normal(0.8, 0.1)
|
197 |
+
beta = rng.normal(0.2, 0.1)
|
198 |
+
kappa = np.exp(rng.normal(0, 2))
|
199 |
+
delta = np.exp(rng.normal(0, 1))
|
200 |
+
return {"alpha": alpha, "beta": beta, "kappa": kappa, "delta": delta}
|
201 |
+
|
202 |
+
|
203 |
+
def vap(x, a, b, c):
|
204 |
+
"""Vapor pressure model"""
|
205 |
+
# no upper bound if c > 0
|
206 |
+
# a = ln(upper bound) for c=0
|
207 |
+
# a+b = ln(initial)
|
208 |
+
return np.exp(a + b / x + c * np.log(x))
|
209 |
+
|
210 |
+
|
211 |
+
def prior_vap(rng):
|
212 |
+
a = rng.uniform(-2.0, 0.0) # @heri: range check
|
213 |
+
b = rng.uniform(-4.0, 0.0) # @heri: range check
|
214 |
+
c = np.exp(rng.uniform(-8.0, 0.0)) # @heri: same as weights
|
215 |
+
return {"a": a, "b": b, "c": c}
|
216 |
+
|
217 |
+
|
218 |
+
def loglog_linear(x, a, b):
|
219 |
+
x = np.log(x)
|
220 |
+
return np.log(a * x + b)
|
221 |
+
|
222 |
+
|
223 |
+
def prior_loglog_linear(rng):
|
224 |
+
# ln(a) ~ N(-2, 1)
|
225 |
+
# ln(b) ~ U(0, 1)
|
226 |
+
a = np.exp(rng.normal(-2.0, 1.0))
|
227 |
+
b = np.exp(rng.uniform(0.0, 1.0))
|
228 |
+
return {"a": a, "b": b}
|
229 |
+
|
230 |
+
|
231 |
+
def exp4(x, c, a, b, alpha):
|
232 |
+
return c - np.exp(-a * (x**alpha) + b)
|
233 |
+
|
234 |
+
|
235 |
+
def prior_exp4(rng):
|
236 |
+
# c ~ N(0.8,0.1)
|
237 |
+
c = rng.normal(0.8, 0.1)
|
238 |
+
# ln(a) ~ N(-2,1)
|
239 |
+
a = np.exp(rng.normal(-2, 1))
|
240 |
+
# ln(alpha) ~ N(0,1)
|
241 |
+
alpha = np.exp(rng.normal(0, 1))
|
242 |
+
# ln(b) ~ N(0,0.5)
|
243 |
+
b = np.exp(rng.normal(0, 0.5))
|
244 |
+
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
245 |
+
|
246 |
+
|
247 |
+
def pow4(x, c, a, b, alpha):
|
248 |
+
return c - (a * x + b) ** -alpha
|
249 |
+
|
250 |
+
|
251 |
+
def prior_pow4(rng):
|
252 |
+
# ln(1 - c) ~ U(-5, 0)
|
253 |
+
c = 1 - np.exp(rng.uniform(-5.0, 0))
|
254 |
+
# ln(a) ~ N(-3, 2)
|
255 |
+
a = np.exp(rng.normal(-3.0, 2))
|
256 |
+
# ln(alpha) ~ N(0,1)
|
257 |
+
alpha = np.exp(rng.normal(0, 1))
|
258 |
+
# ln(b) ~ U(0, 1)
|
259 |
+
b = np.exp(rng.uniform(0, 1))
|
260 |
+
return {"a": a, "b": b, "c": c, "alpha": alpha}
|
261 |
+
|
262 |
+
|
263 |
+
def dr_hill_zero_background(x, theta, eta, kappa):
|
264 |
+
# theta: upper bound
|
265 |
+
# eta: growth rate
|
266 |
+
# initial = theta/(kappa^eta + 1)
|
267 |
+
return (theta * x**eta) / (kappa**eta + x**eta)
|
268 |
+
|
269 |
+
|
270 |
+
def prior_dr_hill_zero_background(rng):
|
271 |
+
# theta ~ U(1,0) N(0.8,0.1)
|
272 |
+
# ln(eta) ~ N(1,1)
|
273 |
+
# ln(kappa) ~ N(1,2)
|
274 |
+
theta = rng.normal(0.8, 0.1)
|
275 |
+
eta = np.exp(rng.normal(1.0, 1.0))
|
276 |
+
kappa = np.exp(rng.normal(1.0, 2.0))
|
277 |
+
return {"theta": theta, "eta": eta, "kappa": kappa}
|
lcpfn/decoders.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
|
6 |
+
class ScaledDecoder(nn.Module):
|
7 |
+
def __init__(self, ninp, nhid, nout):
|
8 |
+
super().__init__()
|
9 |
+
self.linear = nn.Linear(ninp, nhid)
|
10 |
+
self.linear1 = nn.Linear(nhid, nout)
|
11 |
+
self.linear2 = nn.Linear(nhid, 10)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
#return torch.cat([self.linear1(x), self.linear2(x)], -1)
|
15 |
+
x = self.linear(x)
|
16 |
+
x = nn.GELU()(x)
|
17 |
+
temps = self.linear2(x).softmax(-1) @ torch.tensor([1.,1.4,1.7,2.,5.,10.,20.,40.,80.,160.], device=x.device)
|
18 |
+
if random.random() > .99:
|
19 |
+
print(temps.shape,temps[:,:2])
|
20 |
+
return self.linear1(x) / temps.unsqueeze(-1)
|
21 |
+
|
22 |
+
class FixedScaledDecoder(nn.Module):
|
23 |
+
def __init__(self, ninp, nhid, nout):
|
24 |
+
super().__init__()
|
25 |
+
self.mapper = nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, nout))
|
26 |
+
self.T = nn.Parameter(torch.ones(10000)/10000)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return self.mapper(x)/self.T.sum()
|
30 |
+
|
lcpfn/domhan_prior.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from lcpfn.curves import (
|
5 |
+
pow3,
|
6 |
+
ilog2,
|
7 |
+
janoschek,
|
8 |
+
log_power,
|
9 |
+
prior_ilog2,
|
10 |
+
uniform_prior_pow3,
|
11 |
+
weibull,
|
12 |
+
mmf,
|
13 |
+
vap,
|
14 |
+
loglog_linear,
|
15 |
+
exp4,
|
16 |
+
pow4,
|
17 |
+
dr_hill_zero_background,
|
18 |
+
)
|
19 |
+
from lcpfn.curves import (
|
20 |
+
prior_pow3,
|
21 |
+
prior_janoschek,
|
22 |
+
prior_log_power,
|
23 |
+
prior_weibull,
|
24 |
+
prior_mmf,
|
25 |
+
prior_vap,
|
26 |
+
prior_loglog_linear,
|
27 |
+
prior_exp4,
|
28 |
+
prior_pow4,
|
29 |
+
prior_dr_hill_zero_background,
|
30 |
+
)
|
31 |
+
from lcpfn.curves import (
|
32 |
+
uniform_prior_pow3,
|
33 |
+
uniform_prior_ilog2,
|
34 |
+
uniform_prior_janoschek,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def prior_weights(
|
39 |
+
rng,
|
40 |
+
components=[
|
41 |
+
"pow3",
|
42 |
+
"ilog2",
|
43 |
+
"janoschek",
|
44 |
+
"log_power",
|
45 |
+
"weibull",
|
46 |
+
"mmf",
|
47 |
+
"vap",
|
48 |
+
"loglog_linear",
|
49 |
+
"exp4",
|
50 |
+
"pow4",
|
51 |
+
"dr_hill_zero_background",
|
52 |
+
],
|
53 |
+
):
|
54 |
+
K = len(components)
|
55 |
+
weights = rng.uniform(0.0, 1, size=(K,))
|
56 |
+
return {f: weights[i] for i, f in enumerate(components)}
|
57 |
+
|
58 |
+
|
59 |
+
def sample_from_prior(rng, seq_len=100):
|
60 |
+
return sample_prior_comb(
|
61 |
+
rng=rng, seq_len=seq_len, components=["pow3", "ilog2", "janoschek"], distribution="peaked"
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
def sample_prior_comb(
|
66 |
+
rng,
|
67 |
+
components,
|
68 |
+
distribution,
|
69 |
+
var_lnloc=-4,
|
70 |
+
var_lnscale=1,
|
71 |
+
range_constraint=True,
|
72 |
+
seq_len=100,
|
73 |
+
):
|
74 |
+
f_components = {
|
75 |
+
"pow3": pow3,
|
76 |
+
"ilog2": ilog2,
|
77 |
+
"janoschek": janoschek,
|
78 |
+
"log_power": log_power,
|
79 |
+
"weibull": weibull,
|
80 |
+
"mmf": mmf,
|
81 |
+
"vap": vap,
|
82 |
+
"loglog_linear": loglog_linear,
|
83 |
+
"exp4": exp4,
|
84 |
+
"pow4": pow4,
|
85 |
+
"dr_hill_zero_background": dr_hill_zero_background,
|
86 |
+
}
|
87 |
+
|
88 |
+
if distribution == "peaked":
|
89 |
+
f_priors = {
|
90 |
+
"pow3": prior_pow3,
|
91 |
+
"ilog2": prior_ilog2,
|
92 |
+
"janoschek": prior_janoschek,
|
93 |
+
"log_power": prior_log_power,
|
94 |
+
"weibull": prior_weibull,
|
95 |
+
"mmf": prior_mmf,
|
96 |
+
"vap": prior_vap,
|
97 |
+
"loglog_linear": prior_loglog_linear,
|
98 |
+
"exp4": prior_exp4,
|
99 |
+
"pow4": prior_pow4,
|
100 |
+
"dr_hill_zero_background": prior_dr_hill_zero_background,
|
101 |
+
}
|
102 |
+
elif distribution == "uniform":
|
103 |
+
f_priors = {
|
104 |
+
"pow3": uniform_prior_pow3,
|
105 |
+
"ilog2": uniform_prior_ilog2,
|
106 |
+
"janoschek": uniform_prior_janoschek
|
107 |
+
}
|
108 |
+
else:
|
109 |
+
raise NotImplemented()
|
110 |
+
|
111 |
+
x = np.arange(1, seq_len + 1)
|
112 |
+
|
113 |
+
while True:
|
114 |
+
# sample the noiseless curve
|
115 |
+
weights = prior_weights(rng, components=components)
|
116 |
+
y = np.zeros(x.shape, dtype="float")
|
117 |
+
kwargs = 0
|
118 |
+
for f, w in weights.items():
|
119 |
+
kwargs = f_priors[f](rng)
|
120 |
+
# print(f_components[f](x, **kwargs))
|
121 |
+
y += w * f_components[f](x, **kwargs)
|
122 |
+
# add noise (can exceed [0,1], but afaik no way to implement this prior in Tobis work)
|
123 |
+
var = np.exp(
|
124 |
+
rng.normal(var_lnloc, var_lnscale)
|
125 |
+
) # @heri: ln_prob =+ log(normal.pdf(log(var), loc=var_lnloc, scale=var_lnscale))
|
126 |
+
|
127 |
+
# reject any curves that are non-increasing, exceed the [0,1] range
|
128 |
+
if (
|
129 |
+
y[-1] <= y[0]
|
130 |
+
or (range_constraint and (np.any(y < 0) or np.any(y > 1)))
|
131 |
+
or np.isnan(y).any()
|
132 |
+
):
|
133 |
+
continue
|
134 |
+
else:
|
135 |
+
break
|
136 |
+
|
137 |
+
def curve(): # generates a sample from the same model, but with independent noise
|
138 |
+
y_noisy = y + rng.normal(np.zeros_like(y), var)
|
139 |
+
return y, y_noisy
|
140 |
+
|
141 |
+
return curve
|
142 |
+
|
143 |
+
|
144 |
+
def generate_prior_dataset(n, prior=sample_prior_comb, seed=42):
|
145 |
+
"""
|
146 |
+
Returns a fixed sample from the prior (with fixed seq_len) as an n x seq_len np.ndarray
|
147 |
+
"""
|
148 |
+
rng = np.random.RandomState(seed)
|
149 |
+
prior_data = np.stack([prior(rng)()[1] for _ in range(n)])
|
150 |
+
return prior_data
|
151 |
+
|
152 |
+
|
153 |
+
def create_get_batch_func(prior):
|
154 |
+
return partial(get_batch_domhan, prior=prior)
|
155 |
+
|
156 |
+
# function producing batches for PFN training
|
157 |
+
def get_batch_domhan(
|
158 |
+
batch_size,
|
159 |
+
seq_len,
|
160 |
+
num_features,
|
161 |
+
prior,
|
162 |
+
device="cpu",
|
163 |
+
noisy_target=True,
|
164 |
+
**_,
|
165 |
+
):
|
166 |
+
assert num_features == 1
|
167 |
+
|
168 |
+
x = np.arange(1, seq_len + 1)
|
169 |
+
y_target = np.empty((batch_size, seq_len), dtype=float)
|
170 |
+
y_noisy = np.empty((batch_size, seq_len), dtype=float)
|
171 |
+
|
172 |
+
for i in range(batch_size):
|
173 |
+
curve_func = prior(np.random, seq_len=seq_len) # uses numpy rng
|
174 |
+
if noisy_target:
|
175 |
+
_, y_noisy[i] = curve_func()
|
176 |
+
y_target[i] = y_noisy[i]
|
177 |
+
else:
|
178 |
+
y_target[i], y_noisy[i] = curve_func()
|
179 |
+
|
180 |
+
# turn numpy arrays into correctly shaped torch tensors & move them to device
|
181 |
+
x = (
|
182 |
+
torch.arange(1, seq_len + 1)
|
183 |
+
.repeat((num_features, batch_size, 1))
|
184 |
+
.transpose(2, 0)
|
185 |
+
.to(device)
|
186 |
+
)
|
187 |
+
y_target = torch.from_numpy(y_target).transpose(1, 0).to(device)
|
188 |
+
y_noisy = torch.from_numpy(y_noisy).transpose(1, 0).to(device)
|
189 |
+
|
190 |
+
# changes
|
191 |
+
x = x.float()
|
192 |
+
y_target = y_target.float()
|
193 |
+
y_noisy = y_noisy.float()
|
194 |
+
|
195 |
+
return x, y_noisy, y_target
|
lcpfn/encoders.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from lcpfn.utils import normalize_data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
+
|
9 |
+
|
10 |
+
class StyleEncoder(nn.Module):
|
11 |
+
def __init__(self, em_size, hyperparameter_definitions):
|
12 |
+
super().__init__()
|
13 |
+
self.em_size = em_size
|
14 |
+
self.embedding = nn.Linear(hyperparameter_definitions.shape[0], self.em_size)
|
15 |
+
|
16 |
+
def forward(self, hyperparameters): # T x B x num_hps
|
17 |
+
return self.embedding(hyperparameters)
|
18 |
+
|
19 |
+
|
20 |
+
class _PositionalEncoding(nn.Module):
|
21 |
+
def __init__(self, d_model, dropout=0.):
|
22 |
+
super().__init__()
|
23 |
+
self.dropout = nn.Dropout(p=dropout)
|
24 |
+
self.d_model = d_model
|
25 |
+
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
26 |
+
|
27 |
+
def forward(self, x):# T x B x num_features
|
28 |
+
assert self.d_model % x.shape[-1]*2 == 0
|
29 |
+
d_per_feature = self.d_model // x.shape[-1]
|
30 |
+
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
31 |
+
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
32 |
+
interval_size = 10
|
33 |
+
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
34 |
+
#print(div_term/2/math.pi)
|
35 |
+
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
36 |
+
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
37 |
+
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
38 |
+
|
39 |
+
|
40 |
+
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
41 |
+
|
42 |
+
class EmbeddingEncoder(nn.Module):
|
43 |
+
def __init__(self, num_features, em_size, num_embs=100):
|
44 |
+
super().__init__()
|
45 |
+
self.num_embs = num_embs
|
46 |
+
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
47 |
+
self.init_weights(.1)
|
48 |
+
self.min_max = (-2,+2)
|
49 |
+
|
50 |
+
@property
|
51 |
+
def width(self):
|
52 |
+
return self.min_max[1] - self.min_max[0]
|
53 |
+
|
54 |
+
def init_weights(self, initrange):
|
55 |
+
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
56 |
+
|
57 |
+
def discretize(self, x):
|
58 |
+
split_size = self.width / self.num_embs
|
59 |
+
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
60 |
+
|
61 |
+
def forward(self, x): # T x B x num_features
|
62 |
+
x_idxs = self.discretize(x)
|
63 |
+
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
64 |
+
# print(x_idxs,self.embeddings.weight.shape)
|
65 |
+
return self.embeddings(x_idxs).mean(-2)
|
66 |
+
|
67 |
+
|
68 |
+
class Normalize(nn.Module):
|
69 |
+
def __init__(self, mean, std):
|
70 |
+
super().__init__()
|
71 |
+
self.mean = mean
|
72 |
+
self.std = std
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
return (x-self.mean)/self.std
|
76 |
+
|
77 |
+
|
78 |
+
def get_normalized_uniform_encoder(encoder_creator):
|
79 |
+
"""
|
80 |
+
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
81 |
+
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
82 |
+
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
83 |
+
:param encoder:
|
84 |
+
:return:
|
85 |
+
"""
|
86 |
+
return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
|
87 |
+
|
88 |
+
|
89 |
+
Linear = nn.Linear
|
90 |
+
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
91 |
+
nn.ReLU(),
|
92 |
+
nn.Linear(emsize*2,emsize))
|
93 |
+
|
94 |
+
class NanHandlingEncoder(nn.Module):
|
95 |
+
def __init__(self, num_features, emsize, keep_nans=True):
|
96 |
+
super().__init__()
|
97 |
+
self.num_features = 2 * num_features if keep_nans else num_features
|
98 |
+
self.emsize = emsize
|
99 |
+
self.keep_nans = keep_nans
|
100 |
+
self.layer = nn.Linear(self.num_features, self.emsize)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
if self.keep_nans:
|
104 |
+
x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
|
105 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
106 |
+
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
107 |
+
)], -1)
|
108 |
+
else:
|
109 |
+
x = torch.nan_to_num(x, nan=0.0)
|
110 |
+
return self.layer(x)
|
111 |
+
|
112 |
+
|
113 |
+
class Linear(nn.Linear):
|
114 |
+
def __init__(self, num_features, emsize):
|
115 |
+
super().__init__(num_features, emsize)
|
116 |
+
self.num_features = num_features
|
117 |
+
self.emsize = emsize
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = torch.nan_to_num(x, nan=0.0)
|
121 |
+
return super().forward(x)
|
122 |
+
|
123 |
+
|
124 |
+
class Conv(nn.Module):
|
125 |
+
def __init__(self, input_size, emsize):
|
126 |
+
super().__init__()
|
127 |
+
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
128 |
+
self.linear = nn.Linear(64,emsize)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
size = math.isqrt(x.shape[-1])
|
132 |
+
assert size*size == x.shape[-1]
|
133 |
+
x = x.reshape(*x.shape[:-1], 1, size, size)
|
134 |
+
for conv in self.convs:
|
135 |
+
if x.shape[-1] < 4:
|
136 |
+
break
|
137 |
+
x = conv(x)
|
138 |
+
x.relu_()
|
139 |
+
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
140 |
+
return self.linear(x)
|
141 |
+
|
142 |
+
|
143 |
+
class CanEmb(nn.Embedding):
|
144 |
+
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
145 |
+
assert embedding_dim % num_features == 0
|
146 |
+
embedding_dim = embedding_dim // num_features
|
147 |
+
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
lx = x.long()
|
151 |
+
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
152 |
+
x = super().forward(lx)
|
153 |
+
return x.view(*x.shape[:-2], -1)
|
154 |
+
|
155 |
+
|
156 |
+
def get_Canonical(num_classes):
|
157 |
+
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
158 |
+
|
159 |
+
|
160 |
+
def get_Embedding(num_embs_per_feature=100):
|
161 |
+
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|
lcpfn/initializers.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
def get_NormalInitializer(std):
|
5 |
+
def initializer(m):
|
6 |
+
if isinstance(m, nn.Linear):
|
7 |
+
nn.init.normal_(m.weight, 0, std)
|
8 |
+
nn.init.normal_(m.bias, 0, std)
|
9 |
+
return initializer
|
lcpfn/layer.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Optional
|
3 |
+
from torch import Tensor
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules.transformer import *
|
6 |
+
from torch.nn.modules.transformer import _get_activation_fn
|
7 |
+
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
class TransformerEncoderLayer(nn.Module):
|
12 |
+
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
13 |
+
This standard encoder layer is based on the paper "Attention Is All You Need".
|
14 |
+
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
15 |
+
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
16 |
+
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
17 |
+
in a different way during application.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
d_model: the number of expected features in the input (required).
|
21 |
+
nhead: the number of heads in the multiheadattention models (required).
|
22 |
+
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
23 |
+
dropout: the dropout value (default=0.1).
|
24 |
+
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
25 |
+
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
|
26 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
27 |
+
as (batch, seq, feature). Default: ``False``.
|
28 |
+
|
29 |
+
Examples::
|
30 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
31 |
+
>>> src = torch.rand(10, 32, 512)
|
32 |
+
>>> out = encoder_layer(src)
|
33 |
+
|
34 |
+
Alternatively, when ``batch_first`` is ``True``:
|
35 |
+
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
36 |
+
>>> src = torch.rand(32, 10, 512)
|
37 |
+
>>> out = encoder_layer(src)
|
38 |
+
"""
|
39 |
+
__constants__ = ['batch_first']
|
40 |
+
|
41 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
42 |
+
layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
|
43 |
+
device=None, dtype=None, recompute_attn=False) -> None:
|
44 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
45 |
+
super().__init__()
|
46 |
+
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
47 |
+
**factory_kwargs)
|
48 |
+
# Implementation of Feedforward model
|
49 |
+
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
50 |
+
self.dropout = Dropout(dropout)
|
51 |
+
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
52 |
+
|
53 |
+
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
54 |
+
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
55 |
+
self.dropout1 = Dropout(dropout)
|
56 |
+
self.dropout2 = Dropout(dropout)
|
57 |
+
self.pre_norm = pre_norm
|
58 |
+
self.recompute_attn = recompute_attn
|
59 |
+
|
60 |
+
self.activation = _get_activation_fn(activation)
|
61 |
+
|
62 |
+
def __setstate__(self, state):
|
63 |
+
if 'activation' not in state:
|
64 |
+
state['activation'] = F.relu
|
65 |
+
super().__setstate__(state)
|
66 |
+
|
67 |
+
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
68 |
+
r"""Pass the input through the encoder layer.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
src: the sequence to the encoder layer (required).
|
72 |
+
src_mask: the mask for the src sequence (optional).
|
73 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
74 |
+
|
75 |
+
Shape:
|
76 |
+
see the docs in Transformer class.
|
77 |
+
"""
|
78 |
+
if self.pre_norm:
|
79 |
+
src_ = self.norm1(src)
|
80 |
+
else:
|
81 |
+
src_ = src
|
82 |
+
if isinstance(src_mask, tuple):
|
83 |
+
# global attention setup
|
84 |
+
assert not self.self_attn.batch_first
|
85 |
+
assert src_key_padding_mask is None
|
86 |
+
|
87 |
+
global_src_mask, trainset_src_mask, valset_src_mask = src_mask
|
88 |
+
|
89 |
+
num_global_tokens = global_src_mask.shape[0]
|
90 |
+
num_train_tokens = trainset_src_mask.shape[0]
|
91 |
+
|
92 |
+
global_tokens_src = src_[:num_global_tokens]
|
93 |
+
train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
|
94 |
+
global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
|
95 |
+
eval_tokens_src = src_[num_global_tokens+num_train_tokens:]
|
96 |
+
|
97 |
+
|
98 |
+
attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn
|
99 |
+
|
100 |
+
global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
|
101 |
+
train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
|
102 |
+
eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
|
103 |
+
None, True, valset_src_mask)[0]
|
104 |
+
|
105 |
+
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
106 |
+
|
107 |
+
else:
|
108 |
+
if self.recompute_attn:
|
109 |
+
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
110 |
+
else:
|
111 |
+
src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
|
112 |
+
key_padding_mask=src_key_padding_mask)[0]
|
113 |
+
src = src + self.dropout1(src2)
|
114 |
+
if not self.pre_norm:
|
115 |
+
src = self.norm1(src)
|
116 |
+
|
117 |
+
if self.pre_norm:
|
118 |
+
src_ = self.norm2(src)
|
119 |
+
else:
|
120 |
+
src_ = src
|
121 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
|
122 |
+
src = src + self.dropout2(src2)
|
123 |
+
|
124 |
+
if not self.pre_norm:
|
125 |
+
src = self.norm2(src)
|
126 |
+
return src
|
lcpfn/model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lcpfn
|
3 |
+
|
4 |
+
class LCPFN(torch.nn.Module):
|
5 |
+
def __init__(self, model_name="EMSIZE512_NLAYERS12_NBUCKETS1000"):
|
6 |
+
super(LCPFN, self).__init__()
|
7 |
+
self.model = torch.load(getattr(lcpfn, model_name) if model_name in lcpfn.model_dict else model_name)
|
8 |
+
self.model.eval()
|
9 |
+
|
10 |
+
@torch.no_grad()
|
11 |
+
def predict_mean(self, x_train, y_train, x_test):
|
12 |
+
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
13 |
+
return self.model.criterion.mean(logits)
|
14 |
+
|
15 |
+
@torch.no_grad()
|
16 |
+
def predict_quantiles(self, x_train, y_train, x_test, qs):
|
17 |
+
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
18 |
+
return torch.cat([self.model.criterion.icdf(logits, q) for q in qs], dim=1)
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def nll_loss(self, x_train, y_train, x_test, y_test):
|
22 |
+
logits = self(x_train=x_train, y_train=y_train, x_test=x_test)
|
23 |
+
return self.model.criterion(logits, y_test)
|
24 |
+
|
25 |
+
def forward(self, x_train, y_train, x_test):
|
26 |
+
single_eval_pos = x_train.shape[0]
|
27 |
+
x = torch.cat([x_train, x_test], dim=0).unsqueeze(1)
|
28 |
+
y = y_train.unsqueeze(1)
|
29 |
+
return self.model((x, y), single_eval_pos=single_eval_pos)
|
lcpfn/positional_encodings.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
# Protocol for positonal encodings.
|
8 |
+
# __init__(d_model, max_len=..[, more optionals])
|
9 |
+
# forward(x: (seq_len, bs, d_model)) -> Tensor of shape (*x.shape[:2],d_model) containing pos. embeddings
|
10 |
+
|
11 |
+
|
12 |
+
class NoPositionalEncoding(nn.Module):
|
13 |
+
def __init__(self, d_model, max_len=None):
|
14 |
+
super(NoPositionalEncoding, self).__init__()
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x #* math.sqrt(x.shape[-1])
|
19 |
+
|
20 |
+
|
21 |
+
class PositionalEncoding(nn.Module):
|
22 |
+
def __init__(self, d_model, max_len=5000):
|
23 |
+
super(PositionalEncoding, self).__init__()
|
24 |
+
pe = torch.zeros(max_len, d_model)
|
25 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
26 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
27 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
29 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
30 |
+
self.register_buffer('pe', pe)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x = self.pe[:x.size(0), :] + x # * math.sqrt(x.shape[-1])
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class LearnedPositionalEncoding(nn.Module):
|
38 |
+
def __init__(self, d_model, max_len=5000):
|
39 |
+
super(LearnedPositionalEncoding, self).__init__()
|
40 |
+
self.max_seq_len = max_len
|
41 |
+
#self.positional_embeddings = nn.Embedding(max_len, d_model)
|
42 |
+
self.positional_embeddings = nn.Parameter(torch.empty(max_len, d_model))
|
43 |
+
nn.init.normal_(self.positional_embeddings, mean=0, std=d_model ** -0.5)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
seq_len, bs, d_model = x.shape
|
47 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
48 |
+
pos_emb = self.positional_embeddings[:seq_len]
|
49 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
50 |
+
|
51 |
+
|
52 |
+
class PairedScrambledPositionalEncodings(LearnedPositionalEncoding):
|
53 |
+
# TODO check whether it is a problem to use the same perm. for full batch
|
54 |
+
def forward(self, x):
|
55 |
+
seq_len, bs, d_model = x.shape
|
56 |
+
assert seq_len <= len(self.positional_embeddings), 'seq_len can be at most max_len.'
|
57 |
+
assert len(self.positional_embeddings) % 2 == 0, 'Please specify an even max_len.'
|
58 |
+
|
59 |
+
paired_embs = self.positional_embeddings.view(len(self.positional_embeddings), -1, 2)
|
60 |
+
pos_emb = paired_embs[torch.randperm(len(paired_embs))].view(*self.positional_embeddings.shape)[:seq_len]
|
61 |
+
|
62 |
+
return pos_emb.unsqueeze(1).expand(seq_len, bs, d_model) + x #* math.sqrt(x.shape[-1])
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
lcpfn/priors/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import gp, ridge
|
lcpfn/priors/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (216 Bytes). View file
|
|
lcpfn/priors/__pycache__/gp.cpython-310.pyc
ADDED
Binary file (2.17 kB). View file
|
|
lcpfn/priors/__pycache__/prior.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
lcpfn/priors/__pycache__/ridge.cpython-310.pyc
ADDED
Binary file (1.44 kB). View file
|
|
lcpfn/priors/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (6.26 kB). View file
|
|
lcpfn/priors/binarized_regression.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fast_gp, fast_gp_mix
|
2 |
+
from .utils import get_batch_to_dataloader
|
3 |
+
|
4 |
+
def regression_prior_to_binary(get_batch_function):
|
5 |
+
|
6 |
+
def binarized_get_batch_function(*args, assert_on=False, **kwargs):
|
7 |
+
x, y, target_y = get_batch_function(*args, **kwargs)
|
8 |
+
if assert_on:
|
9 |
+
assert y is target_y, "y == target_y is assumed by this function"
|
10 |
+
y = y.sigmoid().bernoulli()
|
11 |
+
return x, y, y
|
12 |
+
|
13 |
+
return binarized_get_batch_function
|
14 |
+
|
15 |
+
|
16 |
+
Binarized_fast_gp_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp.get_batch))
|
17 |
+
|
18 |
+
|
19 |
+
Binarized_fast_gp_mix_dataloader = get_batch_to_dataloader(regression_prior_to_binary(fast_gp_mix.get_batch))
|
lcpfn/priors/fast_gp.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import gpytorch
|
6 |
+
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
from utils import default_device
|
9 |
+
|
10 |
+
|
11 |
+
# We will use the simplest form of GP model, exact inference
|
12 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
13 |
+
def __init__(self, train_x, train_y, likelihood):
|
14 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
15 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
16 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
mean_x = self.mean_module(x)
|
20 |
+
covar_x = self.covar_module(x)
|
21 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
22 |
+
|
23 |
+
|
24 |
+
def get_model(x, y, hyperparameters):
|
25 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
|
26 |
+
model = ExactGPModel(x, y, likelihood)
|
27 |
+
model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
|
28 |
+
model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
|
29 |
+
model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
|
30 |
+
hyperparameters["lengthscale"]
|
31 |
+
return model, likelihood
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
36 |
+
equidistant_x=False, fix_x=None, **kwargs):
|
37 |
+
if isinstance(hyperparameters, (tuple, list)):
|
38 |
+
hyperparameters = {"noise": hyperparameters[0]
|
39 |
+
, "outputscale": hyperparameters[1]
|
40 |
+
, "lengthscale": hyperparameters[2]
|
41 |
+
, "is_binary_classification": hyperparameters[3]
|
42 |
+
# , "num_features_used": hyperparameters[4]
|
43 |
+
, "normalize_by_used_features": hyperparameters[5]
|
44 |
+
, "order_y": hyperparameters[6]
|
45 |
+
, "sampling": hyperparameters[7]
|
46 |
+
}
|
47 |
+
elif hyperparameters is None:
|
48 |
+
hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}
|
49 |
+
|
50 |
+
if 'verbose' in hyperparameters and hyperparameters['verbose']:
|
51 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
52 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})
|
53 |
+
|
54 |
+
# hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
|
55 |
+
# hyperparameters.keys()}
|
56 |
+
assert not (equidistant_x and (fix_x is not None))
|
57 |
+
|
58 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
|
59 |
+
if equidistant_x:
|
60 |
+
assert num_features == 1
|
61 |
+
x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
|
62 |
+
elif fix_x is not None:
|
63 |
+
assert fix_x.shape == (seq_len, num_features)
|
64 |
+
x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
|
65 |
+
else:
|
66 |
+
if hyperparameters.get('sampling','uniform') == 'uniform':
|
67 |
+
x = torch.rand(batch_size, seq_len, num_features, device=device)
|
68 |
+
else:
|
69 |
+
x = torch.randn(batch_size, seq_len, num_features, device=device)
|
70 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
71 |
+
model.to(device)
|
72 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
73 |
+
# trained_model.eval()
|
74 |
+
successful_sample = False
|
75 |
+
while not successful_sample:
|
76 |
+
try:
|
77 |
+
with gpytorch.settings.prior_mode(True):
|
78 |
+
model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
|
79 |
+
model.to(device)
|
80 |
+
|
81 |
+
d = model(x)
|
82 |
+
sample_wo_noise = d.sample().transpose(0, 1) # this will be the target for the loss
|
83 |
+
sample = likelihood(sample_wo_noise).sample() # this will be the input to the Transformer
|
84 |
+
successful_sample = True
|
85 |
+
except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
|
86 |
+
print('GP Sampling unsuccessful, retrying.. ')
|
87 |
+
print(x)
|
88 |
+
print(hyperparameters)
|
89 |
+
|
90 |
+
if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
|
91 |
+
print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
|
92 |
+
, "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})
|
93 |
+
|
94 |
+
# TODO: Multi output
|
95 |
+
return x.transpose(0, 1), sample, sample if hyperparameters.get("observation_noise", True) else sample_wo_noise
|
96 |
+
|
97 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
98 |
+
|
99 |
+
def get_model_on_device(x,y,hyperparameters,device):
|
100 |
+
model, likelihood = get_model(x, y, hyperparameters)
|
101 |
+
model.to(device)
|
102 |
+
return model, likelihood
|
103 |
+
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, hyperparameters={}, get_model_on_device=get_model_on_device, device=default_device, step_size=1, start_pos=0):
|
107 |
+
start_time = time.time()
|
108 |
+
losses_after_t = [.0] if start_pos == 0 else []
|
109 |
+
all_losses_after_t = []
|
110 |
+
|
111 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
112 |
+
for t in range(max(start_pos, 1), len(x), step_size):
|
113 |
+
loss_sum = 0.
|
114 |
+
model, likelihood = get_model_on_device(x[:t].transpose(0, 1), y[:t].transpose(0, 1), hyperparameters, device)
|
115 |
+
|
116 |
+
|
117 |
+
model.eval()
|
118 |
+
# print([t.shape for t in model.train_inputs])
|
119 |
+
# print(x[:t].transpose(0,1).shape, x[t].unsqueeze(1).shape, y[:t].transpose(0,1).shape)
|
120 |
+
f = model(x[t].unsqueeze(1))
|
121 |
+
l = likelihood(f)
|
122 |
+
means = l.mean.squeeze()
|
123 |
+
varis = l.covariance_matrix.squeeze()
|
124 |
+
# print(l.variance.squeeze(), l.mean.squeeze(), y[t])
|
125 |
+
|
126 |
+
assert len(means.shape) == len(varis.shape) == 1
|
127 |
+
assert len(means) == len(varis) == x.shape[1]
|
128 |
+
|
129 |
+
if use_mse:
|
130 |
+
c = nn.MSELoss(reduction='none')
|
131 |
+
ls = c(means, y[t])
|
132 |
+
else:
|
133 |
+
ls = -l.log_prob(y[t].unsqueeze(1))
|
134 |
+
|
135 |
+
losses_after_t.append(ls.mean())
|
136 |
+
all_losses_after_t.append(ls.flatten())
|
137 |
+
return torch.stack(all_losses_after_t).to('cpu'), torch.tensor(losses_after_t).to('cpu'), time.time() - start_time
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
hps = (.1,.1,.1)
|
141 |
+
for redo_idx in range(1):
|
142 |
+
print(
|
143 |
+
evaluate(*get_batch(1000, 10, hyperparameters=hps, num_features=10), use_mse=False, hyperparameters=hps))
|
lcpfn/priors/fast_gp_mix.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import functools
|
3 |
+
import random
|
4 |
+
import math
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import gpytorch
|
11 |
+
from botorch.models import SingleTaskGP
|
12 |
+
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
|
13 |
+
from botorch.fit import fit_gpytorch_model
|
14 |
+
from gpytorch.mlls import ExactMarginalLogLikelihood
|
15 |
+
from gpytorch.likelihoods import GaussianLikelihood
|
16 |
+
from gpytorch.priors.torch_priors import GammaPrior, UniformPrior
|
17 |
+
from gpytorch.constraints import GreaterThan
|
18 |
+
|
19 |
+
|
20 |
+
from bar_distribution import BarDistribution
|
21 |
+
from utils import default_device
|
22 |
+
from .utils import get_batch_to_dataloader
|
23 |
+
from . import fast_gp
|
24 |
+
|
25 |
+
def get_model(x, y, hyperparameters: dict, sample=True):
|
26 |
+
if hyperparameters.get('handmade', False):
|
27 |
+
# We will use the simplest form of GP model, exact inference
|
28 |
+
class ExactGPModel(gpytorch.models.ExactGP):
|
29 |
+
def __init__(self, train_x, train_y, likelihood):
|
30 |
+
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
|
31 |
+
self.mean_module = gpytorch.means.ConstantMean()
|
32 |
+
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel())
|
33 |
+
self.mean_module.register_prior("mean_prior", UniformPrior(-1, 1), "constant")
|
34 |
+
self.covar_module.base_kernel.register_prior("lengthscale_prior", UniformPrior(0.01, 0.5),
|
35 |
+
"lengthscale")
|
36 |
+
# model.covar_module.base_kernel.register_prior("period_length_prior", UniformPrior(0.05, 2.5), "period_length")
|
37 |
+
self.covar_module.register_prior("outputscale_prior", UniformPrior(1, 2), "outputscale")
|
38 |
+
likelihood.register_prior("noise_prior", UniformPrior(0.001, 0.01), "noise")
|
39 |
+
self.to(x)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
mean_x = self.mean_module(x)
|
43 |
+
covar_x = self.covar_module(x)
|
44 |
+
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
45 |
+
|
46 |
+
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())
|
47 |
+
model = ExactGPModel(x, y, likelihood)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
else:
|
52 |
+
aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape
|
53 |
+
noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05))
|
54 |
+
noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate
|
55 |
+
likelihood = GaussianLikelihood(
|
56 |
+
noise_prior=noise_prior,
|
57 |
+
batch_shape=aug_batch_shape,
|
58 |
+
noise_constraint=GreaterThan(
|
59 |
+
MIN_INFERRED_NOISE_LEVEL,
|
60 |
+
transform=None,
|
61 |
+
initial_value=noise_prior_mode,
|
62 |
+
),
|
63 |
+
)
|
64 |
+
model = SingleTaskGP(x, y.unsqueeze(-1),
|
65 |
+
covar_module=gpytorch.kernels.ScaleKernel(
|
66 |
+
gpytorch.kernels.MaternKernel(
|
67 |
+
nu=hyperparameters.get('nu',2.5),
|
68 |
+
ard_num_dims=x.shape[-1],
|
69 |
+
batch_shape=aug_batch_shape,
|
70 |
+
lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)),
|
71 |
+
),
|
72 |
+
batch_shape=aug_batch_shape,
|
73 |
+
outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)),
|
74 |
+
), likelihood=likelihood)
|
75 |
+
|
76 |
+
likelihood = model.likelihood
|
77 |
+
model.to(x.device)
|
78 |
+
if sample:
|
79 |
+
sampled_model = model.pyro_sample_from_prior()
|
80 |
+
return sampled_model, sampled_model.likelihood
|
81 |
+
else:
|
82 |
+
assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..."
|
83 |
+
return model, likelihood
|
84 |
+
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None,
|
88 |
+
batch_size_per_gp_sample=None,
|
89 |
+
fix_to_range=None, equidistant_x=False, **kwargs):
|
90 |
+
'''
|
91 |
+
This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over
|
92 |
+
a mixture of GP priors.
|
93 |
+
:param batch_size:
|
94 |
+
:param seq_len:
|
95 |
+
:param num_features:
|
96 |
+
:param device:
|
97 |
+
:param hyperparameters:
|
98 |
+
:param for_regression:
|
99 |
+
:return:
|
100 |
+
'''
|
101 |
+
hyperparameters = hyperparameters or {}
|
102 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))):
|
103 |
+
batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1))
|
104 |
+
assert batch_size % batch_size_per_gp_sample == 0
|
105 |
+
|
106 |
+
total_num_candidates = batch_size*(2**(fix_to_range is not None))
|
107 |
+
num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None))
|
108 |
+
if equidistant_x:
|
109 |
+
assert num_features == 1
|
110 |
+
x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1)
|
111 |
+
else:
|
112 |
+
x = torch.rand(total_num_candidates, seq_len, num_features, device=device)
|
113 |
+
samples = []
|
114 |
+
samples_wo_noise = []
|
115 |
+
for i in range(0,total_num_candidates,num_candidates):
|
116 |
+
model, likelihood = get_model(x[i:i+num_candidates], torch.zeros(num_candidates,x.shape[1]).to(device), hyperparameters)
|
117 |
+
model.to(device)
|
118 |
+
likelihood.to(device)
|
119 |
+
if hyperparameters.get('handmade', False):
|
120 |
+
model.covar_module.base_kernel.lengthscale = model.covar_module.base_kernel.lengthscale.to(device)
|
121 |
+
model.covar_module.outputscale = model.covar_module.outputscale.to(device)
|
122 |
+
likelihood.noise = likelihood.noise.to(device)
|
123 |
+
model.mean_module.constant = model.mean_module.constant.to(device)
|
124 |
+
|
125 |
+
# trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
|
126 |
+
# trained_model.eval()
|
127 |
+
successful_sample = 0
|
128 |
+
throwaway_share = 0.
|
129 |
+
sampling_with_observation_noise = hyperparameters.get("observation_noise", True)
|
130 |
+
while successful_sample < 1:
|
131 |
+
with gpytorch.settings.prior_mode(True):
|
132 |
+
#print(x.device, device, f'{model.covar_module.base_kernel.lengthscale=}, {model.covar_module.base_kernel.lengthscale.device=}')
|
133 |
+
|
134 |
+
|
135 |
+
if sampling_with_observation_noise :
|
136 |
+
d = model(x[i:i+num_candidates])
|
137 |
+
d = likelihood(d)
|
138 |
+
sample = d.sample() # bs_per_gp_s x T
|
139 |
+
|
140 |
+
else:
|
141 |
+
d = model(x[i:i+num_candidates])
|
142 |
+
sample_wo_noise = d.sample()
|
143 |
+
sample = likelihood(sample_wo_noise).sample()
|
144 |
+
|
145 |
+
if hyperparameters.get('y_minmax_norm'):
|
146 |
+
sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0]))
|
147 |
+
if hyperparameters.get('sigmoid'):
|
148 |
+
sample = sample.sigmoid()
|
149 |
+
|
150 |
+
if not sampling_with_observation_noise:
|
151 |
+
if hyperparameters.get('y_minmax_norm'):
|
152 |
+
sample_wo_noise = ((sample_wo_noise - sample_wo_noise.min(1)[0]) / (sample_wo_noise.max(1)[0] - sample_wo_noise.min(1)[0]))
|
153 |
+
if hyperparameters.get('sigmoid'):
|
154 |
+
sample_wo_noise = sample_wo_noise.sigmoid()
|
155 |
+
|
156 |
+
if fix_to_range is None:
|
157 |
+
samples.append(sample.transpose(0, 1))
|
158 |
+
if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
|
159 |
+
successful_sample = True
|
160 |
+
continue
|
161 |
+
|
162 |
+
smaller_mask = sample < fix_to_range[0]
|
163 |
+
larger_mask = sample >= fix_to_range[1]
|
164 |
+
in_range_mask = ~ (smaller_mask | larger_mask).any(1)
|
165 |
+
throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample
|
166 |
+
if in_range_mask.sum() < batch_size_per_gp_sample:
|
167 |
+
successful_sample -= 1
|
168 |
+
if successful_sample < 100:
|
169 |
+
print("Please change hyper-parameters (e.g. decrease outputscale_mean) it"
|
170 |
+
"seems like the range is set to tight for your hyper-parameters.")
|
171 |
+
continue
|
172 |
+
|
173 |
+
x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample]
|
174 |
+
sample = sample[in_range_mask][:batch_size_per_gp_sample]
|
175 |
+
samples.append(sample.transpose(0,1))
|
176 |
+
if not sampling_with_observation_noise: samples_wo_noise.append(sample_wo_noise.transpose(0,1))
|
177 |
+
successful_sample = True
|
178 |
+
|
179 |
+
if random.random() < .01:
|
180 |
+
print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample))
|
181 |
+
|
182 |
+
#print(f'took {time.time() - start}')
|
183 |
+
|
184 |
+
x = x.view(-1,batch_size,seq_len,num_features)[0]
|
185 |
+
# TODO think about enabling the line below
|
186 |
+
#sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape)
|
187 |
+
x = x.transpose(0,1)
|
188 |
+
sample = torch.cat(samples, 1)
|
189 |
+
|
190 |
+
if sampling_with_observation_noise:
|
191 |
+
target_sample = sample
|
192 |
+
else:
|
193 |
+
target_sample = torch.cat(samples_wo_noise, 1)
|
194 |
+
|
195 |
+
assert x.shape[:2] == sample.shape[:2]
|
196 |
+
|
197 |
+
return x, sample, target_sample # x.shape = (T,B,H)
|
198 |
+
|
199 |
+
|
200 |
+
class DataLoader(get_batch_to_dataloader(get_batch)):
|
201 |
+
@torch.no_grad()
|
202 |
+
def validate(self, model, step_size=1, start_pos=0):
|
203 |
+
if isinstance(model.criterion, BarDistribution):
|
204 |
+
(_, x,y), target_y, eval_pos = self.gbm(**self.get_batch_kwargs)
|
205 |
+
model.eval()
|
206 |
+
losses = []
|
207 |
+
for eval_pos in range(start_pos, len(x), step_size):
|
208 |
+
logits = model((x,y), single_eval_pos=eval_pos)
|
209 |
+
means = model.criterion.mean(logits) # num_evals x batch_size
|
210 |
+
mse = nn.MSELoss()
|
211 |
+
losses.append(mse(means[0], target_y[eval_pos]))
|
212 |
+
model.train()
|
213 |
+
return torch.stack(losses)
|
214 |
+
else:
|
215 |
+
return 123.
|
216 |
+
|
217 |
+
|
218 |
+
@torch.enable_grad()
|
219 |
+
def get_fitted_model(x, y, hyperparameters, device):
|
220 |
+
# fit the gaussian process
|
221 |
+
model, likelihood = get_model(x,y,hyperparameters,sample=False)
|
222 |
+
#print(model.covar_module.base_kernel.lengthscale)
|
223 |
+
model.to(device)
|
224 |
+
mll = ExactMarginalLogLikelihood(likelihood, model)
|
225 |
+
model.train()
|
226 |
+
fit_gpytorch_model(mll)
|
227 |
+
#print(model.covar_module.base_kernel.lengthscale)
|
228 |
+
return model, likelihood
|
229 |
+
|
230 |
+
|
231 |
+
evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model)
|
232 |
+
|
233 |
+
def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps, obs=True):
|
234 |
+
from pyro.infer.mcmc import NUTS, MCMC, HMC
|
235 |
+
import pyro
|
236 |
+
x = x.to(device)
|
237 |
+
y = y.to(device)
|
238 |
+
model, likelihood = get_model(x, y, hyperparameters, sample=False)
|
239 |
+
model.to(device)
|
240 |
+
|
241 |
+
|
242 |
+
def pyro_model(x, y):
|
243 |
+
sampled_model = model.pyro_sample_from_prior()
|
244 |
+
output = sampled_model.likelihood(sampled_model(x))
|
245 |
+
if obs:
|
246 |
+
return pyro.sample("obs", output, obs=y)
|
247 |
+
|
248 |
+
nuts_kernel = NUTS(pyro_model)
|
249 |
+
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1)
|
250 |
+
#print(x.shape)
|
251 |
+
mcmc_run.run(x, y)
|
252 |
+
#print(mcmc_run.get_samples())
|
253 |
+
model.pyro_load_from_samples(mcmc_run.get_samples()) # pyro.infer wie noah?
|
254 |
+
model.eval()
|
255 |
+
#print(mcmc_run.diagnostics())
|
256 |
+
# test_x = torch.linspace(0, 1, 101).unsqueeze(-1)
|
257 |
+
# test_y = torch.sin(test_x * (2 * math.pi))
|
258 |
+
# expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1)
|
259 |
+
# output = model(expanded_test_x)
|
260 |
+
#print(x.shape)
|
261 |
+
return model, likelihood
|
262 |
+
# output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1))
|
263 |
+
# return output.mean
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
|
268 |
+
def get_mean_logdensity(dists, x: torch.Tensor, full_range=None):
|
269 |
+
means = torch.cat([d.mean.squeeze() for d in dists], 0)
|
270 |
+
vars = torch.cat([d.variance.squeeze() for d in dists], 0)
|
271 |
+
assert len(means.shape) == 1 and len(vars.shape) == 1
|
272 |
+
dist = torch.distributions.Normal(means, vars.sqrt())
|
273 |
+
#logprobs = torch.cat([d.log_prob(x) for d in dists], 0)
|
274 |
+
logprobs = dist.log_prob(x)
|
275 |
+
if full_range is not None:
|
276 |
+
used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1]))))
|
277 |
+
if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any():
|
278 |
+
print('factor is inf', -torch.log(used_weight))
|
279 |
+
logprobs -= torch.log(used_weight)
|
280 |
+
assert len(logprobs.shape) == 1
|
281 |
+
#print(logprobs)
|
282 |
+
return torch.logsumexp(logprobs, 0) - math.log(len(logprobs))
|
283 |
+
|
284 |
+
|
285 |
+
def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300,
|
286 |
+
full_range=None, min_seq_len=0, use_likelihood=False, obs=True):
|
287 |
+
with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False):
|
288 |
+
x = x.to(device).double()
|
289 |
+
y = y.to(device).double()
|
290 |
+
start_time = time.time()
|
291 |
+
losses_after_t = [.0] if min_seq_len == 0 else []
|
292 |
+
all_losses = []
|
293 |
+
|
294 |
+
for t in range(max(min_seq_len,1), len(x)):
|
295 |
+
#print('Timestep', t)
|
296 |
+
loss_sum = 0.
|
297 |
+
step_losses = []
|
298 |
+
start_step = time.time()
|
299 |
+
print(x.shape, y.shape)
|
300 |
+
for b_i in range(x.shape[1]):
|
301 |
+
x_train = x[:t,b_i]
|
302 |
+
y_train = y[:t,b_i]
|
303 |
+
from pyro.infer.mcmc import NUTS, MCMC, HMC
|
304 |
+
import pyro
|
305 |
+
x_train = x_train.to(device)
|
306 |
+
y_train = y_train.to(device)
|
307 |
+
print(x_train.shape, y_train.shape)
|
308 |
+
model, likelihood = get_model(x_train, y_train, hyperparameters, sample=False)
|
309 |
+
model.to(device)
|
310 |
+
|
311 |
+
def pyro_model(x, y):
|
312 |
+
sampled_model = model.pyro_sample_from_prior()
|
313 |
+
output = sampled_model.likelihood(sampled_model(x))
|
314 |
+
if obs:
|
315 |
+
return pyro.sample("obs", output, obs=y)
|
316 |
+
|
317 |
+
nuts_kernel = NUTS(pyro_model)
|
318 |
+
mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, num_chains=1, disable_progbar=True)
|
319 |
+
# print(x.shape)
|
320 |
+
mcmc_run.run(x_train, y_train)
|
321 |
+
# print(mcmc_run.get_samples())
|
322 |
+
model.pyro_load_from_samples(mcmc_run.get_samples())
|
323 |
+
model.eval()
|
324 |
+
|
325 |
+
with torch.no_grad():
|
326 |
+
dists = model(x[t, b_i, :].unsqueeze(
|
327 |
+
0).repeat(num_samples, 1, 1))
|
328 |
+
if use_likelihood:
|
329 |
+
dists = likelihood(dists)
|
330 |
+
l = -get_mean_logdensity([dists], y[t, b_i].repeat(num_samples), full_range)
|
331 |
+
print(l)
|
332 |
+
|
333 |
+
step_losses.append(l.item())
|
334 |
+
#print('loss',l.item())
|
335 |
+
print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.')
|
336 |
+
loss_sum += l
|
337 |
+
|
338 |
+
loss_sum /= x.shape[1]
|
339 |
+
all_losses.append(step_losses)
|
340 |
+
print(f'loss after step {t} is {loss_sum}')
|
341 |
+
losses_after_t.append(loss_sum)
|
342 |
+
print(f'losses so far {torch.tensor(losses_after_t)}')
|
343 |
+
return torch.tensor(losses_after_t), time.time() - start_time, all_losses
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
|
348 |
+
|
349 |
+
if __name__ == '__main__':
|
350 |
+
import argparse
|
351 |
+
|
352 |
+
parser = argparse.ArgumentParser()
|
353 |
+
parser.add_argument('--batch_size', type=int)
|
354 |
+
parser.add_argument('--seq_len', type=int)
|
355 |
+
parser.add_argument('--min_seq_len', type=int, default=0)
|
356 |
+
parser.add_argument('--warmup_steps', type=int)
|
357 |
+
parser.add_argument('--num_samples', type=int)
|
358 |
+
parser.add_argument('--min_y', type=int)
|
359 |
+
parser.add_argument('--max_y', type=int)
|
360 |
+
parser.add_argument('--dim', type=int, default=1)
|
361 |
+
parser.add_argument('--use_likelihood', action='store_true')
|
362 |
+
parser.add_argument('--device', default='cpu')
|
363 |
+
parser.add_argument('--outputscale_concentraion', default=2., type=float)
|
364 |
+
parser.add_argument('--noise_concentration', default=1.1, type=float)
|
365 |
+
parser.add_argument('--noise_rate', default=.05, type=float)
|
366 |
+
parser.add_argument('--handmade', action='store_true')
|
367 |
+
parser.add_argument('--no_obs', action='store_true')
|
368 |
+
parser.add_argument('--seed', type=int, default=0)
|
369 |
+
|
370 |
+
args = parser.parse_args()
|
371 |
+
import pyro
|
372 |
+
import gpytorch
|
373 |
+
print(pyro.__version__)
|
374 |
+
print(gpytorch.__version__)
|
375 |
+
|
376 |
+
|
377 |
+
print('min_y:', args.min_y)
|
378 |
+
full_range = (None if args.min_y is None else (args.min_y,args.max_y))
|
379 |
+
|
380 |
+
hps = {'handmade': args.handmade, 'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration,
|
381 |
+
'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)}
|
382 |
+
if args.seed:
|
383 |
+
torch.manual_seed(args.seed)
|
384 |
+
np.random.seed(args.seed)
|
385 |
+
random.seed(args.seed)
|
386 |
+
x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps)
|
387 |
+
#assert args.seq_len == 7 and args.min_seq_len == 6
|
388 |
+
#x = torch.cat([torch.linspace(0, 1, 6), torch.tensor([.33])]).unsqueeze(1).repeat(1,args.batch_size).unsqueeze(-1)
|
389 |
+
#y = torch.sin(x * (2 * math.pi)).squeeze(-1)
|
390 |
+
print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps,
|
391 |
+
num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len,
|
392 |
+
hyperparameters=hps, use_likelihood=args.use_likelihood, obs=not args.no_obs))
|
393 |
+
|
394 |
+
|
lcpfn/priors/gp.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from sklearn.gaussian_process import GaussianProcessRegressor
|
8 |
+
from sklearn.gaussian_process.kernels import RBF, DotProduct, WhiteKernel
|
9 |
+
from .utils import get_batch_to_dataloader
|
10 |
+
|
11 |
+
|
12 |
+
length_scale_sampling_gp = .6
|
13 |
+
|
14 |
+
def get_gp(length_scale=None):
|
15 |
+
return GaussianProcessRegressor(
|
16 |
+
kernel=RBF(length_scale=length_scale or length_scale_sampling_gp, length_scale_bounds='fixed'),
|
17 |
+
random_state=0, optimizer=None)
|
18 |
+
|
19 |
+
|
20 |
+
def get_batch(batch_size, seq_len, num_features, noisy_std=None):
|
21 |
+
# m = torch.normal(0.,.1,size=(batch_size,num_features))
|
22 |
+
# m2 = torch.rand(batch_size,num_features)
|
23 |
+
# b = 0 # torch.rand(batch_size)
|
24 |
+
x_t = torch.rand(batch_size, seq_len, num_features)
|
25 |
+
# gp_b = TensorGP(kernel=TensorRBF(noisy_std))
|
26 |
+
# y_t = gp_b.sample_from_GP_prior(x_t).detach()
|
27 |
+
|
28 |
+
gpr = get_gp(noisy_std)
|
29 |
+
y_t = torch.zeros(batch_size, seq_len)
|
30 |
+
|
31 |
+
for i in range(len(y_t)):
|
32 |
+
y_t[i] += gpr.sample_y(x_t[i], random_state=random.randint(0, 2 ** 32)).squeeze()
|
33 |
+
x, y = x_t.transpose(0, 1), y_t.transpose(0, 1)
|
34 |
+
# x, _ = torch.sort(x,dim=0)
|
35 |
+
return x, y, y
|
36 |
+
|
37 |
+
|
38 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
39 |
+
|
40 |
+
def evaluate(x, y, y_non_noisy, use_mse=False, length_scale=length_scale_sampling_gp):
|
41 |
+
start_time = time.time()
|
42 |
+
losses_after_t = [.0]
|
43 |
+
for t in range(1, len(x)):
|
44 |
+
loss_sum = 0.
|
45 |
+
for b_i in range(x.shape[1]):
|
46 |
+
gpr = get_gp(length_scale).fit(x[:t, b_i], y[:t, b_i])
|
47 |
+
means, stds = gpr.predict(x[t, b_i].unsqueeze(0), return_std=True)
|
48 |
+
assert len(means) == 1 == len(stds)
|
49 |
+
if use_mse:
|
50 |
+
c = nn.MSELoss()
|
51 |
+
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1))
|
52 |
+
else:
|
53 |
+
c = nn.GaussianNLLLoss(full=True)
|
54 |
+
l = c(torch.tensor(means), y[t, b_i].unsqueeze(-1),
|
55 |
+
var=torch.tensor(stds) ** 2)
|
56 |
+
loss_sum += l
|
57 |
+
|
58 |
+
|
59 |
+
losses_after_t.append(loss_sum / x.shape[1])
|
60 |
+
|
61 |
+
return torch.tensor(losses_after_t), time.time()-start_time
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
ls = .1
|
65 |
+
for alpha in set([ls, ls * 1.1, ls * .9]):
|
66 |
+
print(alpha)
|
67 |
+
for redo_idx in range(1):
|
68 |
+
print(
|
69 |
+
evaluate(*get_batch(1000, 10, noisy_std=ls, num_features=10), use_mse=False, length_scale=alpha))
|
lcpfn/priors/prior.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
|
4 |
+
|
5 |
+
class PriorDataLoader(DataLoader, metaclass=ABCMeta):
|
6 |
+
@abstractmethod
|
7 |
+
def __init__(self, num_steps, batch_size, eval_pos_seq_len_sampler, seq_len_maximum, device, **kwargs):
|
8 |
+
"""
|
9 |
+
|
10 |
+
:param num_steps: int, first argument, the number of steps to take per epoch, i.e. iteration of the DataLoader
|
11 |
+
:param batch_size: int, number of datasets per batch
|
12 |
+
:param eval_pos_seq_len_sampler: callable, it takes no arguments and returns a tuple (single eval pos, bptt)
|
13 |
+
:param kwargs: for future compatibility it is good to have a final all catch, as new kwargs might be introduced
|
14 |
+
"""
|
15 |
+
pass
|
16 |
+
|
17 |
+
# A class or object variable `num_features`: int
|
18 |
+
# Optional: `validate` function that accepts a transformer model
|
19 |
+
|
20 |
+
# The DataLoader iter should return batches of the form ([style], x, y), target_y, single_eval_pos
|
21 |
+
# We follow sequence len (s) first, batch size (b) second. So x: (s,b,num_features), y,target_y: (s,b)
|
22 |
+
# and style: Optional[(b,num_style_params)], style can be omitted or set to None, if it is not intended to be used.
|
23 |
+
|
24 |
+
# For more references, see `priors/utils.py` for a pretty general implementation of a DataLoader
|
25 |
+
# and `train.py` for the only call of it.
|
lcpfn/priors/pyro.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from utils import default_device
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
|
9 |
+
|
10 |
+
def get_batch(batch_size, seq_len, batch_size_per_gp_sample=None, **config):
|
11 |
+
batch_size_per_gp_sample = batch_size_per_gp_sample or batch_size // 16
|
12 |
+
assert batch_size % batch_size_per_gp_sample == 0, 'Please choose a batch_size divisible by batch_size_per_gp_sample.'
|
13 |
+
num_models = batch_size // batch_size_per_gp_sample
|
14 |
+
# standard kaiming uniform init currently...
|
15 |
+
|
16 |
+
models = [config['model']() for _ in range(num_models)]
|
17 |
+
|
18 |
+
sample = sum([[model(seq_len=seq_len) for _ in range(0,batch_size_per_gp_sample)] for model in models],[])
|
19 |
+
|
20 |
+
def normalize_data(data):
|
21 |
+
mean = data.mean(0)
|
22 |
+
std = data.std(0) + .000001
|
23 |
+
eval_xs = (data - mean) / std
|
24 |
+
|
25 |
+
return eval_xs
|
26 |
+
|
27 |
+
x, y = zip(*sample)
|
28 |
+
|
29 |
+
y = torch.stack(y, 1).squeeze(-1).detach()
|
30 |
+
x = torch.stack(x, 1).detach()
|
31 |
+
|
32 |
+
if 'normalize_y' in config and config['normalize_y']:
|
33 |
+
x, y = normalize_data(x), normalize_data(y)
|
34 |
+
elif 'normalize_y' in config and config['normalize']:
|
35 |
+
x, y = normalize_data(x), y
|
36 |
+
|
37 |
+
return x, y, y
|
38 |
+
|
39 |
+
|
40 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
41 |
+
|
lcpfn/priors/ridge.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from sklearn.linear_model import Ridge
|
8 |
+
from .utils import get_batch_to_dataloader
|
9 |
+
|
10 |
+
def get_batch(batch_size, seq_len, num_features, noisy_std = .1):
|
11 |
+
m = torch.normal(0., .1, size=(batch_size,num_features))
|
12 |
+
b = 0 # torch.rand(batch_size)
|
13 |
+
x = torch.rand(seq_len, batch_size,num_features)
|
14 |
+
y_non_noisy = torch.einsum('bf,tbf->tb',m,x)
|
15 |
+
y = y_non_noisy + torch.normal(torch.zeros_like(y_non_noisy),noisy_std) # noisy_std is alpha
|
16 |
+
return x, y, y_non_noisy
|
17 |
+
|
18 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
19 |
+
|
20 |
+
|
21 |
+
def evaluate(x,y,y_non_noisy, alpha=0.):
|
22 |
+
start_time = time.time()
|
23 |
+
losses_after_t = [.0]
|
24 |
+
for t in range(1,len(x)):
|
25 |
+
loss_sum = 0.
|
26 |
+
for b_i in range(x.shape[1]):
|
27 |
+
clf = Ridge(alpha=alpha)
|
28 |
+
clf.fit(x[:t,b_i],y[:t,b_i])
|
29 |
+
y_ = clf.predict(x[t,b_i].unsqueeze(0))
|
30 |
+
l = nn.MSELoss()(y_non_noisy[t,b_i].unsqueeze(0),torch.tensor(y_))
|
31 |
+
loss_sum += l
|
32 |
+
losses_after_t.append(loss_sum/x.shape[1])
|
33 |
+
return torch.tensor(losses_after_t), time.time()-start_time
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
for alpha in [.001,.01,.5,1.]:
|
37 |
+
print(alpha, evaluate(*get_batch(1000,10,noisy_std=.01),alpha=alpha))
|
lcpfn/priors/stroke.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFilter
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from .utils import get_batch_to_dataloader
|
8 |
+
|
9 |
+
def mnist_prior(num_classes=2, size=28, min_max_strokes=(1,3), min_max_len=(5/28,20/28), min_max_start=(2/28,25/28),
|
10 |
+
min_max_width=(1/28,4/28), max_offset=4/28, max_target_offset=2/28):
|
11 |
+
classes = []
|
12 |
+
for i in range(num_classes):
|
13 |
+
num_strokes = random.randint(*min_max_strokes)
|
14 |
+
len_strokes = [random.randint(int(size * min_max_len[0]), int(size * min_max_len[1])) for i in range(num_strokes)]
|
15 |
+
stroke_start_points = [
|
16 |
+
(random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1]))) for i in
|
17 |
+
range(num_strokes)]
|
18 |
+
stroke_directions = []
|
19 |
+
# i = Image.fromarray(np.zeros((28,28),dtype=np.uint8))
|
20 |
+
# draw = ImageDraw.Draw(i)
|
21 |
+
for i in range(num_strokes):
|
22 |
+
sp, length = stroke_start_points[i], len_strokes[i]
|
23 |
+
counter = 0
|
24 |
+
while True:
|
25 |
+
if counter % 3 == 0:
|
26 |
+
length = random.randint(int(size * min_max_len[0]), int(size * min_max_len[1]))
|
27 |
+
sp = (
|
28 |
+
random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])), random.randint(int(size * min_max_start[0]), int(size * min_max_start[1])))
|
29 |
+
stroke_start_points[i], len_strokes[i] = sp, length
|
30 |
+
radians = random.random() * 2 * math.pi
|
31 |
+
x_vel = math.cos(radians) * length
|
32 |
+
y_vel = math.sin(radians) * length
|
33 |
+
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
34 |
+
# print(math.degrees(radians),sp,new_p)
|
35 |
+
if not any(n > size - 1 or n < 0 for n in new_p):
|
36 |
+
break
|
37 |
+
counter += 1
|
38 |
+
stroke_directions.append(radians)
|
39 |
+
# print([round(x) for x in sp+new_p])
|
40 |
+
# draw.line([round(x) for x in sp+new_p], fill=128, width=3)
|
41 |
+
classes.append((len_strokes, stroke_start_points, stroke_directions))
|
42 |
+
|
43 |
+
generator_functions = []
|
44 |
+
for c in classes:
|
45 |
+
def g(c=c):
|
46 |
+
len_strokes, stroke_start_points, stroke_directions = c
|
47 |
+
i = Image.fromarray(np.zeros((size, size), dtype=np.uint8))
|
48 |
+
draw = ImageDraw.Draw(i)
|
49 |
+
width = random.randint(int(size * min_max_width[0]), int(size * min_max_width[1]))
|
50 |
+
offset = random.randint(int(-size * max_offset), int(size * max_offset)), random.randint(int(- size * max_offset), int(size * max_offset))
|
51 |
+
for sp, length, radians in zip(stroke_start_points, len_strokes, stroke_directions):
|
52 |
+
sp = (sp[0] + offset[0], sp[1] + offset[1])
|
53 |
+
x_vel = math.cos(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
54 |
+
y_vel = math.sin(radians) * length + random.randint(int(-size * max_target_offset), int(size * max_target_offset))
|
55 |
+
new_p = (sp[0] + x_vel, sp[1] + y_vel)
|
56 |
+
stroke_directions.append(radians)
|
57 |
+
draw.line([round(x) for x in sp + new_p], fill=128, width=width)
|
58 |
+
a_i = np.array(i)
|
59 |
+
a_i[a_i == 128] = np.random.randint(200, 255, size=a_i.shape)[a_i == 128]
|
60 |
+
return Image.fromarray(a_i).filter(ImageFilter.GaussianBlur(.2))
|
61 |
+
|
62 |
+
generator_functions.append(g)
|
63 |
+
return generator_functions
|
64 |
+
|
65 |
+
|
66 |
+
# g1,g2 = mnist_prior(2)
|
67 |
+
|
68 |
+
# for i in [g1() for _ in range(10)]:
|
69 |
+
# display(i.resize((200,200)))
|
70 |
+
|
71 |
+
from torchvision.transforms import ToTensor, ToPILImage
|
72 |
+
|
73 |
+
|
74 |
+
def normalize(x):
|
75 |
+
return (x-x.mean())/(x.std()+.000001)
|
76 |
+
|
77 |
+
from os import path, listdir
|
78 |
+
import random
|
79 |
+
|
80 |
+
def get_batch(batch_size, seq_len, num_features=None, noisy_std=None, only_train_for_last_idx=False, normalize_x=False, num_outputs=2, use_saved_from=None, **kwargs): # num_features = 28*28=784
|
81 |
+
if use_saved_from is not None:
|
82 |
+
directory = path.join(use_saved_from, f'len_{seq_len}_out_{num_outputs}_features_{num_features}_bs_{batch_size}')
|
83 |
+
filename = random.choice(listdir(directory))
|
84 |
+
return torch.load(path.join(directory,filename))
|
85 |
+
|
86 |
+
size = math.isqrt(num_features)
|
87 |
+
assert size * size == num_features, 'num_features needs to be the square of an integer.'
|
88 |
+
if only_train_for_last_idx:
|
89 |
+
assert (seq_len-1) % num_outputs == 0
|
90 |
+
|
91 |
+
# assert seq_len % 2 == 0, "assert seq_len % 2 == 0"
|
92 |
+
batch = []
|
93 |
+
y = []
|
94 |
+
target_y = []
|
95 |
+
for b_i in range(batch_size):
|
96 |
+
gs = mnist_prior(num_outputs, size, **kwargs)
|
97 |
+
if only_train_for_last_idx:
|
98 |
+
generators = [i for i in range(len(gs)) for _ in range((seq_len-1) // num_outputs)]
|
99 |
+
random.shuffle(generators)
|
100 |
+
generators += [random.randint(0, len(gs) - 1)]
|
101 |
+
target = [-100 for _ in generators]
|
102 |
+
target[-1] = generators[-1]
|
103 |
+
else:
|
104 |
+
generators = [random.randint(0, len(gs) - 1) for _ in range(seq_len)]
|
105 |
+
target = generators
|
106 |
+
normalize_or_not = lambda x: normalize(x) if normalize_x else x
|
107 |
+
s = torch.cat([normalize_or_not(ToTensor()(gs[f_i]())) for f_i in generators], 0)
|
108 |
+
batch.append(s)
|
109 |
+
y.append(torch.tensor(generators))
|
110 |
+
target_y.append(torch.tensor(target))
|
111 |
+
x = torch.stack(batch, 1).view(seq_len, batch_size, -1)
|
112 |
+
y = torch.stack(y, 1)
|
113 |
+
target_y = torch.stack(target_y, 1)
|
114 |
+
return x,y,target_y
|
115 |
+
|
116 |
+
DataLoader = get_batch_to_dataloader(get_batch)
|
117 |
+
DataLoader.num_outputs = 2
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
g1, g2 = mnist_prior(2, size=3)
|
121 |
+
|
122 |
+
# for i in range(10):
|
123 |
+
# print(PILToTensor()(g1()))
|
124 |
+
# display(ToPILImage()(PILToTensor()(g1())).resize((200,200)))
|
125 |
+
# display(g2().resize((200,200)))
|
126 |
+
|
127 |
+
size = 10
|
128 |
+
x, y = get_batch(1, 10, num_features=size * size)
|
129 |
+
|
130 |
+
x_ = x[..., :-1].squeeze(1)
|
131 |
+
last_y = x[..., -1].squeeze(1)
|
132 |
+
y = y.squeeze(1)
|
133 |
+
|
134 |
+
# print(y)
|
135 |
+
|
136 |
+
for i, y_, last_y_, x__ in zip(x_, y, last_y, x.squeeze(1)):
|
137 |
+
# print(y_)
|
138 |
+
# print(i.shape)
|
139 |
+
# print(x__)
|
140 |
+
img = ToPILImage()(i.view(size, size))
|
141 |
+
# display(img.resize((200,200)))
|
142 |
+
|
143 |
+
print(y, last_y)
|
lcpfn/priors/utils.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from lcpfn.utils import set_locals_in_self
|
7 |
+
from itertools import repeat
|
8 |
+
from .prior import PriorDataLoader
|
9 |
+
from torch import nn
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import matplotlib.gridspec as gridspec
|
13 |
+
import scipy.stats as stats
|
14 |
+
import math
|
15 |
+
|
16 |
+
def get_batch_to_dataloader(get_batch_method_):
|
17 |
+
class DL(PriorDataLoader):
|
18 |
+
get_batch_method = get_batch_method_
|
19 |
+
|
20 |
+
# Caution, you might need to set self.num_features manually if it is not part of the args.
|
21 |
+
def __init__(self, num_steps, **get_batch_kwargs):
|
22 |
+
set_locals_in_self(locals())
|
23 |
+
|
24 |
+
# The stuff outside the or is set as class attribute before instantiation.
|
25 |
+
self.num_features = get_batch_kwargs.get('num_features') or self.num_features
|
26 |
+
print('DataLoader.__dict__', self.__dict__)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def gbm(*args, eval_pos_seq_len_sampler, **kwargs):
|
30 |
+
kwargs['single_eval_pos'], kwargs['seq_len'] = eval_pos_seq_len_sampler()
|
31 |
+
# Scales the batch size dynamically with the power of 'dynamic_batch_size'.
|
32 |
+
# A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
|
33 |
+
if 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0:
|
34 |
+
kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
|
35 |
+
batch = get_batch_method_(*args, **kwargs)
|
36 |
+
x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
|
37 |
+
return (style, x, y), target_y, kwargs['single_eval_pos']
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return self.num_steps
|
41 |
+
|
42 |
+
def __iter__(self):
|
43 |
+
return iter(self.gbm(**self.get_batch_kwargs) for _ in range(self.num_steps))
|
44 |
+
|
45 |
+
return DL
|
46 |
+
|
47 |
+
"""
|
48 |
+
import seaborn as sns
|
49 |
+
def plot_features(data, targets, fig=None):
|
50 |
+
if torch.is_tensor(data):
|
51 |
+
data = data.detach().cpu().numpy()
|
52 |
+
targets = targets.detach().cpu().numpy()
|
53 |
+
fig2 = plt.figure(figsize=(8, 8))
|
54 |
+
spec2 = gridspec.GridSpec(ncols=data.shape[1], nrows=data.shape[1], figure=fig2)
|
55 |
+
for d in range(0, data.shape[1]):
|
56 |
+
for d2 in range(0, data.shape[1]):
|
57 |
+
sub_ax = fig2.add_subplot(spec2[d, d2])
|
58 |
+
if d == d2:
|
59 |
+
sns.kdeplot(data[:, d],hue=targets[:],ax=sub_ax,legend=False, palette="deep")
|
60 |
+
sub_ax.set(ylabel=None)
|
61 |
+
else:
|
62 |
+
sns.scatterplot(data[:, d], data[:, d2],
|
63 |
+
hue=targets[:],legend=False, palette="deep")
|
64 |
+
#plt.scatter(data[:, d], data[:, d2],
|
65 |
+
# c=targets[:])
|
66 |
+
sub_ax.get_xaxis().set_ticks([])
|
67 |
+
sub_ax.get_yaxis().set_ticks([])
|
68 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
69 |
+
fig2.show()
|
70 |
+
|
71 |
+
|
72 |
+
def plot_prior(prior):
|
73 |
+
s = np.array([prior() for _ in range(0, 1000)])
|
74 |
+
count, bins, ignored = plt.hist(s, 50, density=True)
|
75 |
+
print(s.min())
|
76 |
+
plt.show()
|
77 |
+
"""
|
78 |
+
|
79 |
+
trunc_norm_sampler_f = lambda mu, sigma : lambda: stats.truncnorm((0 - mu) / sigma, (1000000 - mu) / sigma, loc=mu, scale=sigma).rvs(1)[0]
|
80 |
+
beta_sampler_f = lambda a, b : lambda : np.random.beta(a, b)
|
81 |
+
gamma_sampler_f = lambda a, b : lambda : np.random.gamma(a, b)
|
82 |
+
uniform_sampler_f = lambda a, b : lambda : np.random.uniform(a, b)
|
83 |
+
uniform_int_sampler_f = lambda a, b : lambda : round(np.random.uniform(a, b))
|
84 |
+
def zipf_sampler_f(a, b, c):
|
85 |
+
x = np.arange(b, c)
|
86 |
+
weights = x ** (-a)
|
87 |
+
weights /= weights.sum()
|
88 |
+
return lambda : stats.rv_discrete(name='bounded_zipf', values=(x, weights)).rvs(1)
|
89 |
+
scaled_beta_sampler_f = lambda a, b, scale, minimum : lambda : minimum + round(beta_sampler_f(a, b)() * (scale - minimum))
|
90 |
+
|
91 |
+
|
92 |
+
def normalize_by_used_features_f(x, num_features_used, num_features, normalize_with_sqrt=False):
|
93 |
+
if normalize_with_sqrt:
|
94 |
+
return x / (num_features_used / num_features)**(1 / 2)
|
95 |
+
return x / (num_features_used / num_features)
|
96 |
+
|
97 |
+
|
98 |
+
def order_by_y(x, y):
|
99 |
+
order = torch.argsort(y if random.randint(0, 1) else -y, dim=0)[:, 0, 0]
|
100 |
+
order = order.reshape(2, -1).transpose(0, 1).reshape(-1)#.reshape(seq_len)
|
101 |
+
x = x[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).flip([0]).reshape(seq_len, 1, -1)
|
102 |
+
y = y[order] # .reshape(2, -1).transpose(0, 1).reshape(-1).reshape(seq_len, 1, -1)
|
103 |
+
|
104 |
+
return x, y
|
105 |
+
|
106 |
+
def randomize_classes(x, num_classes):
|
107 |
+
classes = torch.arange(0, num_classes, device=x.device)
|
108 |
+
random_classes = torch.randperm(num_classes, device=x.device).type(x.type())
|
109 |
+
x = ((x.unsqueeze(-1) == classes) * random_classes).sum(-1)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class CategoricalActivation(nn.Module):
|
114 |
+
def __init__(self, categorical_p=0.1, ordered_p=0.7
|
115 |
+
, keep_activation_size=False
|
116 |
+
, num_classes_sampler=zipf_sampler_f(0.8, 1, 10)):
|
117 |
+
self.categorical_p = categorical_p
|
118 |
+
self.ordered_p = ordered_p
|
119 |
+
self.keep_activation_size = keep_activation_size
|
120 |
+
self.num_classes_sampler = num_classes_sampler
|
121 |
+
|
122 |
+
super().__init__()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
# x shape: T, B, H
|
126 |
+
|
127 |
+
x = nn.Softsign()(x)
|
128 |
+
|
129 |
+
num_classes = self.num_classes_sampler()
|
130 |
+
hid_strength = torch.abs(x).mean(0).unsqueeze(0) if self.keep_activation_size else None
|
131 |
+
|
132 |
+
categorical_classes = torch.rand((x.shape[1], x.shape[2])) < self.categorical_p
|
133 |
+
class_boundaries = torch.zeros((num_classes - 1, x.shape[1], x.shape[2]), device=x.device, dtype=x.dtype)
|
134 |
+
# Sample a different index for each hidden dimension, but shared for all batches
|
135 |
+
for b in range(x.shape[1]):
|
136 |
+
for h in range(x.shape[2]):
|
137 |
+
ind = torch.randint(0, x.shape[0], (num_classes - 1,))
|
138 |
+
class_boundaries[:, b, h] = x[ind, b, h]
|
139 |
+
|
140 |
+
for b in range(x.shape[1]):
|
141 |
+
x_rel = x[:, b, categorical_classes[b]]
|
142 |
+
boundaries_rel = class_boundaries[:, b, categorical_classes[b]].unsqueeze(1)
|
143 |
+
x[:, b, categorical_classes[b]] = (x_rel > boundaries_rel).sum(dim=0).float() - num_classes / 2
|
144 |
+
|
145 |
+
ordered_classes = torch.rand((x.shape[1],x.shape[2])) < self.ordered_p
|
146 |
+
ordered_classes = torch.logical_and(ordered_classes, categorical_classes)
|
147 |
+
x[:, ordered_classes] = randomize_classes(x[:, ordered_classes], num_classes)
|
148 |
+
|
149 |
+
x = x * hid_strength if self.keep_activation_size else x
|
150 |
+
|
151 |
+
return x
|
lcpfn/train.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import itertools
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import datetime
|
6 |
+
import yaml
|
7 |
+
from contextlib import nullcontext
|
8 |
+
|
9 |
+
import pickle
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from lcpfn import utils
|
14 |
+
from lcpfn.transformer import TransformerModel
|
15 |
+
from lcpfn.bar_distribution import (
|
16 |
+
BarDistribution,
|
17 |
+
FullSupportBarDistribution,
|
18 |
+
get_bucket_limits,
|
19 |
+
)
|
20 |
+
from lcpfn.utils import (
|
21 |
+
get_cosine_schedule_with_warmup,
|
22 |
+
get_openai_lr,
|
23 |
+
StoreDictKeyPair,
|
24 |
+
get_weighted_single_eval_pos_sampler,
|
25 |
+
get_uniform_single_eval_pos_sampler,
|
26 |
+
)
|
27 |
+
from lcpfn import priors
|
28 |
+
from lcpfn import encoders
|
29 |
+
from lcpfn import positional_encodings
|
30 |
+
from lcpfn.utils import init_dist
|
31 |
+
from torch.cuda.amp import autocast, GradScaler
|
32 |
+
|
33 |
+
|
34 |
+
class Losses:
|
35 |
+
gaussian = nn.GaussianNLLLoss(full=True, reduction="none")
|
36 |
+
mse = nn.MSELoss(reduction="none")
|
37 |
+
ce = lambda num_classes: nn.CrossEntropyLoss(
|
38 |
+
reduction="none", weight=torch.ones(num_classes)
|
39 |
+
)
|
40 |
+
bce = nn.BCEWithLogitsLoss(reduction="none")
|
41 |
+
get_BarDistribution = BarDistribution
|
42 |
+
|
43 |
+
|
44 |
+
def train(
|
45 |
+
priordataloader_class,
|
46 |
+
criterion,
|
47 |
+
encoder_generator,
|
48 |
+
emsize=200,
|
49 |
+
nhid=200,
|
50 |
+
nlayers=6,
|
51 |
+
nhead=2,
|
52 |
+
dropout=0.2,
|
53 |
+
epochs=10,
|
54 |
+
steps_per_epoch=100,
|
55 |
+
batch_size=200,
|
56 |
+
bptt=10,
|
57 |
+
lr=None,
|
58 |
+
weight_decay=0.0,
|
59 |
+
warmup_epochs=10,
|
60 |
+
input_normalization=False,
|
61 |
+
y_encoder_generator=None,
|
62 |
+
pos_encoder_generator=None,
|
63 |
+
decoder=None,
|
64 |
+
extra_prior_kwargs_dict={},
|
65 |
+
scheduler=get_cosine_schedule_with_warmup,
|
66 |
+
load_weights_from_this_state_dict=None,
|
67 |
+
validation_period=10,
|
68 |
+
single_eval_pos_gen=None,
|
69 |
+
bptt_extra_samples=None,
|
70 |
+
gpu_device="cuda:0",
|
71 |
+
aggregate_k_gradients=1,
|
72 |
+
verbose=True,
|
73 |
+
style_encoder_generator=None,
|
74 |
+
epoch_callback=None,
|
75 |
+
initializer=None,
|
76 |
+
initialize_with_model=None,
|
77 |
+
train_mixed_precision=False,
|
78 |
+
saving_period=10,
|
79 |
+
checkpoint_file=None,
|
80 |
+
load_optimizer_from_this_state_dict=None,
|
81 |
+
output_path=None,
|
82 |
+
**model_extra_args,
|
83 |
+
):
|
84 |
+
device = gpu_device if torch.cuda.is_available() else "cpu:0"
|
85 |
+
print(f"Using {device} device")
|
86 |
+
using_dist, rank, device = init_dist(device)
|
87 |
+
single_eval_pos_gen = (
|
88 |
+
single_eval_pos_gen
|
89 |
+
if callable(single_eval_pos_gen)
|
90 |
+
else lambda: single_eval_pos_gen
|
91 |
+
)
|
92 |
+
|
93 |
+
def eval_pos_seq_len_sampler():
|
94 |
+
single_eval_pos = single_eval_pos_gen()
|
95 |
+
if bptt_extra_samples:
|
96 |
+
return single_eval_pos, single_eval_pos + bptt_extra_samples
|
97 |
+
else:
|
98 |
+
return single_eval_pos, bptt
|
99 |
+
|
100 |
+
dl = priordataloader_class(
|
101 |
+
num_steps=steps_per_epoch,
|
102 |
+
batch_size=batch_size,
|
103 |
+
eval_pos_seq_len_sampler=eval_pos_seq_len_sampler,
|
104 |
+
seq_len_maximum=bptt + (bptt_extra_samples if bptt_extra_samples else 0),
|
105 |
+
device=device,
|
106 |
+
**extra_prior_kwargs_dict,
|
107 |
+
)
|
108 |
+
|
109 |
+
encoder = encoder_generator(dl.num_features, emsize)
|
110 |
+
style_def = next(iter(dl))[0][
|
111 |
+
0
|
112 |
+
] # This is (style, x, y), target with x and y with batch size
|
113 |
+
print(f"Style definition: {style_def}")
|
114 |
+
style_encoder = (
|
115 |
+
style_encoder_generator(hyperparameter_definitions=style_def[0], em_size=emsize)
|
116 |
+
if (style_def is not None)
|
117 |
+
else None
|
118 |
+
)
|
119 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
120 |
+
n_out = 2
|
121 |
+
elif (
|
122 |
+
isinstance(criterion, BarDistribution)
|
123 |
+
or "BarDistribution" in criterion.__class__.__name__
|
124 |
+
): # TODO remove this fix (only for dev)
|
125 |
+
n_out = criterion.num_bars
|
126 |
+
elif isinstance(criterion, nn.CrossEntropyLoss):
|
127 |
+
n_out = criterion.weight.shape[0]
|
128 |
+
else:
|
129 |
+
n_out = 1
|
130 |
+
model = TransformerModel(
|
131 |
+
encoder,
|
132 |
+
n_out,
|
133 |
+
emsize,
|
134 |
+
nhead,
|
135 |
+
nhid,
|
136 |
+
nlayers,
|
137 |
+
dropout,
|
138 |
+
style_encoder=style_encoder,
|
139 |
+
y_encoder=y_encoder_generator(1, emsize),
|
140 |
+
input_normalization=input_normalization,
|
141 |
+
pos_encoder=(
|
142 |
+
pos_encoder_generator or positional_encodings.NoPositionalEncoding
|
143 |
+
)(emsize, bptt * 2),
|
144 |
+
decoder=decoder,
|
145 |
+
init_method=initializer,
|
146 |
+
**model_extra_args,
|
147 |
+
)
|
148 |
+
model.criterion = criterion
|
149 |
+
if load_weights_from_this_state_dict is not None:
|
150 |
+
model.load_state_dict(load_weights_from_this_state_dict)
|
151 |
+
if initialize_with_model is not None:
|
152 |
+
model.init_from_small_model(initialize_with_model)
|
153 |
+
|
154 |
+
print(
|
155 |
+
f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters"
|
156 |
+
)
|
157 |
+
|
158 |
+
try:
|
159 |
+
for (k, v), (k2, v2) in zip(
|
160 |
+
model.state_dict().items(), initialize_with_model.state_dict().items()
|
161 |
+
):
|
162 |
+
print(k, ((v - v2) / v).abs().mean(), v.shape)
|
163 |
+
except Exception:
|
164 |
+
pass
|
165 |
+
|
166 |
+
model.to(device)
|
167 |
+
if using_dist:
|
168 |
+
print("Distributed training")
|
169 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
170 |
+
model, device_ids=[rank], output_device=rank, broadcast_buffers=False
|
171 |
+
)
|
172 |
+
|
173 |
+
# learning rate
|
174 |
+
if lr is None:
|
175 |
+
lr = get_openai_lr(model)
|
176 |
+
print(f"Using OpenAI max lr of {lr}.")
|
177 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
178 |
+
scheduler = scheduler(
|
179 |
+
optimizer, warmup_epochs, epochs if epochs is not None else 100
|
180 |
+
) # when training for fixed time lr schedule takes 100 steps
|
181 |
+
|
182 |
+
if load_optimizer_from_this_state_dict is not None:
|
183 |
+
optimizer.load_state_dict(load_optimizer_from_this_state_dict)
|
184 |
+
scaler = GradScaler() if train_mixed_precision else None
|
185 |
+
|
186 |
+
# check that everything uses up-to-date APIs
|
187 |
+
utils.check_compatibility(dl)
|
188 |
+
|
189 |
+
def train_epoch():
|
190 |
+
model.train() # Turn on the train mode
|
191 |
+
total_loss = 0.0
|
192 |
+
total_positional_losses = 0.0
|
193 |
+
total_positional_losses_recorded = 0
|
194 |
+
before_get_batch = time.time()
|
195 |
+
assert (
|
196 |
+
len(dl) % aggregate_k_gradients == 0
|
197 |
+
), "Please set the number of steps per epoch s.t. `aggregate_k_gradients` divides it."
|
198 |
+
for batch, (data, targets, single_eval_pos) in enumerate(dl):
|
199 |
+
if using_dist and not (
|
200 |
+
batch % aggregate_k_gradients == aggregate_k_gradients - 1
|
201 |
+
):
|
202 |
+
cm = model.no_sync()
|
203 |
+
else:
|
204 |
+
cm = nullcontext()
|
205 |
+
with cm:
|
206 |
+
time_to_get_batch = time.time() - before_get_batch
|
207 |
+
before_forward = time.time()
|
208 |
+
|
209 |
+
with autocast(enabled=scaler is not None):
|
210 |
+
# If style is set to None, it should not be transferred to device
|
211 |
+
output = model(
|
212 |
+
tuple(e.to(device) if torch.is_tensor(e) else e for e in data)
|
213 |
+
if isinstance(data, tuple)
|
214 |
+
else data.to(device),
|
215 |
+
single_eval_pos=single_eval_pos,
|
216 |
+
)
|
217 |
+
|
218 |
+
forward_time = time.time() - before_forward
|
219 |
+
|
220 |
+
if single_eval_pos is not None:
|
221 |
+
targets = targets[single_eval_pos:]
|
222 |
+
if isinstance(criterion, nn.GaussianNLLLoss):
|
223 |
+
assert (
|
224 |
+
output.shape[-1] == 2
|
225 |
+
), "need to write a little bit of code to handle multiple regression targets at once"
|
226 |
+
|
227 |
+
mean_pred = output[..., 0]
|
228 |
+
var_pred = output[..., 1].abs()
|
229 |
+
losses = criterion(
|
230 |
+
mean_pred.flatten(),
|
231 |
+
targets.to(device).flatten(),
|
232 |
+
var=var_pred.flatten(),
|
233 |
+
)
|
234 |
+
elif isinstance(criterion, (nn.MSELoss, nn.BCEWithLogitsLoss)):
|
235 |
+
losses = criterion(
|
236 |
+
output.flatten(), targets.to(device).flatten()
|
237 |
+
)
|
238 |
+
elif isinstance(criterion, nn.CrossEntropyLoss):
|
239 |
+
losses = criterion(
|
240 |
+
output.reshape(-1, n_out),
|
241 |
+
targets.to(device).long().flatten(),
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
losses = criterion(output, targets)
|
245 |
+
losses = losses.view(*output.shape[0:2])
|
246 |
+
loss = losses.mean() / aggregate_k_gradients
|
247 |
+
|
248 |
+
if scaler:
|
249 |
+
loss = scaler.scale(loss)
|
250 |
+
loss.backward()
|
251 |
+
|
252 |
+
if batch % aggregate_k_gradients == aggregate_k_gradients - 1:
|
253 |
+
if scaler:
|
254 |
+
scaler.unscale_(optimizer)
|
255 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
256 |
+
try:
|
257 |
+
if scaler:
|
258 |
+
scaler.step(optimizer)
|
259 |
+
scaler.update()
|
260 |
+
else:
|
261 |
+
optimizer.step()
|
262 |
+
except:
|
263 |
+
print("Invalid optimization step encountered")
|
264 |
+
optimizer.zero_grad()
|
265 |
+
|
266 |
+
step_time = time.time() - before_forward
|
267 |
+
|
268 |
+
if not torch.isnan(loss):
|
269 |
+
total_loss += losses.mean().cpu().detach()
|
270 |
+
total_positional_losses += (
|
271 |
+
losses.mean(1).cpu().detach()
|
272 |
+
if single_eval_pos is None
|
273 |
+
else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
274 |
+
* losses[: bptt - single_eval_pos].mean().cpu().detach()
|
275 |
+
)
|
276 |
+
|
277 |
+
total_positional_losses_recorded += (
|
278 |
+
torch.ones(bptt)
|
279 |
+
if single_eval_pos is None
|
280 |
+
else nn.functional.one_hot(torch.tensor(single_eval_pos), bptt)
|
281 |
+
)
|
282 |
+
|
283 |
+
before_get_batch = time.time()
|
284 |
+
return (
|
285 |
+
total_loss / steps_per_epoch,
|
286 |
+
(total_positional_losses / total_positional_losses_recorded).tolist(),
|
287 |
+
time_to_get_batch,
|
288 |
+
forward_time,
|
289 |
+
step_time,
|
290 |
+
)
|
291 |
+
|
292 |
+
total_loss = float("inf")
|
293 |
+
total_positional_losses = float("inf")
|
294 |
+
list_losses = []
|
295 |
+
try:
|
296 |
+
for epoch in range(1, epochs + 1) if epochs is not None else itertools.count(1):
|
297 |
+
|
298 |
+
epoch_start_time = time.time()
|
299 |
+
(
|
300 |
+
total_loss,
|
301 |
+
total_positional_losses,
|
302 |
+
time_to_get_batch,
|
303 |
+
forward_time,
|
304 |
+
step_time,
|
305 |
+
) = train_epoch()
|
306 |
+
list_losses.append(total_loss.item())
|
307 |
+
if hasattr(dl, "validate") and epoch % validation_period == 0:
|
308 |
+
with torch.no_grad():
|
309 |
+
val_score = dl.validate(model)
|
310 |
+
|
311 |
+
else:
|
312 |
+
val_score = None
|
313 |
+
|
314 |
+
if epoch % saving_period == 0 and checkpoint_file is not None:
|
315 |
+
checkpoint = {
|
316 |
+
"model_state_dict": model.state_dict(),
|
317 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
318 |
+
"epoch": epoch,
|
319 |
+
}
|
320 |
+
torch.save(checkpoint, checkpoint_file)
|
321 |
+
full_model_path = checkpoint_file.split(".")[0] + "_full_model.pt"
|
322 |
+
torch.save(model, full_model_path)
|
323 |
+
|
324 |
+
if verbose:
|
325 |
+
print("-" * 89)
|
326 |
+
print(
|
327 |
+
f"| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | "
|
328 |
+
f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}"
|
329 |
+
f" data time {time_to_get_batch:5.2f} step time {step_time:5.2f}"
|
330 |
+
f" forward time {forward_time:5.2f}"
|
331 |
+
+ (f"val score {val_score}" if val_score is not None else "")
|
332 |
+
)
|
333 |
+
print("-" * 89)
|
334 |
+
|
335 |
+
# stepping with wallclock time based scheduler
|
336 |
+
if epoch_callback is not None and rank == 0:
|
337 |
+
epoch_callback(model, epoch / epochs)
|
338 |
+
scheduler.step()
|
339 |
+
except KeyboardInterrupt:
|
340 |
+
pass
|
341 |
+
|
342 |
+
if rank == 0: # trivially true for non-parallel training
|
343 |
+
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
344 |
+
model = model.module
|
345 |
+
dl = None
|
346 |
+
if output_path is not None:
|
347 |
+
torch.save(model.to("cpu"), output_path)
|
348 |
+
print("Checkpoint stored at ", output_path)
|
349 |
+
return total_loss, total_positional_losses, model.to("cpu"), dl
|
350 |
+
|
351 |
+
|
352 |
+
def _parse_args(config_parser, parser):
|
353 |
+
# Do we have a config file to parse?
|
354 |
+
args_config, remaining = config_parser.parse_known_args()
|
355 |
+
if args_config.config:
|
356 |
+
with open(args_config.config, "r") as f:
|
357 |
+
cfg = yaml.safe_load(f)
|
358 |
+
parser.set_defaults(**cfg)
|
359 |
+
|
360 |
+
# The main arg parser parses the rest of the args, the usual
|
361 |
+
# defaults will have been overridden if config file specified.
|
362 |
+
args = parser.parse_args(remaining)
|
363 |
+
|
364 |
+
# Cache the args as a text string to save them in the output dir later
|
365 |
+
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
366 |
+
return args, args_text
|
367 |
+
|
368 |
+
|
369 |
+
if __name__ == "__main__":
|
370 |
+
config_parser = argparse.ArgumentParser(
|
371 |
+
description="Only used as a first parser for the config file path."
|
372 |
+
)
|
373 |
+
config_parser.add_argument("--config")
|
374 |
+
parser = argparse.ArgumentParser()
|
375 |
+
parser.add_argument("prior")
|
376 |
+
parser.add_argument("--loss_function", default="barnll")
|
377 |
+
# Optional Arg's for `--loss_function barnll`
|
378 |
+
parser.add_argument(
|
379 |
+
"--min_y",
|
380 |
+
type=float,
|
381 |
+
help="barnll can only model y in strict ranges, this is the minimum y can take.",
|
382 |
+
)
|
383 |
+
parser.add_argument(
|
384 |
+
"--max_y",
|
385 |
+
type=float,
|
386 |
+
help="barnll can only model y in strict ranges, this is the maximum y can take.",
|
387 |
+
)
|
388 |
+
parser.add_argument("--num_buckets", default=100, type=int)
|
389 |
+
# parser.add_argument('--num_features', default=None, type=int, help='Specify depending on the prior.')
|
390 |
+
parser.add_argument(
|
391 |
+
"--extra_prior_kwargs_dict",
|
392 |
+
default={},
|
393 |
+
dest="extra_prior_kwargs_dict",
|
394 |
+
action=StoreDictKeyPair,
|
395 |
+
nargs="+",
|
396 |
+
metavar="KEY=VAL",
|
397 |
+
help="Specify depending on the prior.",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--encoder", default="linear", type=str, help="Specify depending on the prior."
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--y_encoder",
|
404 |
+
default="linear",
|
405 |
+
type=str,
|
406 |
+
help="Specify depending on the prior. You should specify this if you do not fuse x and y.",
|
407 |
+
)
|
408 |
+
parser.add_argument(
|
409 |
+
"--pos_encoder",
|
410 |
+
default="none",
|
411 |
+
type=str,
|
412 |
+
help="Specify depending on the prior.",
|
413 |
+
)
|
414 |
+
parser.add_argument("--bptt", default=10, type=int)
|
415 |
+
parser.add_argument("--epochs", default=200, type=int)
|
416 |
+
parser.add_argument("--warmup_epochs", default=50, type=int)
|
417 |
+
parser.add_argument("--validation_period", default=10, type=int)
|
418 |
+
parser.add_argument(
|
419 |
+
"--permutation_invariant_max_eval_pos",
|
420 |
+
default=None,
|
421 |
+
type=int,
|
422 |
+
help="Set this to an int to ",
|
423 |
+
)
|
424 |
+
parser.add_argument(
|
425 |
+
"--permutation_invariant_sampling",
|
426 |
+
default="weighted",
|
427 |
+
help="Only relevant if --permutation_invariant_max_eval_pos is set.",
|
428 |
+
)
|
429 |
+
parser.add_argument("--train_mixed_precision", action="store_true")
|
430 |
+
|
431 |
+
# these can likely be mostly left at defaults
|
432 |
+
parser.add_argument(
|
433 |
+
"--emsize", default=512, type=int
|
434 |
+
) # sometimes even larger is better e.g. 1024
|
435 |
+
parser.add_argument("--nlayers", default=6, type=int)
|
436 |
+
parser.add_argument("--nhid", default=None, type=int) # 2*emsize is the default
|
437 |
+
parser.add_argument(
|
438 |
+
"--nhead", default=4, type=int
|
439 |
+
) # nhead = emsize / 64 in the original paper
|
440 |
+
parser.add_argument("--dropout", default=0.0, type=float)
|
441 |
+
parser.add_argument("--steps_per_epoch", default=10, type=int)
|
442 |
+
parser.add_argument("--batch_size", default=1000, type=int)
|
443 |
+
parser.add_argument(
|
444 |
+
"--lr", "--learning_rate", default=0.001, type=float
|
445 |
+
) # try also .0003, .0001, go lower with lower batch size
|
446 |
+
parser.add_argument("--gpu_device", default="cuda", type=str)
|
447 |
+
|
448 |
+
# for model checkpointing
|
449 |
+
parser.add_argument(
|
450 |
+
"--checkpoint_file",
|
451 |
+
help="absolute or relative-to-the-project-rootdir path to the file storing the state dicts.",
|
452 |
+
default=None,
|
453 |
+
type=str,
|
454 |
+
)
|
455 |
+
parser.add_argument("--saving_period", default=10, type=str)
|
456 |
+
|
457 |
+
args, _ = _parse_args(config_parser, parser)
|
458 |
+
|
459 |
+
if args.nhid is None:
|
460 |
+
args.nhid = 2 * args.emsize
|
461 |
+
|
462 |
+
prior = args.__dict__.pop("prior")
|
463 |
+
|
464 |
+
if prior == "gp":
|
465 |
+
prior = priors.fast_gp.DataLoader
|
466 |
+
elif prior == "ridge":
|
467 |
+
prior = priors.ridge.DataLoader
|
468 |
+
elif prior == "stroke":
|
469 |
+
prior = priors.stroke.DataLoader
|
470 |
+
elif prior == "mix_gp":
|
471 |
+
prior = priors.fast_gp_mix.DataLoader
|
472 |
+
else:
|
473 |
+
raise NotImplementedError(f"Prior == {prior}.")
|
474 |
+
|
475 |
+
loss_function = args.__dict__.pop("loss_function")
|
476 |
+
|
477 |
+
criterion = nn.GaussianNLLLoss(reduction="none", full=True)
|
478 |
+
classificiation_criterion = nn.CrossEntropyLoss(reduction="none")
|
479 |
+
num_buckets = args.__dict__.pop("num_buckets")
|
480 |
+
max_y = args.__dict__.pop("max_y")
|
481 |
+
min_y = args.__dict__.pop("min_y")
|
482 |
+
# criterion = nn.MSELoss(reduction='none')
|
483 |
+
|
484 |
+
device = args.gpu_device if torch.cuda.is_available() else "cpu:0"
|
485 |
+
|
486 |
+
def get_y_sample():
|
487 |
+
args.__dict__["extra_prior_kwargs_dict"]["eval_pos_seq_len_sampler"] = lambda: (
|
488 |
+
args.bptt,
|
489 |
+
args.bptt,
|
490 |
+
)
|
491 |
+
dl = prior(
|
492 |
+
num_steps=1,
|
493 |
+
batch_size=args.batch_size * args.steps_per_epoch,
|
494 |
+
seq_len=args.bptt,
|
495 |
+
device=device,
|
496 |
+
**args.extra_prior_kwargs_dict,
|
497 |
+
)
|
498 |
+
args.__dict__["extra_prior_kwargs_dict"].pop("eval_pos_seq_len_sampler")
|
499 |
+
|
500 |
+
y_sample = next(iter(dl))[-2]
|
501 |
+
print(
|
502 |
+
f"Creating Bar distribution with borders from y sample of size {y_sample.numel()}"
|
503 |
+
)
|
504 |
+
return y_sample
|
505 |
+
|
506 |
+
if loss_function == "ce":
|
507 |
+
criterion = nn.CrossEntropyLoss(reduction="none")
|
508 |
+
elif loss_function == "gaussnll":
|
509 |
+
criterion = nn.GaussianNLLLoss(reduction="none", full=True)
|
510 |
+
elif loss_function == "mse":
|
511 |
+
criterion = nn.MSELoss(reduction="none")
|
512 |
+
elif loss_function == "barnll":
|
513 |
+
criterion = BarDistribution(
|
514 |
+
borders=get_bucket_limits(num_buckets, full_range=(min_y, max_y))
|
515 |
+
)
|
516 |
+
elif loss_function == "adaptivebarnll":
|
517 |
+
borders = get_bucket_limits(
|
518 |
+
num_buckets, ys=get_y_sample(), full_range=(min_y, max_y)
|
519 |
+
)
|
520 |
+
criterion = BarDistribution(borders=borders)
|
521 |
+
elif loss_function == "adaptivefullsupportbarnll":
|
522 |
+
assert (
|
523 |
+
min_y is None and max_y is None
|
524 |
+
), "Please do not specify `min_y` and `max_y` with `unboundedadaptivebarnll`."
|
525 |
+
borders = get_bucket_limits(num_buckets, ys=get_y_sample())
|
526 |
+
criterion = FullSupportBarDistribution(borders=borders)
|
527 |
+
else:
|
528 |
+
raise NotImplementedError(f"loss_function == {loss_function}.")
|
529 |
+
|
530 |
+
encoder = args.__dict__.pop("encoder")
|
531 |
+
y_encoder = args.__dict__.pop("y_encoder")
|
532 |
+
|
533 |
+
def get_encoder_generator(encoder):
|
534 |
+
if encoder == "linear":
|
535 |
+
encoder_generator = encoders.Linear
|
536 |
+
elif encoder == "mlp":
|
537 |
+
encoder_generator = encoders.MLP
|
538 |
+
elif encoder == "positional":
|
539 |
+
encoder_generator = encoders.Positional
|
540 |
+
else:
|
541 |
+
raise NotImplementedError(f"A {encoder} encoder is not valid.")
|
542 |
+
return encoder_generator
|
543 |
+
|
544 |
+
encoder_generator = get_encoder_generator(encoder)
|
545 |
+
y_encoder_generator = get_encoder_generator(y_encoder)
|
546 |
+
|
547 |
+
pos_encoder = args.__dict__.pop("pos_encoder")
|
548 |
+
|
549 |
+
if pos_encoder == "none":
|
550 |
+
pos_encoder_generator = None
|
551 |
+
elif pos_encoder == "sinus":
|
552 |
+
pos_encoder_generator = positional_encodings.PositionalEncoding
|
553 |
+
elif pos_encoder == "learned":
|
554 |
+
pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
|
555 |
+
elif pos_encoder == "paired_scrambled_learned":
|
556 |
+
pos_encoder_generator = positional_encodings.PairedScrambledPositionalEncodings
|
557 |
+
else:
|
558 |
+
raise NotImplementedError(f"pos_encoer == {pos_encoder} is not valid.")
|
559 |
+
|
560 |
+
permutation_invariant_max_eval_pos = args.__dict__.pop(
|
561 |
+
"permutation_invariant_max_eval_pos"
|
562 |
+
)
|
563 |
+
permutation_invariant_sampling = args.__dict__.pop("permutation_invariant_sampling")
|
564 |
+
if permutation_invariant_max_eval_pos is not None:
|
565 |
+
if permutation_invariant_sampling == "weighted":
|
566 |
+
get_sampler = get_weighted_single_eval_pos_sampler
|
567 |
+
elif permutation_invariant_sampling == "uniform":
|
568 |
+
get_sampler = get_uniform_single_eval_pos_sampler
|
569 |
+
else:
|
570 |
+
raise ValueError()
|
571 |
+
args.__dict__["single_eval_pos_gen"] = get_sampler(
|
572 |
+
permutation_invariant_max_eval_pos
|
573 |
+
)
|
574 |
+
|
575 |
+
print("ARGS for `train`:", args.__dict__)
|
576 |
+
|
577 |
+
if args.__dict__["checkpoint_file"] is not None:
|
578 |
+
rootdir = os.path.dirname(os.path.realpath(__file__))
|
579 |
+
args.__dict__["checkpoint_file"] = os.path.join(
|
580 |
+
rootdir, args.__dict__["checkpoint_file"]
|
581 |
+
)
|
582 |
+
|
583 |
+
if os.path.exists(args.__dict__["checkpoint_file"]):
|
584 |
+
state_dicts = torch.load(args.__dict__["checkpoint_file"])
|
585 |
+
args.__dict__["load_weights_from_this_state_dict"] = state_dicts[
|
586 |
+
"model_state_dict"
|
587 |
+
]
|
588 |
+
args.__dict__["load_optimizer_from_this_state_dict"] = state_dicts[
|
589 |
+
"optimizer_state_dict"
|
590 |
+
]
|
591 |
+
else:
|
592 |
+
args.__dict__["load_weights_from_this_state_dict"] = None
|
593 |
+
args.__dict__["load_optimizer_from_this_state_dict"] = None
|
594 |
+
|
595 |
+
train(
|
596 |
+
prior,
|
597 |
+
criterion,
|
598 |
+
encoder_generator,
|
599 |
+
y_encoder_generator=y_encoder_generator,
|
600 |
+
pos_encoder_generator=pos_encoder_generator,
|
601 |
+
**args.__dict__,
|
602 |
+
)
|
lcpfn/train_lcpfn.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from lcpfn import bar_distribution, encoders, priors, train
|
6 |
+
from lcpfn import utils
|
7 |
+
|
8 |
+
|
9 |
+
def train_lcpfn(
|
10 |
+
get_batch_func,
|
11 |
+
seq_len: int = 100,
|
12 |
+
emsize: int = 512,
|
13 |
+
nlayers: int = 12,
|
14 |
+
num_borders: int = 1000,
|
15 |
+
lr: float = 0.001,
|
16 |
+
batch_size: int = 100,
|
17 |
+
epochs: int = 1000,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Train a LCPFN model using the specified hyperparameters.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
get_batch_func (callable): A function that returns a batch of learning curves.
|
24 |
+
seq_len (int, optional): The length of the input sequence. Defaults to 100.
|
25 |
+
emsize (int, optional): The size of the embedding layer. Defaults to 512.
|
26 |
+
nlayers (int, optional): The number of layers in the model. Defaults to 12.
|
27 |
+
num_borders_choices (int, optional): The number of borders to use. Defaults to 1000.
|
28 |
+
lr (float, optional): The learning rate for the optimizer. Defaults to 0.001.
|
29 |
+
batch_size (int, optional): The batch size for training. Defaults to 100.
|
30 |
+
epochs (int, optional): The number of epochs to train for. Defaults to 1000.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
torch.module: The trained model.
|
34 |
+
"""
|
35 |
+
|
36 |
+
hps = {}
|
37 |
+
|
38 |
+
# PFN training hyperparameters
|
39 |
+
dataloader = priors.utils.get_batch_to_dataloader(get_batch_func) # type: ignore
|
40 |
+
|
41 |
+
num_features = 1
|
42 |
+
|
43 |
+
ys = get_batch_func(
|
44 |
+
10_000,
|
45 |
+
seq_len,
|
46 |
+
num_features,
|
47 |
+
hyperparameters=hps,
|
48 |
+
single_eval_pos=seq_len,
|
49 |
+
)
|
50 |
+
|
51 |
+
bucket_limits = bar_distribution.get_bucket_limits(num_borders, ys=ys[2])
|
52 |
+
|
53 |
+
# Discretization of the predictive distributions
|
54 |
+
criterions = {
|
55 |
+
num_features: {
|
56 |
+
num_borders: bar_distribution.FullSupportBarDistribution(bucket_limits)
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
config = dict(
|
61 |
+
nlayers=nlayers,
|
62 |
+
priordataloader_class=dataloader,
|
63 |
+
criterion=criterions[num_features][num_borders],
|
64 |
+
encoder_generator=lambda in_dim, out_dim: nn.Sequential(
|
65 |
+
encoders.Normalize(0.0, 101.0),
|
66 |
+
encoders.Normalize(0.5, math.sqrt(1 / 12)),
|
67 |
+
encoders.Linear(in_dim, out_dim),
|
68 |
+
),
|
69 |
+
emsize=emsize,
|
70 |
+
nhead=(emsize // 128),
|
71 |
+
warmup_epochs=(epochs // 4),
|
72 |
+
y_encoder_generator=encoders.get_normalized_uniform_encoder(encoders.Linear),
|
73 |
+
batch_size=batch_size,
|
74 |
+
scheduler=utils.get_cosine_schedule_with_warmup,
|
75 |
+
extra_prior_kwargs_dict={
|
76 |
+
# "num_workers": 10,
|
77 |
+
"num_features": num_features,
|
78 |
+
"hyperparameters": {
|
79 |
+
**hps,
|
80 |
+
},
|
81 |
+
},
|
82 |
+
epochs=epochs,
|
83 |
+
lr=lr,
|
84 |
+
bptt=seq_len,
|
85 |
+
single_eval_pos_gen=utils.get_uniform_single_eval_pos_sampler(seq_len, min_len=1),
|
86 |
+
aggregate_k_gradients=1,
|
87 |
+
nhid=(emsize * 2),
|
88 |
+
steps_per_epoch=100,
|
89 |
+
train_mixed_precision=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
return train.train(**config)
|
lcpfn/transformer.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
from torch.nn import Module, TransformerEncoder
|
8 |
+
|
9 |
+
from lcpfn.layer import TransformerEncoderLayer, _get_activation_fn
|
10 |
+
from lcpfn.utils import SeqBN, bool_mask_to_att_mask
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
class TransformerModel(nn.Module):
|
15 |
+
def __init__(self, encoder, n_out, ninp, nhead, nhid, nlayers, dropout=0.0, style_encoder=None, y_encoder=None,
|
16 |
+
pos_encoder=None, decoder=None, input_normalization=False, init_method=None, pre_norm=False,
|
17 |
+
activation='gelu', recompute_attn=False, num_global_att_tokens=0, full_attention=False,
|
18 |
+
all_layers_same_init=True):
|
19 |
+
super().__init__()
|
20 |
+
self.model_type = 'Transformer'
|
21 |
+
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
|
22 |
+
pre_norm=pre_norm, recompute_attn=recompute_attn)
|
23 |
+
self.transformer_encoder = TransformerEncoder(encoder_layer_creator(), nlayers)\
|
24 |
+
if all_layers_same_init else TransformerEncoderDiffInit(encoder_layer_creator, nlayers)
|
25 |
+
self.ninp = ninp
|
26 |
+
self.encoder = encoder
|
27 |
+
self.y_encoder = y_encoder
|
28 |
+
self.pos_encoder = pos_encoder
|
29 |
+
self.decoder = decoder(ninp, nhid, n_out) if decoder is not None else nn.Sequential(nn.Linear(ninp, nhid), nn.GELU(), nn.Linear(nhid, n_out))
|
30 |
+
self.input_ln = SeqBN(ninp) if input_normalization else None
|
31 |
+
self.style_encoder = style_encoder
|
32 |
+
self.init_method = init_method
|
33 |
+
if num_global_att_tokens is not None:
|
34 |
+
assert not full_attention
|
35 |
+
self.global_att_embeddings = nn.Embedding(num_global_att_tokens, ninp) if num_global_att_tokens else None
|
36 |
+
self.full_attention = full_attention
|
37 |
+
|
38 |
+
self.n_out = n_out
|
39 |
+
self.nhid = nhid
|
40 |
+
|
41 |
+
self.init_weights()
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def generate_square_subsequent_mask(sz):
|
45 |
+
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
46 |
+
return bool_mask_to_att_mask(mask)
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def generate_D_q_matrix(sz, query_size):
|
50 |
+
train_size = sz-query_size
|
51 |
+
mask = torch.zeros(sz,sz) == 0
|
52 |
+
mask[:,train_size:].zero_()
|
53 |
+
mask |= torch.eye(sz) == 1
|
54 |
+
return bool_mask_to_att_mask(mask)
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def generate_global_att_query_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
58 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
59 |
+
sz = seq_len + num_global_att_tokens
|
60 |
+
mask = torch.zeros(num_query_tokens, sz) == 0
|
61 |
+
mask[:,train_size:].zero_()
|
62 |
+
mask[:,train_size:] |= torch.eye(num_query_tokens) == 1
|
63 |
+
return bool_mask_to_att_mask(mask)
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def generate_global_att_trainset_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
67 |
+
train_size = seq_len + num_global_att_tokens - num_query_tokens
|
68 |
+
trainset_size = seq_len - num_query_tokens
|
69 |
+
mask = torch.zeros(trainset_size, num_global_att_tokens) == 0
|
70 |
+
#mask[:,num_global_att_tokens:].zero_()
|
71 |
+
#mask[:,num_global_att_tokens:] |= torch.eye(trainset_size) == 1
|
72 |
+
return bool_mask_to_att_mask(mask)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def generate_global_att_globaltokens_matrix(num_global_att_tokens, seq_len, num_query_tokens):
|
76 |
+
mask = torch.zeros(num_global_att_tokens, num_global_att_tokens+seq_len-num_query_tokens) == 0
|
77 |
+
return bool_mask_to_att_mask(mask)
|
78 |
+
|
79 |
+
def init_weights(self):
|
80 |
+
initrange = 1.
|
81 |
+
# if isinstance(self.encoder,EmbeddingEncoder):
|
82 |
+
# self.encoder.weight.data.uniform_(-initrange, initrange)
|
83 |
+
# self.decoder.bias.data.zero_()
|
84 |
+
# self.decoder.weight.data.uniform_(-initrange, initrange)
|
85 |
+
if self.init_method is not None:
|
86 |
+
self.apply(self.init_method)
|
87 |
+
for layer in self.transformer_encoder.layers:
|
88 |
+
nn.init.zeros_(layer.linear2.weight)
|
89 |
+
nn.init.zeros_(layer.linear2.bias)
|
90 |
+
attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
|
91 |
+
for attn in attns:
|
92 |
+
nn.init.zeros_(attn.out_proj.weight)
|
93 |
+
nn.init.zeros_(attn.out_proj.bias)
|
94 |
+
|
95 |
+
def forward(self, src, src_mask=None, single_eval_pos=None):
|
96 |
+
assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'
|
97 |
+
|
98 |
+
if len(src) == 2: # (x,y) and no style
|
99 |
+
src = (None,) + src
|
100 |
+
|
101 |
+
style_src, style_src_size = (src[0], (0 if (src[0] is None) else 1))
|
102 |
+
if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
|
103 |
+
if src_mask is None:
|
104 |
+
x_src = src[1]
|
105 |
+
if self.global_att_embeddings is None:
|
106 |
+
full_len = len(x_src) + style_src_size
|
107 |
+
if self.full_attention:
|
108 |
+
src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
|
109 |
+
else:
|
110 |
+
src_mask = self.generate_D_q_matrix(len(x_src) + style_src_size, len(x_src) + style_src_size -single_eval_pos).to(x_src.device)
|
111 |
+
else:
|
112 |
+
src_mask_args = (self.global_att_embeddings.num_embeddings,
|
113 |
+
len(x_src) + style_src_size,
|
114 |
+
len(x_src) + style_src_size - single_eval_pos)
|
115 |
+
src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
|
116 |
+
self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
|
117 |
+
self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))
|
118 |
+
|
119 |
+
style_src, x_src, y_src = src
|
120 |
+
x_src = self.encoder(x_src)
|
121 |
+
y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
|
122 |
+
style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else torch.tensor([], device=x_src.device)
|
123 |
+
global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
|
124 |
+
self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)
|
125 |
+
train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
|
126 |
+
src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)
|
127 |
+
|
128 |
+
if self.input_ln is not None:
|
129 |
+
src = self.input_ln(src)
|
130 |
+
|
131 |
+
if self.pos_encoder is not None:
|
132 |
+
src = self.pos_encoder(src)
|
133 |
+
|
134 |
+
# If we have style input, drop its output
|
135 |
+
output = self.transformer_encoder(src, src_mask)[style_src_size:]
|
136 |
+
output = self.decoder(output)
|
137 |
+
return output[single_eval_pos+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def init_from_small_model(self, small_model):
|
141 |
+
assert isinstance(self.decoder, nn.Linear) and isinstance(self.encoder, (nn.Linear, nn.Sequential)) \
|
142 |
+
and isinstance(self.y_encoder, (nn.Linear, nn.Sequential))
|
143 |
+
|
144 |
+
def set_encoder_weights(my_encoder, small_model_encoder):
|
145 |
+
my_encoder_linear, small_encoder_linear = (my_encoder, small_model_encoder) \
|
146 |
+
if isinstance(my_encoder, nn.Linear) else (my_encoder[-1], small_model_encoder[-1])
|
147 |
+
small_in_dim = small_encoder_linear.out_features
|
148 |
+
my_encoder_linear.weight.zero_()
|
149 |
+
my_encoder_linear.bias.zero_()
|
150 |
+
my_encoder_linear.weight[:small_in_dim] = small_encoder_linear.weight
|
151 |
+
my_encoder_linear.bias[:small_in_dim] = small_encoder_linear.bias
|
152 |
+
|
153 |
+
set_encoder_weights(self.encoder, small_model.encoder)
|
154 |
+
set_encoder_weights(self.y_encoder, small_model.y_encoder)
|
155 |
+
|
156 |
+
small_in_dim = small_model.decoder.in_features
|
157 |
+
|
158 |
+
self.decoder.weight[:, :small_in_dim] = small_model.decoder.weight
|
159 |
+
self.decoder.bias = small_model.decoder.bias
|
160 |
+
|
161 |
+
for my_layer, small_layer in zip(self.transformer_encoder.layers, small_model.transformer_encoder.layers):
|
162 |
+
small_hid_dim = small_layer.linear1.out_features
|
163 |
+
my_in_dim = my_layer.linear1.in_features
|
164 |
+
|
165 |
+
# packed along q,k,v order in first dim
|
166 |
+
my_in_proj_w = my_layer.self_attn.in_proj_weight
|
167 |
+
small_in_proj_w = small_layer.self_attn.in_proj_weight
|
168 |
+
|
169 |
+
my_in_proj_w.view(3, my_in_dim, my_in_dim)[:, :small_in_dim, :small_in_dim] = small_in_proj_w.view(3,
|
170 |
+
small_in_dim,
|
171 |
+
small_in_dim)
|
172 |
+
my_layer.self_attn.in_proj_bias.view(3, my_in_dim)[:,
|
173 |
+
:small_in_dim] = small_layer.self_attn.in_proj_bias.view(3, small_in_dim)
|
174 |
+
|
175 |
+
my_layer.self_attn.out_proj.weight[:small_in_dim, :small_in_dim] = small_layer.self_attn.out_proj.weight
|
176 |
+
my_layer.self_attn.out_proj.bias[:small_in_dim] = small_layer.self_attn.out_proj.bias
|
177 |
+
|
178 |
+
my_layer.linear1.weight[:small_hid_dim, :small_in_dim] = small_layer.linear1.weight
|
179 |
+
my_layer.linear1.bias[:small_hid_dim] = small_layer.linear1.bias
|
180 |
+
|
181 |
+
my_layer.linear2.weight[:small_in_dim, :small_hid_dim] = small_layer.linear2.weight
|
182 |
+
my_layer.linear2.bias[:small_in_dim] = small_layer.linear2.bias
|
183 |
+
|
184 |
+
my_layer.norm1.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm1.weight
|
185 |
+
my_layer.norm2.weight[:small_in_dim] = math.sqrt(small_in_dim / my_in_dim) * small_layer.norm2.weight
|
186 |
+
|
187 |
+
my_layer.norm1.bias[:small_in_dim] = small_layer.norm1.bias
|
188 |
+
my_layer.norm2.bias[:small_in_dim] = small_layer.norm2.bias
|
189 |
+
|
190 |
+
|
191 |
+
class TransformerEncoderDiffInit(Module):
|
192 |
+
r"""TransformerEncoder is a stack of N encoder layers
|
193 |
+
|
194 |
+
Args:
|
195 |
+
encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
|
196 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
197 |
+
norm: the layer normalization component (optional).
|
198 |
+
"""
|
199 |
+
__constants__ = ['norm']
|
200 |
+
|
201 |
+
def __init__(self, encoder_layer_creator, num_layers, norm=None):
|
202 |
+
super().__init__()
|
203 |
+
self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
|
204 |
+
self.num_layers = num_layers
|
205 |
+
self.norm = norm
|
206 |
+
|
207 |
+
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
208 |
+
r"""Pass the input through the encoder layers in turn.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
src: the sequence to the encoder (required).
|
212 |
+
mask: the mask for the src sequence (optional).
|
213 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
214 |
+
|
215 |
+
Shape:
|
216 |
+
see the docs in Transformer class.
|
217 |
+
"""
|
218 |
+
output = src
|
219 |
+
|
220 |
+
for mod in self.layers:
|
221 |
+
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
|
222 |
+
|
223 |
+
if self.norm is not None:
|
224 |
+
output = self.norm(output)
|
225 |
+
|
226 |
+
return output
|
lcpfn/utils.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import argparse
|
4 |
+
import random
|
5 |
+
import datetime
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.optim.lr_scheduler import LambdaLR
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# copied from huggingface
|
13 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
|
14 |
+
""" Create a schedule with a learning rate that decreases following the
|
15 |
+
values of the cosine function between 0 and `pi * cycles` after a warmup
|
16 |
+
period during which it increases linearly between 0 and 1.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def lr_lambda(current_step):
|
20 |
+
if current_step < num_warmup_steps:
|
21 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
22 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
23 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
24 |
+
|
25 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
26 |
+
|
27 |
+
# copied from huggingface
|
28 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
29 |
+
"""
|
30 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
31 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
optimizer (:class:`~torch.optim.Optimizer`):
|
35 |
+
The optimizer for which to schedule the learning rate.
|
36 |
+
num_warmup_steps (:obj:`int`):
|
37 |
+
The number of steps for the warmup phase.
|
38 |
+
num_training_steps (:obj:`int`):
|
39 |
+
The total number of training steps.
|
40 |
+
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
41 |
+
The index of the last epoch when resuming training.
|
42 |
+
|
43 |
+
Return:
|
44 |
+
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def lr_lambda(current_step: int):
|
48 |
+
if current_step < num_warmup_steps:
|
49 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
50 |
+
return max(
|
51 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
52 |
+
)
|
53 |
+
|
54 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
55 |
+
|
56 |
+
|
57 |
+
def get_openai_lr(transformer_model):
|
58 |
+
num_params = sum(p.numel() for p in transformer_model.parameters())
|
59 |
+
return 0.003239 - 0.0001395 * math.log(num_params)
|
60 |
+
|
61 |
+
|
62 |
+
def get_weighted_single_eval_pos_sampler(max_len):
|
63 |
+
"""
|
64 |
+
This gives a sampler that can be used for `single_eval_pos` which yields good performance for all positions p,
|
65 |
+
where p <= `max_len`. At most `max_len` - 1 examples are shown to the Transformer.
|
66 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
67 |
+
"""
|
68 |
+
return lambda: random.choices(range(max_len), [1 / (max_len - i) for i in range(max_len)])[0]
|
69 |
+
|
70 |
+
|
71 |
+
def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
|
72 |
+
"""
|
73 |
+
Just sample any evaluation position with the same weight
|
74 |
+
:return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
|
75 |
+
"""
|
76 |
+
return lambda: random.choices(range(min_len, max_len))[0]
|
77 |
+
|
78 |
+
|
79 |
+
class SeqBN(nn.Module):
|
80 |
+
def __init__(self, d_model):
|
81 |
+
super().__init__()
|
82 |
+
self.bn = nn.BatchNorm1d(d_model)
|
83 |
+
self.d_model = d_model
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
assert self.d_model == x.shape[-1]
|
87 |
+
flat_x = x.view(-1, self.d_model)
|
88 |
+
flat_x = self.bn(flat_x)
|
89 |
+
return flat_x.view(*x.shape)
|
90 |
+
|
91 |
+
|
92 |
+
def set_locals_in_self(locals):
|
93 |
+
"""
|
94 |
+
Call this function like `set_locals_in_self(locals())` to set all local variables as object variables.
|
95 |
+
Especially useful right at the beginning of `__init__`.
|
96 |
+
:param locals: `locals()`
|
97 |
+
"""
|
98 |
+
self = locals['self']
|
99 |
+
for var_name, val in locals.items():
|
100 |
+
if var_name != 'self': setattr(self, var_name, val)
|
101 |
+
|
102 |
+
|
103 |
+
default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
|
104 |
+
|
105 |
+
|
106 |
+
# Copied from StackOverflow, but we do an eval on the values additionally
|
107 |
+
class StoreDictKeyPair(argparse.Action):
|
108 |
+
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
109 |
+
self._nargs = nargs
|
110 |
+
super(StoreDictKeyPair, self).__init__(option_strings, dest, nargs=nargs, **kwargs)
|
111 |
+
|
112 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
113 |
+
my_dict = {}
|
114 |
+
for kv in values:
|
115 |
+
k, v = kv.split("=")
|
116 |
+
try:
|
117 |
+
my_dict[k] = eval(v)
|
118 |
+
except NameError:
|
119 |
+
my_dict[k] = v
|
120 |
+
setattr(namespace, self.dest, my_dict)
|
121 |
+
print("dict values: {}".format(my_dict))
|
122 |
+
|
123 |
+
def get_nan_value(v, set_value_to_nan=0.0):
|
124 |
+
if random.random() < set_value_to_nan:
|
125 |
+
return v
|
126 |
+
else:
|
127 |
+
return random.choice([-999, 0, 1, 999])
|
128 |
+
|
129 |
+
def to_ranking(data):
|
130 |
+
x = (data >= data.unsqueeze(-3))
|
131 |
+
x = x.sum(0)
|
132 |
+
return x
|
133 |
+
# TODO: Is there a better way to do this?
|
134 |
+
# 1. Cmparing to unique elements: When all values are different we still get quadratic blowup
|
135 |
+
# 2. Argsort(Argsort()) returns ranking, but with duplicate values there is an ordering which is problematic
|
136 |
+
# 3. Argsort(Argsort(Unique))->Scatter seems a bit complicated, doesn't have quadratic blowup, but how fast?
|
137 |
+
def to_ranking_low_mem(data):
|
138 |
+
x = torch.zeros_like(data)
|
139 |
+
for col in range(data.shape[-1]):
|
140 |
+
x_ = (data[:, :, col] >= data[:, :, col].unsqueeze(-2))
|
141 |
+
x_ = x_.sum(0)
|
142 |
+
x[:, :, col] = x_
|
143 |
+
return x
|
144 |
+
|
145 |
+
def nan_handling_missing_for_unknown_reason_value(set_value_to_nan=0.0):
|
146 |
+
return get_nan_value(float('nan'), set_value_to_nan)
|
147 |
+
|
148 |
+
def nan_handling_missing_for_no_reason_value(set_value_to_nan=0.0):
|
149 |
+
return get_nan_value(float('-inf'), set_value_to_nan)
|
150 |
+
|
151 |
+
def nan_handling_missing_for_a_reason_value(set_value_to_nan=0.0):
|
152 |
+
return get_nan_value(float('inf'), set_value_to_nan)
|
153 |
+
|
154 |
+
def torch_nanmean(x, axis=0):
|
155 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
|
156 |
+
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
157 |
+
return value / num
|
158 |
+
|
159 |
+
def torch_nanstd(x, axis=0):
|
160 |
+
num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum(axis=axis)
|
161 |
+
value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum(axis=axis)
|
162 |
+
mean = value / num
|
163 |
+
mean_broadcast = torch.repeat_interleave(mean.unsqueeze(axis), x.shape[axis], dim=axis)
|
164 |
+
return torch.sqrt(torch.nansum(torch.square(mean_broadcast - x), axis=axis) / (num - 1))
|
165 |
+
|
166 |
+
def normalize_data(data, normalize_positions=-1):
|
167 |
+
if normalize_positions > 0:
|
168 |
+
mean = torch_nanmean(data[:normalize_positions], axis=0)
|
169 |
+
std = torch_nanstd(data[:normalize_positions], axis=0) + .000001
|
170 |
+
else:
|
171 |
+
mean = torch_nanmean(data, axis=0)
|
172 |
+
std = torch_nanstd(data, axis=0) + .000001
|
173 |
+
data = (data - mean) / std
|
174 |
+
data = torch.clip(data, min=-100, max=100)
|
175 |
+
|
176 |
+
return data
|
177 |
+
|
178 |
+
def remove_outliers(X, n_sigma=4):
|
179 |
+
# Expects T, B, H
|
180 |
+
assert len(X.shape) == 3, "X must be T,B,H"
|
181 |
+
#for b in range(X.shape[1]):
|
182 |
+
#for col in range(X.shape[2]):
|
183 |
+
data = X
|
184 |
+
data_mean, data_std = torch_nanmean(data, axis=0), torch_nanstd(data, axis=0)
|
185 |
+
cut_off = data_std * n_sigma
|
186 |
+
lower, upper = data_mean - cut_off, data_mean + cut_off
|
187 |
+
|
188 |
+
data_clean = X[:].clone()
|
189 |
+
data_clean[torch.logical_or(data > upper, data < lower)] = np.nan
|
190 |
+
data_mean, data_std = torch_nanmean(data_clean, axis=0), torch_nanstd(data_clean, axis=0)
|
191 |
+
cut_off = data_std * n_sigma
|
192 |
+
lower, upper = data_mean - cut_off, data_mean + cut_off
|
193 |
+
|
194 |
+
X = torch.maximum(-torch.log(1+torch.abs(X)) + lower, X)
|
195 |
+
X = torch.minimum(torch.log(1+torch.abs(X)) + upper, X)
|
196 |
+
# print(ds[1][data < lower, col], ds[1][data > upper, col], ds[1][~np.isnan(data), col].shape, data_mean, data_std)
|
197 |
+
return X
|
198 |
+
|
199 |
+
def bool_mask_to_att_mask(mask):
|
200 |
+
return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
201 |
+
|
202 |
+
def print_on_master_only(is_master):
|
203 |
+
import builtins as __builtin__
|
204 |
+
|
205 |
+
builtin_print = __builtin__.print
|
206 |
+
|
207 |
+
def print(*args, **kwargs):
|
208 |
+
force = kwargs.pop("force", False)
|
209 |
+
if is_master or force:
|
210 |
+
builtin_print(*args, **kwargs)
|
211 |
+
|
212 |
+
__builtin__.print = print
|
213 |
+
|
214 |
+
|
215 |
+
def init_dist(device):
|
216 |
+
print('init dist')
|
217 |
+
if 'LOCAL_RANK' in os.environ:
|
218 |
+
# launched with torch.distributed.launch
|
219 |
+
rank = int(os.environ["LOCAL_RANK"])
|
220 |
+
print('torch.distributed.launch and my rank is', rank)
|
221 |
+
torch.cuda.set_device(rank)
|
222 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
223 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
|
224 |
+
world_size=torch.cuda.device_count(), rank=rank)
|
225 |
+
torch.distributed.barrier()
|
226 |
+
print_on_master_only(rank == 0)
|
227 |
+
print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
228 |
+
"only I can print, but when using print(..., force=True) it will print on all ranks.")
|
229 |
+
return True, rank, f'cuda:{rank}'
|
230 |
+
elif 'SLURM_PROCID' in os.environ and torch.cuda.device_count() > 1:
|
231 |
+
# this is for multi gpu when starting with submitit
|
232 |
+
assert device != 'cpu:0'
|
233 |
+
rank = int(os.environ['SLURM_PROCID'])
|
234 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
235 |
+
os.environ['MASTER_PORT'] = '12355'
|
236 |
+
torch.cuda.set_device(rank)
|
237 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(rank)
|
238 |
+
print('distributed submitit launch and my rank is', rank)
|
239 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=20),
|
240 |
+
world_size=torch.cuda.device_count(), rank=rank)
|
241 |
+
torch.distributed.barrier()
|
242 |
+
print_on_master_only(rank == 0)
|
243 |
+
print(f"Distributed training on {torch.cuda.device_count()} GPUs, this is rank {rank}, "
|
244 |
+
"only I can print, but when using print(..., force=True) it will print on all ranks.")
|
245 |
+
|
246 |
+
return True, rank, f'cuda:{rank}'
|
247 |
+
else:
|
248 |
+
print('Not using distributed')
|
249 |
+
# will not change any of the behavior of print, but allows putting the force=True in the print calls
|
250 |
+
print_on_master_only(True)
|
251 |
+
return False, 0, device
|
252 |
+
|
253 |
+
|
254 |
+
def check_compatibility(dl):
|
255 |
+
if hasattr(dl, 'num_outputs'):
|
256 |
+
print('`num_outputs` for the DataLoader is deprecated. It is assumed to be 1 from now on.')
|
257 |
+
assert dl.num_outputs != 1, "We assume num_outputs to be 1. Instead of the num_ouputs change your loss." \
|
258 |
+
"We specify the number of classes in the CE loss."
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.11.0
|
2 |
+
numpy>=1.21.2
|
3 |
+
# lcpfn @ git+https://github.com/automl/lcpfn.git
|