Spaces:
Runtime error
Runtime error
Anonymous
commited on
Commit
•
74a9978
1
Parent(s):
47cd265
Update To Sklearn interface
Browse files- TabPFN/PrepareDatasets.ipynb +37 -109
- TabPFN/PriorFittingCustomPrior.ipynb +353 -0
- TabPFN/{TabPFNPredictionOnly.ipynb → QuickPredictionDemo.ipynb} +166 -37
- TabPFN/README.md +1 -4
- TabPFN/RunFullDatasetAnalyses.ipynb +833 -0
- TabPFN/SyntheticGPAblation.ipynb +0 -392
- TabPFN/TabularEvaluationVisualization.ipynb +0 -0
- TabPFN/TrainingTuningAndPrediction.ipynb +0 -0
- TabPFN/differentiable_pfn_evaluation.py +0 -345
- TabPFN/layer.py +6 -0
- TabPFN/model_builder.py +0 -273
- TabPFN/models_diff/gp_ablation_model.cpkt +0 -3
- TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt +0 -3
- TabPFN/prior_tuning_result.pkl +0 -3
- TabPFN/scripts/differentiable_pfn_evaluation.py +25 -143
- TabPFN/scripts/model_configs.py +92 -17
- TabPFN/scripts/tabular_baselines.py +1140 -39
- TabPFN/scripts/tabular_baselines_deep.py +74 -0
- TabPFN/scripts/tabular_evaluation.py +51 -23
- TabPFN/scripts/tabular_metrics.py +38 -7
- TabPFN/scripts/transformer_prediction_interface.py +1 -1
- TabPFN/tabular_evaluation.py +0 -283
- encoders.py +0 -243
TabPFN/PrepareDatasets.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
},
|
15 |
{
|
16 |
"cell_type": "code",
|
17 |
-
"execution_count":
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -25,18 +25,9 @@
|
|
25 |
},
|
26 |
{
|
27 |
"cell_type": "code",
|
28 |
-
"execution_count":
|
29 |
"metadata": {},
|
30 |
-
"outputs": [
|
31 |
-
{
|
32 |
-
"name": "stdout",
|
33 |
-
"output_type": "stream",
|
34 |
-
"text": [
|
35 |
-
"The autoreload extension is already loaded. To reload it, use:\n",
|
36 |
-
" %reload_ext autoreload\n"
|
37 |
-
]
|
38 |
-
}
|
39 |
-
],
|
40 |
"source": [
|
41 |
"%load_ext autoreload\n",
|
42 |
"\n",
|
@@ -54,7 +45,7 @@
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
-
"execution_count":
|
58 |
"metadata": {},
|
59 |
"outputs": [],
|
60 |
"source": [
|
@@ -63,42 +54,16 @@
|
|
63 |
},
|
64 |
{
|
65 |
"cell_type": "code",
|
66 |
-
"execution_count":
|
67 |
"metadata": {},
|
68 |
-
"outputs": [
|
69 |
-
{
|
70 |
-
"data": {
|
71 |
-
"text/plain": [
|
72 |
-
"OrderedDict([(99,\n",
|
73 |
-
" {'id': 99,\n",
|
74 |
-
" 'alias': 'OpenML-CC18',\n",
|
75 |
-
" 'main_entity_type': 'task',\n",
|
76 |
-
" 'name': 'OpenML-CC18 Curated Classification benchmark',\n",
|
77 |
-
" 'status': 'active',\n",
|
78 |
-
" 'creation_date': '2019-02-21 18:47:13',\n",
|
79 |
-
" 'creator': 1}),\n",
|
80 |
-
" (225,\n",
|
81 |
-
" {'id': 225,\n",
|
82 |
-
" 'alias': 'OpenML-friendly',\n",
|
83 |
-
" 'main_entity_type': 'task',\n",
|
84 |
-
" 'name': 'OpenML100-friendly',\n",
|
85 |
-
" 'status': 'active',\n",
|
86 |
-
" 'creation_date': '2019-09-16 19:41:46',\n",
|
87 |
-
" 'creator': 1})])"
|
88 |
-
]
|
89 |
-
},
|
90 |
-
"execution_count": 8,
|
91 |
-
"metadata": {},
|
92 |
-
"output_type": "execute_result"
|
93 |
-
}
|
94 |
-
],
|
95 |
"source": [
|
96 |
"openml.study.list_suites()"
|
97 |
]
|
98 |
},
|
99 |
{
|
100 |
"cell_type": "code",
|
101 |
-
"execution_count":
|
102 |
"metadata": {},
|
103 |
"outputs": [],
|
104 |
"source": [
|
@@ -108,7 +73,7 @@
|
|
108 |
},
|
109 |
{
|
110 |
"cell_type": "code",
|
111 |
-
"execution_count":
|
112 |
"metadata": {},
|
113 |
"outputs": [],
|
114 |
"source": [
|
@@ -120,7 +85,7 @@
|
|
120 |
},
|
121 |
{
|
122 |
"cell_type": "code",
|
123 |
-
"execution_count":
|
124 |
"metadata": {},
|
125 |
"outputs": [],
|
126 |
"source": [
|
@@ -130,27 +95,16 @@
|
|
130 |
},
|
131 |
{
|
132 |
"cell_type": "code",
|
133 |
-
"execution_count":
|
134 |
"metadata": {},
|
135 |
-
"outputs": [
|
136 |
-
{
|
137 |
-
"data": {
|
138 |
-
"text/plain": [
|
139 |
-
"30"
|
140 |
-
]
|
141 |
-
},
|
142 |
-
"execution_count": 12,
|
143 |
-
"metadata": {},
|
144 |
-
"output_type": "execute_result"
|
145 |
-
}
|
146 |
-
],
|
147 |
"source": [
|
148 |
"len(tids)"
|
149 |
]
|
150 |
},
|
151 |
{
|
152 |
"cell_type": "code",
|
153 |
-
"execution_count":
|
154 |
"metadata": {},
|
155 |
"outputs": [],
|
156 |
"source": [
|
@@ -159,7 +113,7 @@
|
|
159 |
},
|
160 |
{
|
161 |
"cell_type": "code",
|
162 |
-
"execution_count":
|
163 |
"metadata": {},
|
164 |
"outputs": [],
|
165 |
"source": [
|
@@ -169,20 +123,23 @@
|
|
169 |
{
|
170 |
"cell_type": "code",
|
171 |
"execution_count": null,
|
172 |
-
"outputs": [],
|
173 |
-
"source": [
|
174 |
-
"open_ml_datasets, open_ml_datasets_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 100000, num_feats=100, return_capped=True)\n"
|
175 |
-
],
|
176 |
"metadata": {
|
177 |
"collapsed": false,
|
|
|
|
|
|
|
178 |
"pycharm": {
|
179 |
"name": "#%%\n"
|
180 |
}
|
181 |
-
}
|
|
|
|
|
|
|
|
|
182 |
},
|
183 |
{
|
184 |
"cell_type": "code",
|
185 |
-
"execution_count":
|
186 |
"metadata": {},
|
187 |
"outputs": [],
|
188 |
"source": [
|
@@ -191,41 +148,9 @@
|
|
191 |
},
|
192 |
{
|
193 |
"cell_type": "code",
|
194 |
-
"execution_count":
|
195 |
"metadata": {},
|
196 |
-
"outputs": [
|
197 |
-
{
|
198 |
-
"name": "stdout",
|
199 |
-
"output_type": "stream",
|
200 |
-
"text": [
|
201 |
-
"\\begin{tabular}{lrrrrrrr}\n",
|
202 |
-
"\\toprule\n",
|
203 |
-
" Name & \\# Features & \\# Categorical Features & \\# Instances & \\# Classes & \\# NaNs & Minority Class Size & id \\\\\n",
|
204 |
-
"\\midrule\n",
|
205 |
-
" KDDCup09\\_appetency & 231 & 39 & 50000 & 2 & 8024152 & 890 & 1111 \\\\\n",
|
206 |
-
" airlines & 8 & 5 & 539383 & 2 & 0 & 240264 & 1169 \\\\\n",
|
207 |
-
" bank-marketing & 17 & 10 & 45211 & 2 & 0 & 5289 & 1461 \\\\\n",
|
208 |
-
" nomao & 119 & 30 & 34465 & 2 & 0 & 9844 & 1486 \\\\\n",
|
209 |
-
" adult & 15 & 9 & 48842 & 2 & 6465 & 11687 & 1590 \\\\\n",
|
210 |
-
" covertype & 55 & 45 & 581012 & 7 & 0 & 2747 & 1596 \\\\\n",
|
211 |
-
" numerai28.6 & 22 & 1 & 96320 & 2 & 0 & 47662 & 23517 \\\\\n",
|
212 |
-
" connect-4 & 43 & 43 & 67557 & 3 & 0 & 6449 & 40668 \\\\\n",
|
213 |
-
"jungle\\_chess\\_2pcs\\_raw\\_endgame\\_complete & 7 & 1 & 44819 & 3 & 0 & 4335 & 41027 \\\\\n",
|
214 |
-
" APSFailure & 171 & 1 & 76000 & 2 & 1078695 & 1375 & 41138 \\\\\n",
|
215 |
-
" albert & 79 & 53 & 425240 & 2 & 2734000 & 212620 & 41147 \\\\\n",
|
216 |
-
" MiniBooNE & 51 & 1 & 130064 & 2 & 0 & 36499 & 41150 \\\\\n",
|
217 |
-
" guillermo & 4297 & 1 & 20000 & 2 & 0 & 8003 & 41159 \\\\\n",
|
218 |
-
" riccardo & 4297 & 1 & 20000 & 2 & 0 & 5000 & 41161 \\\\\n",
|
219 |
-
" volkert & 181 & 1 & 58310 & 10 & 0 & 1361 & 41166 \\\\\n",
|
220 |
-
" dionis & 61 & 1 & 416188 & 355 & 0 & 878 & 41167 \\\\\n",
|
221 |
-
" jannis & 55 & 1 & 83733 & 4 & 0 & 1687 & 41168 \\\\\n",
|
222 |
-
" helena & 28 & 1 & 65196 & 100 & 0 & 111 & 41169 \\\\\n",
|
223 |
-
"\\bottomrule\n",
|
224 |
-
"\\end{tabular}\n",
|
225 |
-
"\n"
|
226 |
-
]
|
227 |
-
}
|
228 |
-
],
|
229 |
"source": [
|
230 |
"print_table = open_ml_datasets_df\n",
|
231 |
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
@@ -247,6 +172,15 @@
|
|
247 |
{
|
248 |
"cell_type": "code",
|
249 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
"outputs": [],
|
251 |
"source": [
|
252 |
"open_cc_datasets, open_cc_datasets_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 2000, num_feats=100, return_capped=True)\n",
|
@@ -319,13 +253,7 @@
|
|
319 |
"\n",
|
320 |
"# Remove too easy\n",
|
321 |
"openml_list = openml_list[openml_list.CfsSubsetEval_DecisionStumpAUC != 1]"
|
322 |
-
]
|
323 |
-
"metadata": {
|
324 |
-
"collapsed": false,
|
325 |
-
"pycharm": {
|
326 |
-
"name": "#%%\n"
|
327 |
-
}
|
328 |
-
}
|
329 |
},
|
330 |
{
|
331 |
"cell_type": "code",
|
@@ -365,9 +293,9 @@
|
|
365 |
"name": "python",
|
366 |
"nbconvert_exporter": "python",
|
367 |
"pygments_lexer": "ipython3",
|
368 |
-
"version": "3.
|
369 |
}
|
370 |
},
|
371 |
"nbformat": 4,
|
372 |
"nbformat_minor": 4
|
373 |
-
}
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
14 |
},
|
15 |
{
|
16 |
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
25 |
},
|
26 |
{
|
27 |
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
"metadata": {},
|
30 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"source": [
|
32 |
"%load_ext autoreload\n",
|
33 |
"\n",
|
|
|
45 |
},
|
46 |
{
|
47 |
"cell_type": "code",
|
48 |
+
"execution_count": null,
|
49 |
"metadata": {},
|
50 |
"outputs": [],
|
51 |
"source": [
|
|
|
54 |
},
|
55 |
{
|
56 |
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
"metadata": {},
|
59 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
"source": [
|
61 |
"openml.study.list_suites()"
|
62 |
]
|
63 |
},
|
64 |
{
|
65 |
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
"metadata": {},
|
68 |
"outputs": [],
|
69 |
"source": [
|
|
|
73 |
},
|
74 |
{
|
75 |
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
"metadata": {},
|
78 |
"outputs": [],
|
79 |
"source": [
|
|
|
85 |
},
|
86 |
{
|
87 |
"cell_type": "code",
|
88 |
+
"execution_count": null,
|
89 |
"metadata": {},
|
90 |
"outputs": [],
|
91 |
"source": [
|
|
|
95 |
},
|
96 |
{
|
97 |
"cell_type": "code",
|
98 |
+
"execution_count": null,
|
99 |
"metadata": {},
|
100 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
"source": [
|
102 |
"len(tids)"
|
103 |
]
|
104 |
},
|
105 |
{
|
106 |
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
"metadata": {},
|
109 |
"outputs": [],
|
110 |
"source": [
|
|
|
113 |
},
|
114 |
{
|
115 |
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
"metadata": {},
|
118 |
"outputs": [],
|
119 |
"source": [
|
|
|
123 |
{
|
124 |
"cell_type": "code",
|
125 |
"execution_count": null,
|
|
|
|
|
|
|
|
|
126 |
"metadata": {
|
127 |
"collapsed": false,
|
128 |
+
"jupyter": {
|
129 |
+
"outputs_hidden": false
|
130 |
+
},
|
131 |
"pycharm": {
|
132 |
"name": "#%%\n"
|
133 |
}
|
134 |
+
},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"open_ml_datasets, open_ml_datasets_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 100000, num_feats=100, return_capped=True)\n"
|
138 |
+
]
|
139 |
},
|
140 |
{
|
141 |
"cell_type": "code",
|
142 |
+
"execution_count": null,
|
143 |
"metadata": {},
|
144 |
"outputs": [],
|
145 |
"source": [
|
|
|
148 |
},
|
149 |
{
|
150 |
"cell_type": "code",
|
151 |
+
"execution_count": null,
|
152 |
"metadata": {},
|
153 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
"source": [
|
155 |
"print_table = open_ml_datasets_df\n",
|
156 |
"print_table = print_table[['name', 'NumberOfFeatures', 'NumberOfSymbolicFeatures', 'NumberOfInstances', 'NumberOfClasses', 'NumberOfMissingValues', 'MinorityClassSize']].copy()\n",
|
|
|
172 |
{
|
173 |
"cell_type": "code",
|
174 |
"execution_count": null,
|
175 |
+
"metadata": {
|
176 |
+
"collapsed": false,
|
177 |
+
"jupyter": {
|
178 |
+
"outputs_hidden": false
|
179 |
+
},
|
180 |
+
"pycharm": {
|
181 |
+
"name": "#%%\n"
|
182 |
+
}
|
183 |
+
},
|
184 |
"outputs": [],
|
185 |
"source": [
|
186 |
"open_cc_datasets, open_cc_datasets_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 2000, num_feats=100, return_capped=True)\n",
|
|
|
253 |
"\n",
|
254 |
"# Remove too easy\n",
|
255 |
"openml_list = openml_list[openml_list.CfsSubsetEval_DecisionStumpAUC != 1]"
|
256 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
},
|
258 |
{
|
259 |
"cell_type": "code",
|
|
|
293 |
"name": "python",
|
294 |
"nbconvert_exporter": "python",
|
295 |
"pygments_lexer": "ipython3",
|
296 |
+
"version": "3.9.6"
|
297 |
}
|
298 |
},
|
299 |
"nbformat": 4,
|
300 |
"nbformat_minor": 4
|
301 |
+
}
|
TabPFN/PriorFittingCustomPrior.ipynb
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"tags": []
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"## Setup"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"%load_ext autoreload\n",
|
19 |
+
"\n",
|
20 |
+
"%autoreload 2"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"import random\n",
|
30 |
+
"import time\n",
|
31 |
+
"import warnings\n",
|
32 |
+
"from datetime import datetime\n",
|
33 |
+
"\n",
|
34 |
+
"import torch\n",
|
35 |
+
"\n",
|
36 |
+
"import numpy as np\n",
|
37 |
+
"\n",
|
38 |
+
"import matplotlib.pyplot as plt\n",
|
39 |
+
"from scripts.differentiable_pfn_evaluation import eval_model_range\n",
|
40 |
+
"from scripts.model_builder import get_model, get_default_spec, save_model, load_model\n",
|
41 |
+
"from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow\n",
|
42 |
+
"\n",
|
43 |
+
"from scripts.model_configs import *\n",
|
44 |
+
"\n",
|
45 |
+
"from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n",
|
46 |
+
"from priors.utils import plot_prior, plot_features\n",
|
47 |
+
"from priors.utils import uniform_int_sampler_f\n",
|
48 |
+
"\n",
|
49 |
+
"from scripts.tabular_metrics import calculate_score_per_method, calculate_score\n",
|
50 |
+
"from scripts.tabular_evaluation import evaluate\n",
|
51 |
+
"\n",
|
52 |
+
"from priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info\n",
|
53 |
+
"from scripts import tabular_metrics\n",
|
54 |
+
"from notebook_utils import *"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"large_datasets = True\n",
|
64 |
+
"max_samples = 10000 if large_datasets else 5000\n",
|
65 |
+
"bptt = 10000 if large_datasets else 3000\n",
|
66 |
+
"suite='cc'"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"execution_count": null,
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"device = 'cpu'\n",
|
76 |
+
"base_path = '.'\n",
|
77 |
+
"max_features = 100"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"def print_models(model_string):\n",
|
87 |
+
" print(model_string)\n",
|
88 |
+
"\n",
|
89 |
+
" for i in range(80):\n",
|
90 |
+
" for e in range(50):\n",
|
91 |
+
" exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()\n",
|
92 |
+
" if exists:\n",
|
93 |
+
" print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))\n",
|
94 |
+
" print()"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": null,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [],
|
102 |
+
"source": [
|
103 |
+
"def train_function(config_sample, i, add_name=''):\n",
|
104 |
+
" start_time = time.time()\n",
|
105 |
+
" N_epochs_to_save = 50\n",
|
106 |
+
" \n",
|
107 |
+
" def save_callback(model, epoch):\n",
|
108 |
+
" if not hasattr(model, 'last_saved_epoch'):\n",
|
109 |
+
" model.last_saved_epoch = 0\n",
|
110 |
+
" if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:\n",
|
111 |
+
" print('Saving model..')\n",
|
112 |
+
" config_sample['epoch_in_training'] = epoch\n",
|
113 |
+
" save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',\n",
|
114 |
+
" config_sample)\n",
|
115 |
+
" model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint\n",
|
116 |
+
" \n",
|
117 |
+
" model = get_model(config_sample\n",
|
118 |
+
" , device\n",
|
119 |
+
" , should_train=True\n",
|
120 |
+
" , verbose=1\n",
|
121 |
+
" , epoch_callback = save_callback)\n",
|
122 |
+
" \n",
|
123 |
+
" return"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "markdown",
|
128 |
+
"metadata": {
|
129 |
+
"tags": []
|
130 |
+
},
|
131 |
+
"source": [
|
132 |
+
"## Define prior settings"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": 17,
|
138 |
+
"metadata": {
|
139 |
+
"scrolled": true
|
140 |
+
},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"def reload_config(config_type='causal', task_type='multiclass', longer=0):\n",
|
144 |
+
" config = get_prior_config(config_type=config_type)\n",
|
145 |
+
" \n",
|
146 |
+
" config['prior_type'], config['differentiable'], config['flexible'] = 'prior_bag', True, True\n",
|
147 |
+
" \n",
|
148 |
+
" model_string = ''\n",
|
149 |
+
" \n",
|
150 |
+
" config['epochs'] = 12000\n",
|
151 |
+
" config['recompute_attn'] = True\n",
|
152 |
+
"\n",
|
153 |
+
" config['max_num_classes'] = 10\n",
|
154 |
+
" config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])\n",
|
155 |
+
" config['balanced'] = False\n",
|
156 |
+
" model_string = model_string + '_multiclass'\n",
|
157 |
+
" \n",
|
158 |
+
" model_string = model_string + '_'+datetime.now().strftime(\"%m_%d_%Y_%H_%M_%S\")\n",
|
159 |
+
" \n",
|
160 |
+
" return config, model_string"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"metadata": {
|
166 |
+
"tags": []
|
167 |
+
},
|
168 |
+
"source": [
|
169 |
+
"## Visualize Prior samples"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": 19,
|
175 |
+
"metadata": {
|
176 |
+
"tags": []
|
177 |
+
},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"config, model_string = reload_config(longer=1)\n",
|
181 |
+
"\n",
|
182 |
+
"config['bptt_extra_samples'] = None\n",
|
183 |
+
"\n",
|
184 |
+
"# diff\n",
|
185 |
+
"config['output_multiclass_ordered_p'] = 0.\n",
|
186 |
+
"del config['differentiable_hyperparameters']['output_multiclass_ordered_p']\n",
|
187 |
+
"\n",
|
188 |
+
"config['multiclass_type'] = 'rank'\n",
|
189 |
+
"del config['differentiable_hyperparameters']['multiclass_type']\n",
|
190 |
+
"\n",
|
191 |
+
"config['sampling'] = 'normal' # vielleicht schlecht?\n",
|
192 |
+
"del config['differentiable_hyperparameters']['sampling']\n",
|
193 |
+
"\n",
|
194 |
+
"config['pre_sample_causes'] = True\n",
|
195 |
+
"# end diff\n",
|
196 |
+
"\n",
|
197 |
+
"config['multiclass_loss_type'] = 'nono' # 'compatible'\n",
|
198 |
+
"config['normalize_to_ranking'] = False # False\n",
|
199 |
+
"\n",
|
200 |
+
"config['categorical_feature_p'] = .2 # diff: .0\n",
|
201 |
+
"\n",
|
202 |
+
"# turn this back on in a random search!?\n",
|
203 |
+
"config['nan_prob_no_reason'] = .0\n",
|
204 |
+
"config['nan_prob_unknown_reason'] = .0 # diff: .0\n",
|
205 |
+
"config['set_value_to_nan'] = .1 # diff: 1.\n",
|
206 |
+
"\n",
|
207 |
+
"config['normalize_with_sqrt'] = False\n",
|
208 |
+
"\n",
|
209 |
+
"config['new_mlp_per_example'] = True\n",
|
210 |
+
"config['prior_mlp_scale_weights_sqrt'] = True\n",
|
211 |
+
"config['batch_size_per_gp_sample'] = None\n",
|
212 |
+
"\n",
|
213 |
+
"config['normalize_ignore_label_too'] = False\n",
|
214 |
+
"\n",
|
215 |
+
"config['differentiable_hps_as_style'] = False\n",
|
216 |
+
"config['max_eval_pos'] = 1000\n",
|
217 |
+
"\n",
|
218 |
+
"config['random_feature_rotation'] = True\n",
|
219 |
+
"config['rotate_normalized_labels'] = True\n",
|
220 |
+
"\n",
|
221 |
+
"config[\"mix_activations\"] = False # False heisst eig True\n",
|
222 |
+
"\n",
|
223 |
+
"config['emsize'] = 512\n",
|
224 |
+
"config['nhead'] = config['emsize'] // 128\n",
|
225 |
+
"config['bptt'] = 1024+128\n",
|
226 |
+
"config['canonical_y_encoder'] = False\n",
|
227 |
+
"\n",
|
228 |
+
" \n",
|
229 |
+
"config['aggregate_k_gradients'] = 8\n",
|
230 |
+
"config['batch_size'] = 8*config['aggregate_k_gradients']\n",
|
231 |
+
"config['num_steps'] = 1024//config['aggregate_k_gradients']\n",
|
232 |
+
"config['epochs'] = 400\n",
|
233 |
+
"config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...\n",
|
234 |
+
"\n",
|
235 |
+
"config['train_mixed_precision'] = True\n",
|
236 |
+
"config['efficient_eval_masking'] = True\n",
|
237 |
+
"\n",
|
238 |
+
"config_sample = evaluate_hypers(config)"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": 25,
|
244 |
+
"metadata": {},
|
245 |
+
"outputs": [
|
246 |
+
{
|
247 |
+
"name": "stdout",
|
248 |
+
"output_type": "stream",
|
249 |
+
"text": [
|
250 |
+
"Using style prior: True\n",
|
251 |
+
"MODEL BUILDER <module 'priors.differentiable_prior' from '/home/hollmann/TabPFN/priors/differentiable_prior.py'> <function get_model.<locals>.make_get_batch.<locals>.new_get_batch at 0x7f24bd339af0>\n",
|
252 |
+
"Using cpu:0 device\n",
|
253 |
+
"init dist\n",
|
254 |
+
"Not using distributed\n",
|
255 |
+
"DataLoader.__dict__ {'num_steps': 33554432, 'get_batch_kwargs': {'batch_size': 1, 'eval_pos_seq_len_sampler': <function train.<locals>.eval_pos_seq_len_sampler at 0x7f24bd493ee0>, 'seq_len_maximum': 1152, 'device': 'cpu:0', 'num_features': 100, 'hyperparameters': {'lr': 0.00011555441385381896, 'dropout': 0.0, 'emsize': 512, 'batch_size': 1, 'nlayers': 12, 'num_features': 100, 'nhead': 4, 'nhid_factor': 2, 'bptt': 1152, 'eval_positions': [1094], 'seq_len_used': 50, 'sampling': 'normal', 'epochs': 400, 'num_steps': 33554432, 'verbose': True, 'mix_activations': False, 'pre_sample_causes': True, 'multiclass_type': 'rank', 'nan_prob_unknown_reason_reason_prior': 0.5, 'categorical_feature_p': 0.2, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 10, 'num_classes': <function <lambda>.<locals>.<lambda> at 0x7f24c2d03ee0>, 'noise_type': 'Gaussian', 'balanced': False, 'normalize_to_ranking': False, 'set_value_to_nan': 0.1, 'normalize_by_used_features': True, 'num_features_used': <function <lambda>.<locals>.<lambda> at 0x7f24c2d03e50>, 'num_categorical_features_sampler_a': -1.0, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 2.0, 'max': 10.0}, 'prior_type': 'prior_bag', 'differentiable': True, 'flexible': True, 'recompute_attn': True, 'bptt_extra_samples': None, 'output_multiclass_ordered_p': 0.0, 'multiclass_loss_type': 'nono', 'normalize_with_sqrt': False, 'new_mlp_per_example': True, 'prior_mlp_scale_weights_sqrt': True, 'batch_size_per_gp_sample': None, 'normalize_ignore_label_too': False, 'differentiable_hps_as_style': False, 'max_eval_pos': 1000, 'random_feature_rotation': True, 'rotate_normalized_labels': True, 'canonical_y_encoder': False, 'aggregate_k_gradients': 8, 'total_available_time_in_s': None, 'train_mixed_precision': True, 'efficient_eval_masking': True, 'prior_bag_get_batch': (<function get_model.<locals>.make_get_batch.<locals>.new_get_batch at 0x7f24bf3e8550>, <function get_model.<locals>.make_get_batch.<locals>.new_get_batch at 0x7f24bd339e50>), 'prior_bag_exp_weights_1': 2.0, 'normalize_labels': True, 'check_is_compatible': True}, 'batch_size_per_gp_sample': None, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.new_get_batch at 0x7f24bd339af0>, 'differentiable_hyperparameters': {'prior_bag_exp_weights_1': {'distribution': 'uniform', 'min': 2.0, 'max': 10.0}, 'num_layers': {'distribution': 'meta_gamma', 'max_alpha': 2, 'max_scale': 3, 'round': True, 'lower_bound': 2}, 'prior_mlp_hidden_dim': {'distribution': 'meta_gamma', 'max_alpha': 3, 'max_scale': 100, 'round': True, 'lower_bound': 4}, 'prior_mlp_dropout_prob': {'distribution': 'meta_beta', 'scale': 0.6, 'min': 0.1, 'max': 5.0}, 'noise_std': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 0.3, 'min_mean': 0.0001, 'round': False, 'lower_bound': 0.0}, 'init_std': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 0.01, 'round': False, 'lower_bound': 0.0}, 'num_causes': {'distribution': 'meta_gamma', 'max_alpha': 3, 'max_scale': 7, 'round': True, 'lower_bound': 2}, 'is_causal': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'pre_sample_weights': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'y_is_effect': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'prior_mlp_activations': {'distribution': 'meta_choice_mixed', 'choice_values': [<class 'torch.nn.modules.activation.Tanh'>, <class 'torch.nn.modules.linear.Identity'>, <class 'torch.nn.modules.activation.ReLU'>]}, 'block_wise_dropout': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'sort_features': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'in_clique': {'distribution': 'meta_choice', 'choice_values': [True, False]}, 'outputscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 1e-05, 'round': False, 'lower_bound': 0}, 'lengthscale': {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 1e-05, 'round': False, 'lower_bound': 0}, 'noise': {'distribution': 'meta_choice', 'choice_values': [1e-05, 0.0001, 0.01]}}}, 'num_features': 100, 'epoch_count': 0}\n",
|
256 |
+
"PRIOR_BAG: tensor([1.0000, 2.3162]) [1]\n",
|
257 |
+
"{'is_causal': False, 'num_causes': 4, 'prior_mlp_hidden_dim': 6, 'num_layers': 2, 'noise_std': 0.0021951181710037487, 'y_is_effect': True, 'pre_sample_weights': True, 'prior_mlp_dropout_prob': 0.11217365522242403, 'pre_sample_causes': True}\n",
|
258 |
+
"Hparams dict_keys(['prior_bag_exp_weights_1', 'num_layers_alpha', 'num_layers_scale', 'prior_mlp_hidden_dim_alpha', 'prior_mlp_hidden_dim_scale', 'prior_mlp_dropout_prob_b', 'prior_mlp_dropout_prob_k', 'noise_std_log_mean', 'noise_std_log_std', 'init_std_log_mean', 'init_std_log_std', 'num_causes_alpha', 'num_causes_scale', 'is_causal_choice_1_weight', 'pre_sample_weights_choice_1_weight', 'y_is_effect_choice_1_weight', 'prior_mlp_activations_choice_1_weight', 'prior_mlp_activations_choice_2_weight', 'block_wise_dropout_choice_1_weight', 'sort_features_choice_1_weight', 'in_clique_choice_1_weight', 'outputscale_log_mean', 'outputscale_log_std', 'lengthscale_log_mean', 'lengthscale_log_std', 'noise_choice_1_weight', 'noise_choice_2_weight'])\n",
|
259 |
+
"Style definition of first 3 examples: None\n",
|
260 |
+
"Using a Transformer with 25.82 M parameters\n",
|
261 |
+
"PRIOR_BAG: tensor([1.0000, 7.0192]) [1]\n",
|
262 |
+
"{'is_causal': True, 'num_causes': 2, 'prior_mlp_hidden_dim': 10, 'num_layers': 2, 'noise_std': 0.0031679113358953426, 'y_is_effect': False, 'pre_sample_weights': True, 'prior_mlp_dropout_prob': 0.009754962364049987, 'pre_sample_causes': True}\n",
|
263 |
+
"Hparams dict_keys(['prior_bag_exp_weights_1', 'num_layers_alpha', 'num_layers_scale', 'prior_mlp_hidden_dim_alpha', 'prior_mlp_hidden_dim_scale', 'prior_mlp_dropout_prob_b', 'prior_mlp_dropout_prob_k', 'noise_std_log_mean', 'noise_std_log_std', 'init_std_log_mean', 'init_std_log_std', 'num_causes_alpha', 'num_causes_scale', 'is_causal_choice_1_weight', 'pre_sample_weights_choice_1_weight', 'y_is_effect_choice_1_weight', 'prior_mlp_activations_choice_1_weight', 'prior_mlp_activations_choice_2_weight', 'block_wise_dropout_choice_1_weight', 'sort_features_choice_1_weight', 'in_clique_choice_1_weight', 'outputscale_log_mean', 'outputscale_log_std', 'lengthscale_log_mean', 'lengthscale_log_std', 'noise_choice_1_weight', 'noise_choice_2_weight'])\n"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"data": {
|
268 |
+
"image/png": "\n",
|
269 |
+
"text/plain": [
|
270 |
+
"<Figure size 576x576 with 10 Axes>"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
"metadata": {},
|
274 |
+
"output_type": "display_data"
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"data": {
|
278 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARAAAAEDCAYAAAD9SFsgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUbElEQVR4nO3df5BdZX3H8fc3u5vfptklm7AkCEHWEH4I2JWfWhkiGsCStFNacLAZm5qOVURGhya1FjvVKbaOYseCZoiaKo1FYEykCMZVRrQaWYTBhCQkgvnFkmwoJJhEdjf77R/POfeuYTcbvifrPRs+rxnm3nvuc8/9niX3M8855znPMXdHRCRiVK0LEJGRSwEiImEKEBEJU4CISJgCRETCFCAiElaKADGzuWa20cw2m9niWtdzKDM70cx+aGbrzWydmd2QLW8ys9Vmtil7bKx1rf2ZWZ2ZPWZm92WvS1uvmU02s7vNbEP2d76w5PXemP1bWGtmK8xsbNnqNbOvmNkuM1vbb9mgNZrZkuw3uNHM3nUk31HzADGzOuA/gMuB04Frzez02lb1Cr3AR919NnAB8MGsxsVAu7u3Au3Z6zK5AVjf73WZ6/0C8IC7nwacTaq7lPWa2XTgw0Cbu58J1AHXUL56vwbMPWTZgDVm/56vAc7IPnNb9ts8PHev6X/AhcCD/V4vAZbUuq4hal4JXAZsBFqyZS3AxlrX1q/GGdk/kEuB+7JlpawXmAQ8A9ghy8ta73RgG9AE1AP3Ae8sY73AycDaof6mh/7ugAeBC4daf817IFT/Z+S2Z8tKycxOBs4F1gDT3L0TIHucWsPSDnUrcBPQ129ZWes9BegCvprtct1hZhMoab3uvgP4LLAV6AT2uPv3KGm9hxisxtDvsAwBYgMsK+X4ejObCNwDfMTd99a6nsGY2buBXe7+aK1rOUL1wJuB2939XGAfte/+Dyo7bjAPmAmcAEwws+tqW1Vhod9hGQJkO3Biv9czgGdrVMugzKyBFB53uvu92eKdZtaSvd8C7KpVfYe4GLjKzH4NfBO41My+QXnr3Q5sd/c12eu7SYFS1nrfATzj7l3u3gPcC1xEeevtb7AaQ7/DMgTII0Crmc00s9GkAzmralzT7zAzA5YB6939c/3eWgUsyJ4vIB0bqTl3X+LuM9z9ZNLf8wfufh3lrfc5YJuZzcoWzQGepKT1knZdLjCz8dm/jTmkg75lrbe/wWpcBVxjZmPMbCbQCvx8yLXV+iBPdsDmCuAp4FfAx2tdzwD1vZXUnXsCeDz77wrgONKByk3ZY1Otax2g9kuoHkQtbb3AOUBH9jf+NtBY8nr/CdgArAW+DowpW73ACtIxmh5SD2Ph4WoEPp79BjcClx/Jd1j2QRGRV60MuzAiMkIpQEQkTAEiImEKEBEJU4CISNiwBcirvcLWzBYNVy3DZaTVrHqH12ux3mEJkOAVtiPqj58ZaTWr3uH1mqt3uHog5wGb3f1pd+8mDaeeN0zfJSI1MiwDyczsz4C57v7X2ev3Aue7+4cGaj+lqc4njB/Frt7jARjVU32vryFbZ3ZNacOebgC8ob76fb0H07L6NH1B74SUiwfHpPdHdVfX59nHGn6TttstXUPUMzFbvq/atm7PAQB+e8I4AM6aNu136u7q6qK5uXmgTSol1Tu8Rlq96x7fyOhR49jb07Xb3UOF1w/dJGTIK/uy/a9FAA2va+T4az/BMx+9DYA3ffZvK+32nZiSY8zuFAqvf2DPK1bcNzptRvfk0QB0nZtS57i3dwKwbcuUSttJU38DQOOXU2K88MbUdvRlu1Nd3zqu0rZx+U8B2P2lNwLQceWnB91gkZFm7uwlADy44ZYt0XUM1y7MkFf2uftSd29z97a68ROGqQwRGU7D1QOpXGEL7CBdEfqewRqP6oHxO/sqPY8nPnZb5b22T3wAgLqe1IEZteW59EbjH1Ta1B1MvRSvmwzAhB1ps7Y/leZKmbyhOjPbgalpCkjztJ/UtDE9Pjsl9TxaunpfUd9La7NeyZWDb7DISOPbis+aMSwB4u69ZvYh0rRodcBX3H3dcHyXiNTOcPVAcPf7gfuHa/0iUnvDFiCvRl8D7J86qnLANN9tAej459sBmL8pzTK/9qx0QLPxyernd5+fzsKc/sl0LOiq29LjskfeCsDs9zxVaXvKhHSwdNXetwEw7Yo0DeT8pvR41/S2StvZj6QD037SgWIbKFJCz77/nPTk1v8Mr0ND2UUkrBQ9EOuDhn1eOVWbHzCFas/j260PAjDr4Q+8oo29nD7nPemA6EsHxwJQvzudou3cP6nStrsvHVDNx5e8eCCN8dj0UuptWH11EnMbnRr17h1dbANFSqiuu/gYMPVARCSsFD2Qhj3dHP/Adg5OST2Fyqlaqsc88p7Hxr9Kx0TmzntvpU3TwucB2HfRqQB86ydpHNusmx8DYNOy2ZW2W15Io11nrE+na19/VZqU+mc/SW3GvVAdA9c9M50GPu32l9KCkXalg8hh7LskG3Z92+HbHY56ICISVooeiDfU09vSiPWksyn9B4nlZ1vyYx55z+OBlV+vtLns2vcBMP6RdA/hpuPPBqD7ojMAaLmnupm9Y1MPY1znfgC2fib1cJrHp/Xnx0YA6takL//tpW8qsnkipXTqTS8AsLnAOtQDEZEwBYiIhJViF8Z6D1K/ay8909KuS35tC1QHieWnavMDpvluC8DqFV8F4G0f+hsAnj8vHSCdeu8zADR8a1yl7bYXJqcn97wOgBs/uQKAJWv+FIBRO8ZW2k54+1kA1O8/WGj7RMpo663pN8D8+DrUAxGRsFL0QLy+jt6pkyrzeeRX1UJ1eHo+SCw/VZsfMIVqz+PhL34ZgCvOuhSAbe87DYDp76ze4vOkN2S9kV3pSsRlm68CoLUnzTpk69dX2toprweg6/zqHCEix4qpX0y/hSeHaHc46oGISFgpeiC9E0ax8y0T6cl2yfL5PKB6YVw+PD0fJJafqoXqMY+853H/L38AwBv+O52infKjiZW2TaPT+r7/7bcAcOplTwMwY/yLADy05dRK25k3pmW7L64OhRc5Vjx/Rjbn5/fj61APRETCStEDOTgG9rYeZMasNKw8n0kMqpfk5xfG5cPT80FiUD3bkh/zyHsev/qLLwFwyvcWVtqe15rajkljaNi5bCYAM65P67VH+/U2etOl/x+76MFswd9FN1GkdFrueByAXxZYh3ogIhJWih7IqG6YsLWObWPT7On95zDNJwPKL8nPL4zrPzw9H+eRn23Jj3nkPY+n37ms0nZdd5oc6E+mzwLgU3+expA016UL5r7/plmVtgePT2dfbt/wRwBcf1qRrRQpl42fSeOcGPBmK0dGPRARCVOAiEhYKXZhvB5ePs4rN33Kb70A1TlM85nE8vk88qtqoTo8PR8klp+qzQ+Y5rstAGeMTm08i84dPU0AzB2Xrs4dM7Z6W7y6PWnZ+SfsLrR9ImU0Znfd0I2GoB6IiISVogfS8Bun5ScHsezgZ37TJ6jOnp7P05HPJJbP5wFULozLh6fng8TyU7X5AVOo9jw2/WWa2ezMf083s/q3yZ59ptqz6Xvu1wD84uvnpAXnRbZOpJxmtKff0FNDtDsc9UBEJKwUPRA3o6/B2PuGVE5+u0mo3rclnz09n8M0n0kMqpfk5xfG5cPT80Fi+alaqB7zyHseaz+cJoT81O50jnb52gsqbbsvTKeMuzWSXY5BLy7OevEPx9ehHoiIhJWiB9IzEZ67cBSTT09nO/IbXUP1jnH5fVvy2dPzOUyhOhlQfkl+fmFcPjw9HyQG1bMt+TGPvOfxD1M2ALBizB9W2o7dnCYvmvuvO4ptoEgJHXh4SuF1qAciImGl6IE07IOpHU7PutTzaOnqrbyX36s2v2Ncft+W/rOn59MQ5pMB5Zfk5xfG9R+eno/zyM+25Mc88p7HugvvrLS9kj8G4P7vpDafP7fARoqUzKQtfUM3GoJ6ICISpgARkbBS7MLU7TnApPueoG///le8N/uR7KbX2Y2u89tN5jd9gurs6fkcpvlMYvl8HvlVtem70nfkg8TyU7X5AdN8twXgf376HQAuP/WitOAfAxsnUlJuQ7cZinogIhJWih7Ib08Yx+YbzmbS7NQLeGlttcfgJ6UL4Xr3phnb8xtd97/dZH7flnz29HwO03wmsXw+D6heGJcPT88HieWnavMDplDteXx38/8W2j6RMmr68fbC61APRETCzN2HbjXM2travKOjo9ZliLwmmdmj7t4W+Wy4B2JmJ5rZD81svZmtM7MbsuVNZrbazDZlj41DrUtERqYiuzC9wEfdfTZwAfBBMzsdWAy0u3sr0J69FpFjUDhA3L3T3X+RPX8JWA9MB+YBy7Nmyyl0614RKbOjchDVzE4GzgXWANPcvRNSyABTB/nMIjPrMLOOrq6uo1GGiPyeFQ4QM5sI3AN8xN33Hunn3H2pu7e5e1tzc3PRMkSkBgoFiJk1kMLjTne/N1u808xasvdbgF3FShSRsipyFsaAZcB6d/9cv7dWAQuy5wuAlfHyRKTMioxEvRh4L/BLM3s8W/b3wC3AXWa2ENgKXF2oQhEprXCAuPuPgcEux5kTXa+IjBwayi4iYQoQEQlTgIhImAJERMIUICISpgARkTAFiIiEKUBEJEwBIiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJU4CISJgCRETCFCAiEqYAEZEwBYiIhClARCRMASIiYQoQEQlTgIhImAJERMIUICISpgARkTAFiIiEKUBEJEwBIiJhChARCVOAiEhY4QAxszoze8zM7steN5nZajPblD02Fi9TRMroaPRAbgDW93u9GGh391agPXstIsegQgFiZjOAK4E7+i2eByzPni8H5hf5DhEpr6I9kFuBm4C+fsumuXsnQPY4daAPmtkiM+sws46urq6CZYhILYQDxMzeDexy90cjn3f3pe7e5u5tzc3N0TJEpIbqC3z2YuAqM7sCGAtMMrNvADvNrMXdO82sBdh1NAoVkfIJ90DcfYm7z3D3k4FrgB+4+3XAKmBB1mwBsLJwlSJSSsMxDuQW4DIz2wRclr0WkWNQkV2YCnd/CHgoe/48MOdorFdEyk0jUUUkTAEiImEKEBEJU4CISJgCRETCFCAiEqYAEZEwBYiIhClARCRMASIiYQoQEQlTgIhImAJERMIUICISpgARkTAFiIiEKUBEJEwBIiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJU4CISJgCRETCFCAiEqYAEZEwBYiIhClARCRMASIiYQoQEQkrFCBmNtnM7jazDWa23swuNLMmM1ttZpuyx8ajVayIlEvRHsgXgAfc/TTgbGA9sBhod/dWoD17LSLHoHCAmNkk4I+AZQDu3u3uLwLzgOVZs+XA/GIlikhZFemBnAJ0AV81s8fM7A4zmwBMc/dOgOxx6kAfNrNFZtZhZh1dXV0FyhCRWikSIPXAm4Hb3f1cYB+vYnfF3Ze6e5u7tzU3NxcoQ0RqpUiAbAe2u/ua7PXdpEDZaWYtANnjrmIlikhZhQPE3Z8DtpnZrGzRHOBJYBWwIFu2AFhZqEIRKa36gp+/HrjTzEYDTwPvI4XSXWa2ENgKXF3wO0SkpAoFiLs/DrQN8NacIusVkZFBI1FFJEwBIiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJU4CISJgCRETCFCAiEqYAEZEwBYiIhClARCRMASIiYQoQEQlTgIhImAJERMIUICISpgARkTAFiIiEKUBEJEwBIiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJKxQgZnajma0zs7VmtsLMxppZk5mtNrNN2WPj0SpWRMolHCBmNh34MNDm7mcCdcA1wGKg3d1bgfbstYgcg4ruwtQD48ysHhgPPAvMA5Zn7y8H5hf8DhEpqXCAuPsO4LPAVqAT2OPu3wOmuXtn1qYTmHo0ChWR8imyC9NI6m3MBE4AJpjZda/i84vMrMPMOrq6uqJliEgNFdmFeQfwjLt3uXsPcC9wEbDTzFoAssddA33Y3Ze6e5u7tzU3NxcoQ0RqpUiAbAUuMLPxZmbAHGA9sApYkLVZAKwsVqKIlFV99IPuvsbM7gZ+AfQCjwFLgYnAXWa2kBQyVx+NQkWkfMIBAuDuNwM3H7L4ZVJvRESOcRqJKiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJU4CISJgCRETCFCAiEqYAEZEwBYiIhClARCRMASIiYQoQEQlTgIhImAJERMIUICISpgARkTAFiIiEKUBEJEwBIiJhChARCVOAiEiYAkREwhQgIhKmABGRMAWIiIQpQEQkTAEiImEKEBEJU4CISNiQAWJmXzGzXWa2tt+yJjNbbWabssfGfu8tMbPNZrbRzN41XIWLSO0dSQ/ka8DcQ5YtBtrdvRVoz15jZqcD1wBnZJ+5zczqjlq1IlIqQwaIu/8I+L9DFs8DlmfPlwPz+y3/pru/7O7PAJuB845OqSJSNvXBz01z904Ad+80s6nZ8unAz/q1254tO6xNT2zl8unX45MmAuDbnq289+z7zwGgrtsB2HfJPgBOvemFSputt74OgKlfHAfA82eMAaDljscB2PiZsyptx+xOHaIZ7fsBeHFxejzw8BQAJm3pq7R1S49NP94OwHe3fH6oTREZMfqeay28jmiADMYGWOYDNjRbBCwCGFs38SiXISK/D+Y+4O/7dxuZnQzc5+5nZq83ApdkvY8W4CF3n2VmSwDc/V+ydg8Cn3T3nw6x/i5gH7C7yMbUwBRGVs2qd3iN1HpPcvfmyAqiPZBVwALgluxxZb/l/2VmnwNOAFqBnw+1MndvNrMOd28L1lMTI61m1Tu8Xov1DhkgZrYCuASYYmbbgZtJwXGXmS0EtgJXA7j7OjO7C3gS6AU+6O4HixQoIuU1ZIC4+7WDvDVnkPafBj5dpCgRGRnKNBJ1aa0LCBhpNave4fWaq/eIDqKKiAykTD0QERlhFCAiEqYAEZEwBYiIhClARCRMASIiYf8Pmu3ntIpnOt4AAAAASUVORK5CYII=\n",
|
279 |
+
"text/plain": [
|
280 |
+
"<Figure size 288x288 with 1 Axes>"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
"metadata": {
|
284 |
+
"needs_background": "light"
|
285 |
+
},
|
286 |
+
"output_type": "display_data"
|
287 |
+
}
|
288 |
+
],
|
289 |
+
"source": [
|
290 |
+
"config_sample['batch_size'] = 4\n",
|
291 |
+
"model = get_model(config_sample, device, should_train=False, verbose=2) # , state_dict=model[2].state_dict()\n",
|
292 |
+
"(hp_embedding, data, _), targets, single_eval_pos = next(iter(model[3]))\n",
|
293 |
+
"\n",
|
294 |
+
"from utils import normalize_data\n",
|
295 |
+
"fig = plt.figure(figsize=(8, 8))\n",
|
296 |
+
"N = 100\n",
|
297 |
+
"plot_features(data[0:N, 0, 0:4], targets[0:N, 0], fig=fig)\n",
|
298 |
+
"\n",
|
299 |
+
"d = np.concatenate([data[:, 0, :].T, np.expand_dims(targets[:, 0], -1).T])\n",
|
300 |
+
"d[np.isnan(d)] = 0\n",
|
301 |
+
"c = np.corrcoef(d)\n",
|
302 |
+
"plt.matshow(np.abs(c), vmin=0, vmax=1)\n",
|
303 |
+
"plt.show()"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "markdown",
|
308 |
+
"metadata": {
|
309 |
+
"tags": []
|
310 |
+
},
|
311 |
+
"source": [
|
312 |
+
"## Training"
|
313 |
+
]
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"cell_type": "code",
|
317 |
+
"execution_count": null,
|
318 |
+
"metadata": {},
|
319 |
+
"outputs": [],
|
320 |
+
"source": [
|
321 |
+
"model = get_model(config_sample, device, should_train=True, verbose=0)"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"execution_count": null,
|
327 |
+
"metadata": {},
|
328 |
+
"outputs": [],
|
329 |
+
"source": []
|
330 |
+
}
|
331 |
+
],
|
332 |
+
"metadata": {
|
333 |
+
"kernelspec": {
|
334 |
+
"display_name": "Python 3 (ipykernel)",
|
335 |
+
"language": "python",
|
336 |
+
"name": "python3"
|
337 |
+
},
|
338 |
+
"language_info": {
|
339 |
+
"codemirror_mode": {
|
340 |
+
"name": "ipython",
|
341 |
+
"version": 3
|
342 |
+
},
|
343 |
+
"file_extension": ".py",
|
344 |
+
"mimetype": "text/x-python",
|
345 |
+
"name": "python",
|
346 |
+
"nbconvert_exporter": "python",
|
347 |
+
"pygments_lexer": "ipython3",
|
348 |
+
"version": "3.9.6"
|
349 |
+
}
|
350 |
+
},
|
351 |
+
"nbformat": 4,
|
352 |
+
"nbformat_minor": 4
|
353 |
+
}
|
TabPFN/{TabPFNPredictionOnly.ipynb → QuickPredictionDemo.ipynb}
RENAMED
@@ -44,14 +44,15 @@
|
|
44 |
"import torch\n",
|
45 |
"import numpy as np\n",
|
46 |
"import os\n",
|
47 |
-
"import random\n",
|
48 |
"\n",
|
49 |
-
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
50 |
"from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n",
|
|
|
51 |
"\n",
|
52 |
-
"from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n",
|
53 |
"\n",
|
54 |
-
"from scripts import tabular_metrics"
|
|
|
55 |
]
|
56 |
},
|
57 |
{
|
@@ -66,6 +67,7 @@
|
|
66 |
{
|
67 |
"cell_type": "markdown",
|
68 |
"metadata": {
|
|
|
69 |
"tags": []
|
70 |
},
|
71 |
"source": [
|
@@ -76,9 +78,6 @@
|
|
76 |
"cell_type": "code",
|
77 |
"execution_count": null,
|
78 |
"metadata": {
|
79 |
-
"jupyter": {
|
80 |
-
"outputs_hidden": true
|
81 |
-
},
|
82 |
"tags": []
|
83 |
},
|
84 |
"outputs": [],
|
@@ -96,27 +95,6 @@
|
|
96 |
"random.shuffle(cc_valid_datasets_multiclass)"
|
97 |
]
|
98 |
},
|
99 |
-
{
|
100 |
-
"cell_type": "code",
|
101 |
-
"execution_count": null,
|
102 |
-
"metadata": {},
|
103 |
-
"outputs": [],
|
104 |
-
"source": [
|
105 |
-
"from datasets import get_openml_classification"
|
106 |
-
]
|
107 |
-
},
|
108 |
-
{
|
109 |
-
"cell_type": "code",
|
110 |
-
"execution_count": null,
|
111 |
-
"metadata": {},
|
112 |
-
"outputs": [],
|
113 |
-
"source": [
|
114 |
-
"dataset = openml.datasets.get_dataset(31)\n",
|
115 |
-
"X, y, categorical_indicator, attribute_names = dataset.get_data(\n",
|
116 |
-
" dataset_format=\"array\", target=dataset.default_target_attribute\n",
|
117 |
-
" )"
|
118 |
-
]
|
119 |
-
},
|
120 |
{
|
121 |
"cell_type": "code",
|
122 |
"execution_count": null,
|
@@ -156,7 +134,7 @@
|
|
156 |
"tags": []
|
157 |
},
|
158 |
"source": [
|
159 |
-
"###
|
160 |
]
|
161 |
},
|
162 |
{
|
@@ -174,7 +152,7 @@
|
|
174 |
"metadata": {},
|
175 |
"outputs": [],
|
176 |
"source": [
|
177 |
-
"evaluation_dataset_index =
|
178 |
"ds = test_datasets[evaluation_dataset_index]\n",
|
179 |
"print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')"
|
180 |
]
|
@@ -191,13 +169,36 @@
|
|
191 |
"test_xs, test_ys = xs[eval_position:], ys[eval_position:]"
|
192 |
]
|
193 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
{
|
195 |
"cell_type": "markdown",
|
196 |
"metadata": {
|
|
|
197 |
"tags": []
|
198 |
},
|
199 |
"source": [
|
200 |
-
"###
|
|
|
201 |
]
|
202 |
},
|
203 |
{
|
@@ -206,9 +207,35 @@
|
|
206 |
"metadata": {},
|
207 |
"outputs": [],
|
208 |
"source": [
|
209 |
-
"
|
210 |
-
"
|
211 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
]
|
213 |
},
|
214 |
{
|
@@ -217,8 +244,110 @@
|
|
217 |
"metadata": {},
|
218 |
"outputs": [],
|
219 |
"source": [
|
220 |
-
"
|
221 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
]
|
223 |
},
|
224 |
{
|
@@ -245,7 +374,7 @@
|
|
245 |
"name": "python",
|
246 |
"nbconvert_exporter": "python",
|
247 |
"pygments_lexer": "ipython3",
|
248 |
-
"version": "3.
|
249 |
}
|
250 |
},
|
251 |
"nbformat": 4,
|
|
|
44 |
"import torch\n",
|
45 |
"import numpy as np\n",
|
46 |
"import os\n",
|
|
|
47 |
"\n",
|
48 |
+
"from scripts.model_builder import get_model, get_default_spec, save_model, load_model\n",
|
49 |
"from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n",
|
50 |
+
"from scripts.differentiable_pfn_evaluation import eval_model, eval_model_range\n",
|
51 |
"\n",
|
52 |
+
"from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification\n",
|
53 |
"\n",
|
54 |
+
"from scripts import tabular_metrics\n",
|
55 |
+
"import random"
|
56 |
]
|
57 |
},
|
58 |
{
|
|
|
67 |
{
|
68 |
"cell_type": "markdown",
|
69 |
"metadata": {
|
70 |
+
"jp-MarkdownHeadingCollapsed": true,
|
71 |
"tags": []
|
72 |
},
|
73 |
"source": [
|
|
|
78 |
"cell_type": "code",
|
79 |
"execution_count": null,
|
80 |
"metadata": {
|
|
|
|
|
|
|
81 |
"tags": []
|
82 |
},
|
83 |
"outputs": [],
|
|
|
95 |
"random.shuffle(cc_valid_datasets_multiclass)"
|
96 |
]
|
97 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
{
|
99 |
"cell_type": "code",
|
100 |
"execution_count": null,
|
|
|
134 |
"tags": []
|
135 |
},
|
136 |
"source": [
|
137 |
+
"### Run on a single dataset"
|
138 |
]
|
139 |
},
|
140 |
{
|
|
|
152 |
"metadata": {},
|
153 |
"outputs": [],
|
154 |
"source": [
|
155 |
+
"evaluation_dataset_index = 0 # Index of the dataset to predict\n",
|
156 |
"ds = test_datasets[evaluation_dataset_index]\n",
|
157 |
"print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')"
|
158 |
]
|
|
|
169 |
"test_xs, test_ys = xs[eval_position:], ys[eval_position:]"
|
170 |
]
|
171 |
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"classifier = TabPFNClassifier(device='cpu')\n",
|
179 |
+
"classifier.fit(train_xs, train_ys)\n",
|
180 |
+
"prediction_ = classifier.predict_proba(test_xs)"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": null,
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)\n",
|
190 |
+
"'AUC', float(roc), 'Cross Entropy', float(ce)"
|
191 |
+
]
|
192 |
+
},
|
193 |
{
|
194 |
"cell_type": "markdown",
|
195 |
"metadata": {
|
196 |
+
"jp-MarkdownHeadingCollapsed": true,
|
197 |
"tags": []
|
198 |
},
|
199 |
"source": [
|
200 |
+
"### Run on all datasets\n",
|
201 |
+
"This section runs a differentiable hyperparameter tuning run and saves the results to a results file, which can be inserted in TabularEval.ipynb to compare to other baselines."
|
202 |
]
|
203 |
},
|
204 |
{
|
|
|
207 |
"metadata": {},
|
208 |
"outputs": [],
|
209 |
"source": [
|
210 |
+
"eval_positions=[1000]\n",
|
211 |
+
"bptt=2000\n",
|
212 |
+
"\n",
|
213 |
+
"N_models = 3\n",
|
214 |
+
"models_per_block = 1\n",
|
215 |
+
"\n",
|
216 |
+
"eval_addition = 'user_run'\n",
|
217 |
+
"device = 'cpu'\n",
|
218 |
+
"\n",
|
219 |
+
"eval_model_range(i_range=[0], e=-1\n",
|
220 |
+
" , valid_datasets=[]#cc_valid_datasets_multiclass\n",
|
221 |
+
" , test_datasets=cc_test_datasets_multiclass\n",
|
222 |
+
" , train_datasets=[]\n",
|
223 |
+
" , eval_positions_test=eval_positions\n",
|
224 |
+
" , bptt_test=bptt\n",
|
225 |
+
" , add_name=model_string\n",
|
226 |
+
" , base_path=base_path\n",
|
227 |
+
" , selection_metric='auc'\n",
|
228 |
+
" , best_grad_steps=0\n",
|
229 |
+
" , eval_addition=eval_addition\n",
|
230 |
+
" , N_ensemble_configurations_list = [32]\n",
|
231 |
+
" , device=device)#range(0, 10)"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "markdown",
|
236 |
+
"metadata": {},
|
237 |
+
"source": [
|
238 |
+
"### Run generalization experiments"
|
239 |
]
|
240 |
},
|
241 |
{
|
|
|
244 |
"metadata": {},
|
245 |
"outputs": [],
|
246 |
"source": [
|
247 |
+
"# Loading longer OpenML Datasets for generalization experiments (optional)\n",
|
248 |
+
"test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "code",
|
253 |
+
"execution_count": null,
|
254 |
+
"metadata": {},
|
255 |
+
"outputs": [],
|
256 |
+
"source": [
|
257 |
+
"test_datasets_longer_generalization = [ds for ds in test_datasets_multiclass if ds[1].shape[0] >= 10000]"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": null,
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [],
|
265 |
+
"source": [
|
266 |
+
"def test_gen(classifier_key, split):\n",
|
267 |
+
" if classifier_key == 'tabpfn':\n",
|
268 |
+
" model = TabPFNClassifier(device='cuda', base_path='/work/dlclarge1/hollmann-PFN_Tabular/',\n",
|
269 |
+
" model_string=model_string, N_ensemble_configurations=4\n",
|
270 |
+
" , no_preprocess_mode=False, i=i, feature_shift_decoder=False)\n",
|
271 |
+
" else:\n",
|
272 |
+
" model = classifier_dict[classifier_key]\n",
|
273 |
+
" \n",
|
274 |
+
" ces = []\n",
|
275 |
+
" for k in tqdm(range(0, len(test_datasets_longer_generalization))):\n",
|
276 |
+
" x, y = test_datasets_longer_generalization[k][1], test_datasets_longer_generalization[k][2].numpy()\n",
|
277 |
+
" x = normalize_data(x).numpy()\n",
|
278 |
+
" x[np.isnan(x)] = 0.0\n",
|
279 |
+
" print(x.shape[0])\n",
|
280 |
+
" \n",
|
281 |
+
" if x.shape[0] < 10000:\n",
|
282 |
+
" continue\n",
|
283 |
+
" if len(np.unique(y)) > 2:\n",
|
284 |
+
" continue\n",
|
285 |
+
"\n",
|
286 |
+
" for bptt_ in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000]:\n",
|
287 |
+
" bptt_ = bptt_ // 2\n",
|
288 |
+
" #model = classifier_dict[classifier_key]\n",
|
289 |
+
" x_, y_ = x.copy(), y.copy()\n",
|
290 |
+
" x_train, x_test, y_train, y_test = train_test_split(x_, y_, test_size=0.5, random_state=split)\n",
|
291 |
+
" x_train, y_train = x_train[0:bptt_], y_train[0:bptt_]\n",
|
292 |
+
" model.fit(x_train, y_train) # ranking[0:j]\n",
|
293 |
+
" pred = model.predict_proba(x_test) # ranking[0:j]\n",
|
294 |
+
" ce = tabular_metrics.auc_metric(y_test, pred)\n",
|
295 |
+
" ces += [{'bptt': bptt_, 'k': k, 'm': float(ce), 'method': classifier_key, 'split': split}]\n",
|
296 |
+
" print(x_train.shape, ce)\n",
|
297 |
+
" with open(f'generalization_{classifier_key}_{split}.obj',\"wb\") as fh:\n",
|
298 |
+
" pickle.dump(ces,fh)"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"metadata": {},
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"test_gen('tabpfn', 0)"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"ces = []\n",
|
317 |
+
"for classifier_key in classifier_dict:\n",
|
318 |
+
" for split in range(0,5):\n",
|
319 |
+
" try:\n",
|
320 |
+
" with open(f'generalization_{classifier_key}_{split}.obj',\"rb\") as fh:\n",
|
321 |
+
" ces += pickle.load(fh)\n",
|
322 |
+
" except:\n",
|
323 |
+
" pass\n",
|
324 |
+
"df = pd.DataFrame(ces)"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": null,
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [],
|
332 |
+
"source": [
|
333 |
+
"df = df.groupby(['bptt', 'split', 'method']).mean().reset_index()\n",
|
334 |
+
"fig, ax = plt.subplots(1,1, figsize=(8, 6)) # , sharey=True\n",
|
335 |
+
"\n",
|
336 |
+
"colors = iter(sns.color_palette(\"tab10\"))\n",
|
337 |
+
"for classifier_key in ['tabpfn']:#df.method.unique():\n",
|
338 |
+
" c = next(colors)\n",
|
339 |
+
" sns.lineplot(x='bptt', y='m', data=df[df.method==classifier_key], label=relabeler[classifier_key], color=c, ax = ax)\n",
|
340 |
+
" #ax.text(x = df[df.method==classifier_key].iloc[50].bptt, # x-coordinate position of data label\n",
|
341 |
+
" # y = df[df.method==classifier_key].iloc[50].m, # y-coordinate position of data label, adjusted to be 150 below the data point\n",
|
342 |
+
" # s = classifier_key, # data label, formatted to ignore decimals\n",
|
343 |
+
" # color = c, size=12) # set colour of line\n",
|
344 |
+
" \n",
|
345 |
+
"ax.get_legend().remove()\n",
|
346 |
+
"ax.set(xlabel='Number of training samples')\n",
|
347 |
+
"ax.set(ylabel='ROC AUC')\n",
|
348 |
+
"plt.axvline(x=1024, linestyle='dashed', color='red')\n",
|
349 |
+
"plt.ylim((0.73,0.79))\n",
|
350 |
+
"plt.xlim((250,5000))"
|
351 |
]
|
352 |
},
|
353 |
{
|
|
|
374 |
"name": "python",
|
375 |
"nbconvert_exporter": "python",
|
376 |
"pygments_lexer": "ipython3",
|
377 |
+
"version": "3.9.6"
|
378 |
}
|
379 |
},
|
380 |
"nbformat": 4,
|
TabPFN/README.md
CHANGED
@@ -2,11 +2,8 @@
|
|
2 |
|
3 |
## Installation
|
4 |
```
|
5 |
-
git clone git@github.com:automl/TabPFN.git
|
6 |
-
cd TabPFN
|
7 |
conda create -n TabPFN python=3.7
|
8 |
-
|
9 |
-
pip install -r requirements.txt
|
10 |
```
|
11 |
|
12 |
To run the autogluon baseline please create a separate environment and install autogluon==0.4.0, installation in the same environment as our other baselines is not possible.
|
|
|
2 |
|
3 |
## Installation
|
4 |
```
|
|
|
|
|
5 |
conda create -n TabPFN python=3.7
|
6 |
+
$environment_path$/pip install -r requirements.txt
|
|
|
7 |
```
|
8 |
|
9 |
To run the autogluon baseline please create a separate environment and install autogluon==0.4.0, installation in the same environment as our other baselines is not possible.
|
TabPFN/RunFullDatasetAnalyses.ipynb
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import matplotlib.pyplot as plt\n",
|
10 |
+
"\n",
|
11 |
+
"from scripts import tabular_baselines\n",
|
12 |
+
"\n",
|
13 |
+
"import seaborn as sns\n",
|
14 |
+
"import numpy as np\n",
|
15 |
+
"\n",
|
16 |
+
"from datasets import load_openml_list, valid_dids_classification, test_dids_classification, open_cc_dids\n",
|
17 |
+
"from scripts.tabular_baselines import *\n",
|
18 |
+
"from scripts.tabular_evaluation import evaluate\n",
|
19 |
+
"from scripts.tabular_metrics import calculate_score, make_ranks_and_wins_table, make_metric_matrix\n",
|
20 |
+
"from scripts import tabular_metrics"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"from notebook_utils import *"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"%load_ext autoreload\n",
|
39 |
+
"\n",
|
40 |
+
"%autoreload 2"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"metadata": {
|
46 |
+
"tags": []
|
47 |
+
},
|
48 |
+
"source": [
|
49 |
+
"# Datasets"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": null,
|
55 |
+
"metadata": {
|
56 |
+
"tags": []
|
57 |
+
},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": null,
|
66 |
+
"metadata": {},
|
67 |
+
"outputs": [],
|
68 |
+
"source": [
|
69 |
+
"def get_datasets(selector, task_type, suite='openml'):\n",
|
70 |
+
" if task_type == 'binary':\n",
|
71 |
+
" ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary\n",
|
72 |
+
" else:\n",
|
73 |
+
" if suite == 'openml':\n",
|
74 |
+
" ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass\n",
|
75 |
+
" elif suite == 'cc':\n",
|
76 |
+
" ds = valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass\n",
|
77 |
+
" else:\n",
|
78 |
+
" raise Exception(\"Unknown suite\")\n",
|
79 |
+
" return ds"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "markdown",
|
84 |
+
"metadata": {
|
85 |
+
"tags": []
|
86 |
+
},
|
87 |
+
"source": [
|
88 |
+
"# Setting params"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": null,
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"eval_positions = [1000]\n",
|
98 |
+
"max_features = 100\n",
|
99 |
+
"bptt = 2000\n",
|
100 |
+
"selector = 'test'\n",
|
101 |
+
"base_path = os.path.join('.')\n",
|
102 |
+
"overwrite=False\n",
|
103 |
+
"max_times = [0.5, 1, 15, 30, 60, 60*5, 60*15, 60*60]\n",
|
104 |
+
"metric_used = tabular_metrics.auc_metric\n",
|
105 |
+
"methods = ['transformer', 'logistic', 'gp', 'knn', 'catboost', 'xgb', 'autosklearn2', 'autogluon']\n",
|
106 |
+
"task_type = 'multiclass'"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "code",
|
111 |
+
"execution_count": null,
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"suite = 'cc'\n",
|
116 |
+
"test_datasets = get_datasets('test',task_type, suite=suite)"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": null,
|
122 |
+
"metadata": {},
|
123 |
+
"outputs": [],
|
124 |
+
"source": [
|
125 |
+
"clf_dict= {'gp': gp_metric\n",
|
126 |
+
" , 'knn': knn_metric\n",
|
127 |
+
" , 'catboost': catboost_metric\n",
|
128 |
+
" , 'xgb': xgb_metric\n",
|
129 |
+
" , 'transformer': transformer_metric\n",
|
130 |
+
" , 'logistic': logistic_metric\n",
|
131 |
+
" , 'autosklearn': autosklearn_metric\n",
|
132 |
+
" , 'autosklearn2': autosklearn2_metric\n",
|
133 |
+
" , 'autogluon': autogluon_metric}"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"execution_count": null,
|
139 |
+
"metadata": {},
|
140 |
+
"outputs": [],
|
141 |
+
"source": [
|
142 |
+
"device = 'cpu'\n",
|
143 |
+
"\n",
|
144 |
+
"def eval_method(task_type, method, dids, selector, eval_positions, max_time, metric_used, split_number, append_metric=True, fetch_only=False, verbose=False):\n",
|
145 |
+
" \n",
|
146 |
+
" dids = dids if type(dids) is list else [dids]\n",
|
147 |
+
" \n",
|
148 |
+
" for did in dids:\n",
|
149 |
+
"\n",
|
150 |
+
" ds = get_datasets(selector, task_type, suite=suite)\n",
|
151 |
+
"\n",
|
152 |
+
" ds = ds if did is None else ds[did:did+1]\n",
|
153 |
+
"\n",
|
154 |
+
" clf = clf_dict[method]\n",
|
155 |
+
"\n",
|
156 |
+
" time_string = '_time_'+str(max_time) if max_time else ''\n",
|
157 |
+
" metric_used_string = '_'+tabular_baselines.get_scoring_string(metric_used, usage='') if append_metric else ''\n",
|
158 |
+
"\n",
|
159 |
+
" result = evaluate(datasets=ds\n",
|
160 |
+
" , model=clf\n",
|
161 |
+
" , method=method+time_string+metric_used_string\n",
|
162 |
+
" , bptt=bptt, base_path=base_path\n",
|
163 |
+
" , eval_positions=eval_positions\n",
|
164 |
+
" , device=device, max_splits=1\n",
|
165 |
+
" , overwrite=overwrite\n",
|
166 |
+
" , save=True\n",
|
167 |
+
" , metric_used=metric_used\n",
|
168 |
+
" , path_interfix=task_type\n",
|
169 |
+
" , fetch_only=fetch_only\n",
|
170 |
+
" , split_number=split_number\n",
|
171 |
+
" , verbose=verbose\n",
|
172 |
+
" , max_time=max_time)\n",
|
173 |
+
" \n",
|
174 |
+
" return result"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "markdown",
|
179 |
+
"metadata": {
|
180 |
+
"tags": []
|
181 |
+
},
|
182 |
+
"source": [
|
183 |
+
"# Baseline Evaluation\n",
|
184 |
+
"This section runs baselines and saves results locally."
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": null,
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"!mkdir {base_path}/results\n",
|
194 |
+
"!mkdir {base_path}/results/tabular/\n",
|
195 |
+
"!mkdir {base_path}/results/tabular/multiclass/"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": null,
|
201 |
+
"metadata": {
|
202 |
+
"tags": []
|
203 |
+
},
|
204 |
+
"outputs": [],
|
205 |
+
"source": [
|
206 |
+
"# RUN ONE METHOD ON ONE DATASET AND SPLIT\n",
|
207 |
+
"overwrite=True\n",
|
208 |
+
"dataset_id = 0\n",
|
209 |
+
"split_number = 1\n",
|
210 |
+
"maximum_runtime = 30\n",
|
211 |
+
"r = eval_method(task_type, 'transformer', dataset_id, 'test', eval_positions, maximum_runtime, metric_used, split_number)"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": null,
|
217 |
+
"metadata": {
|
218 |
+
"tags": []
|
219 |
+
},
|
220 |
+
"outputs": [],
|
221 |
+
"source": [
|
222 |
+
"# RUN ALL METHODS, SPLITS AND DATASETS\n",
|
223 |
+
"test_datasets = get_datasets('test',task_type, suite=suite)\n",
|
224 |
+
"\n",
|
225 |
+
"overwrite=True\n",
|
226 |
+
"jobs = [\n",
|
227 |
+
" eval_method(task_type, m, did, selector, eval_positions, max_time, metric_used, split_number)\n",
|
228 |
+
" for did in range(0, len(test_datasets))\n",
|
229 |
+
" for selector in ['test']\n",
|
230 |
+
" for m in methods\n",
|
231 |
+
" for max_time in max_times\n",
|
232 |
+
" for split_number in [1, 2, 3, 4, 5]\n",
|
233 |
+
"]"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "markdown",
|
238 |
+
"metadata": {
|
239 |
+
"tags": []
|
240 |
+
},
|
241 |
+
"source": [
|
242 |
+
"# Comparison"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "code",
|
247 |
+
"execution_count": null,
|
248 |
+
"metadata": {
|
249 |
+
"tags": []
|
250 |
+
},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"pos = str(eval_positions[0])\n",
|
254 |
+
"\n",
|
255 |
+
"global_results = {}\n",
|
256 |
+
"overwrite=False\n",
|
257 |
+
"\n",
|
258 |
+
"for method in baseline_methods:\n",
|
259 |
+
" for max_time in max_times:\n",
|
260 |
+
" for split_number in range(1,5+1):\n",
|
261 |
+
" global_results[method+'_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_split_'+str(split_number)] = eval_method(task_type, method, None, selector, \n",
|
262 |
+
" eval_positions, fetch_only=True, \n",
|
263 |
+
" verbose=False, max_time=max_time,\n",
|
264 |
+
" metric_used=metric_used, split_number=split_number)"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": null,
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"path_ = 'prior_tuning_result.pkl'\n",
|
274 |
+
"\n",
|
275 |
+
"try:\n",
|
276 |
+
" output = open(path_, 'rb')\n",
|
277 |
+
" _, metrics, _, _, _, _ = CustomUnpickler(output).load()\n",
|
278 |
+
"except:\n",
|
279 |
+
" output = open(path_, 'rb')\n",
|
280 |
+
" _, metrics, _, _, _ = CustomUnpickler(output).load()\n",
|
281 |
+
"if isinstance(metrics, list):\n",
|
282 |
+
" for i in range(1, len(metrics[1])+1):\n",
|
283 |
+
" global_results['transformer_split_'+str(i)] = metrics[2][i-1]"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"# Verify integrity of results\n",
|
293 |
+
"for bl in set(global_results.keys()):\n",
|
294 |
+
" if 'split_1' in bl:\n",
|
295 |
+
" for ds in test_datasets:\n",
|
296 |
+
" if f'{ds[0]}_ys_at_1000' not in global_results[bl]:\n",
|
297 |
+
" continue\n",
|
298 |
+
" match = (global_results[bl][f'{ds[0]}_ys_at_1000'] == global_results['transformer_split_1'][f'{ds[0]}_ys_at_1000']).float().mean()\n",
|
299 |
+
" if not match:\n",
|
300 |
+
" raise Exception(\"Not the same labels used\")\n",
|
301 |
+
" "
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": null,
|
307 |
+
"metadata": {
|
308 |
+
"tags": []
|
309 |
+
},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"limit_to = ''\n",
|
313 |
+
"calculate_score(tabular_metrics.auc_metric, 'roc', global_results, test_datasets, eval_positions + [-1], limit_to=limit_to)\n",
|
314 |
+
"calculate_score(tabular_metrics.cross_entropy, 'cross_entropy', global_results, test_datasets, eval_positions + [-1], limit_to=limit_to)\n",
|
315 |
+
"calculate_score(tabular_metrics.accuracy_metric, 'acc', global_results, test_datasets, eval_positions + [-1])\n",
|
316 |
+
"calculate_score(tabular_metrics.time_metric, 'time', global_results, test_datasets, eval_positions + [-1], aggregator='sum', limit_to=limit_to)\n",
|
317 |
+
"calculate_score(tabular_metrics.time_metric, 'time', global_results, test_datasets, eval_positions + [-1], aggregator='mean', limit_to=limit_to)\n",
|
318 |
+
"calculate_score(tabular_metrics.count_metric, 'count', global_results, test_datasets, eval_positions + [-1], aggregator='sum', limit_to=limit_to)"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "markdown",
|
323 |
+
"metadata": {
|
324 |
+
"tags": []
|
325 |
+
},
|
326 |
+
"source": [
|
327 |
+
"#### ROC and AUC plots from TabPFN Paper"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"metadata": {},
|
334 |
+
"outputs": [],
|
335 |
+
"source": [
|
336 |
+
"def generate_ranks_and_wins_table(global_results_filtered, metric_key, max_time, split_number, time_matrix):\n",
|
337 |
+
" global_results_filtered_split = {**global_results_filtered}\n",
|
338 |
+
" global_results_filtered_split = {k: global_results_filtered_split[k] for k in global_results_filtered_split.keys() if '_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_split_'+str(split_number) in k or 'transformer_split_'+str(split_number) in k}\n",
|
339 |
+
"\n",
|
340 |
+
" matrix, matrix_stds = make_metric_matrix(global_results_filtered_split, methods, pos, metric_key, test_datasets)\n",
|
341 |
+
" for method in methods:\n",
|
342 |
+
" if time_matrix[method] > max_time * 2:\n",
|
343 |
+
" matrix[method] = np.nan\n",
|
344 |
+
" # = np.nan\n",
|
345 |
+
"\n",
|
346 |
+
" if metric_key == 'cross_entropy':\n",
|
347 |
+
" matrix = -(matrix.fillna(-100))\n",
|
348 |
+
" else:\n",
|
349 |
+
" matrix = matrix.fillna(-1)\n",
|
350 |
+
" return make_ranks_and_wins_table(matrix.copy())"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": null,
|
356 |
+
"metadata": {
|
357 |
+
"tags": []
|
358 |
+
},
|
359 |
+
"outputs": [],
|
360 |
+
"source": [
|
361 |
+
"%matplotlib inline\n",
|
362 |
+
"\n",
|
363 |
+
"df_ = []\n",
|
364 |
+
"metric_keys = ['roc', 'cross_entropy', 'time']\n",
|
365 |
+
"\n",
|
366 |
+
"for max_time in max_times:\n",
|
367 |
+
" global_results_filtered = {**global_results}\n",
|
368 |
+
" global_results_filtered = {k: global_results_filtered[k] for k in global_results_filtered.keys() if '_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_' in k or 'transformer' in k}\n",
|
369 |
+
" \n",
|
370 |
+
" time_matrix, _ = make_metric_matrix(global_results_filtered, methods, pos, 'time', test_datasets)\n",
|
371 |
+
" time_matrix = time_matrix.mean()\n",
|
372 |
+
" \n",
|
373 |
+
" if len(global_results_filtered) == 0:\n",
|
374 |
+
" continue\n",
|
375 |
+
" \n",
|
376 |
+
" # Calculate ranks and wins per split\n",
|
377 |
+
" for metric_key in metric_keys:\n",
|
378 |
+
" for split_number in range(1,6):\n",
|
379 |
+
" ranks, wins = generate_ranks_and_wins_table(global_results_filtered, metric_key, max_time, split_number, time_matrix)\n",
|
380 |
+
"\n",
|
381 |
+
" for method in methods:\n",
|
382 |
+
" method_ = method+'_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='') if method != 'transformer' else method\n",
|
383 |
+
" global_results[method_+'_split_'+str(split_number)]['mean_rank_'+metric_key+f'_at_{pos}'] = ranks[method]\n",
|
384 |
+
" global_results[method_+'_split_'+str(split_number)]['mean_wins_'+metric_key+f'_at_{pos}'] = wins[method]\n",
|
385 |
+
" \n",
|
386 |
+
" #for method in global_results.keys():\n",
|
387 |
+
" # global_results[method]['mean_rank_'+metric_key+f'_at_{pos}'] = ranks[]\n",
|
388 |
+
" \n",
|
389 |
+
" avg_times = {}\n",
|
390 |
+
" for method_ in methods:\n",
|
391 |
+
" avg_times[method_] = []\n",
|
392 |
+
" for split_number in range(1,6):\n",
|
393 |
+
" if method_ != 'transformer':\n",
|
394 |
+
" method = method_+'_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_split_'+str(split_number)\n",
|
395 |
+
" else:\n",
|
396 |
+
" method = method_+'_split_'+str(split_number)\n",
|
397 |
+
" avg_times[method_] += [global_results[method][f'mean_time_at_{pos}']]\n",
|
398 |
+
" avg_times = pd.DataFrame(avg_times).mean()\n",
|
399 |
+
" \n",
|
400 |
+
" for metric_key in metric_keys:\n",
|
401 |
+
" for ranking in ['', 'rank_', 'wins_']:\n",
|
402 |
+
" for method_ in methods:\n",
|
403 |
+
" for split_number in range(1,6):\n",
|
404 |
+
" method = method_\n",
|
405 |
+
" if method_ != 'transformer':\n",
|
406 |
+
" method = method_+'_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_split_'+str(split_number)\n",
|
407 |
+
" else:\n",
|
408 |
+
" method = method_+'_split_'+str(split_number)\n",
|
409 |
+
"\n",
|
410 |
+
" if global_results[method][f'sum_count_at_{pos}'] <= 29:\n",
|
411 |
+
" print('Warning not all datasets generated for '+method+' '+ str(global_results[method][f'sum_count_at_{pos}']))\n",
|
412 |
+
" \n",
|
413 |
+
" time = global_results[method]['mean_time'] if ranking == '' else max_time\n",
|
414 |
+
" time = max_time # Todo: This is not the real time\n",
|
415 |
+
" df_ += [{'metric'+ranking+metric_key: global_results[method]['mean_'+ranking+metric_key+f'_at_{pos}'], 'real_time': avg_times[method_], 'time': time, 'method': method_, 'split_number': split_number}]\n",
|
416 |
+
" # For Roc AUC Plots\n",
|
417 |
+
" #if 'transformer' in method:\n",
|
418 |
+
" # df_ += [{'metric'+ranking+metric_key: global_results[method]['mean_'+ranking+metric_key+f'_at_{pos}'], 'real_time': avg_times[method_], 'time': time, 'method': method_, 'split_number': split_number}]\n",
|
419 |
+
" # df_ += [{'metric'+ranking+metric_key: global_results[method]['mean_'+ranking+metric_key+f'_at_{pos}'], 'real_time': max(avg_times), 'time': max(max_times), 'method': method_, 'split_number': split_number}]\n",
|
420 |
+
" \n",
|
421 |
+
" \n",
|
422 |
+
"df_ = pd.DataFrame(df_)"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"metadata": {},
|
429 |
+
"outputs": [],
|
430 |
+
"source": [
|
431 |
+
"metric_renamer = {'roc': 'ROC AUC', 'cross_entropy': 'Cross entropy'\n",
|
432 |
+
" , 'rank_roc': 'Mean ROC AUC Rank', 'rank_cross_entropy': 'Mean Cross entropy Rank'\n",
|
433 |
+
" , 'wins_roc': 'Mean ROC AUC Wins', 'wins_cross_entropy': 'Mean Cross entropy Wins'\n",
|
434 |
+
" , 'time': 'actual time taken'}\n",
|
435 |
+
"max_times_renamer = {0.5: \"0.5s\", 1: \"1s\", 5: \"5s\", 15: \"15s\", 30: \"30s\", 60: \"1min\", 300: \"5min\", 900: \"15min\", 3600: \"1h\", 14400: \"4h\"}\n",
|
436 |
+
"\n",
|
437 |
+
"def make_tabular_results_plot(metric_key, exclude, max_times, df_, grouping=True):\n",
|
438 |
+
" f, ax = plt.subplots(figsize=(7, 7))\n",
|
439 |
+
" #ax.set(xscale=\"log\")\n",
|
440 |
+
" \n",
|
441 |
+
" df_.loc[:, 'time_log'] = np.log10(df_.time)\n",
|
442 |
+
" df_.loc[:, 'real_time_log'] = np.log10(df_.real_time)\n",
|
443 |
+
" time_column = 'time_log' if grouping else 'real_time_log'\n",
|
444 |
+
"\n",
|
445 |
+
" sns.set_palette(\"tab10\")\n",
|
446 |
+
" for method in methods:\n",
|
447 |
+
" if method in exclude or method=='transformer':\n",
|
448 |
+
" continue\n",
|
449 |
+
" df_method = df_[df_.method==method].copy()\n",
|
450 |
+
" ax = sns.lineplot(time_column, 'metric'+metric_key, data=df_method, marker='o', label=method, ax=ax)\n",
|
451 |
+
" #sns.scatterplot(data=df_, x='time', y='metric', hue='method', ax=ax, style='method') #\n",
|
452 |
+
" df_trans = df_[df_.method=='transformer']\n",
|
453 |
+
" if time_column == 'real_time_log':\n",
|
454 |
+
" # Removing dots for line for transformers\n",
|
455 |
+
" df_trans = df_trans[np.logical_or(df_trans.real_time == df_trans.real_time.min(), df_trans.real_time == df_trans.real_time.max())]\n",
|
456 |
+
" df_trans.loc[:, 'metric'+metric_key] = df_trans['metric'+metric_key].mean()\n",
|
457 |
+
" df_trans.loc[:, time_column] = np.log(1) # Hacky code to get the right time from our measurements\n",
|
458 |
+
" ax = sns.lineplot(time_column, 'metric'+metric_key, data=df_trans, linestyle='--', marker='o', ci=\"sd\", ax=ax)\n",
|
459 |
+
" \n",
|
460 |
+
" #ax = sns.scatterplot(data = df_trans, x=time_column, y='metric'+metric_key, s=800, marker='*', color='grey') #\n",
|
461 |
+
" #ax = plt.scatter(df_trans[time_column], df_trans['metric'+metric_key], s=600, marker=['*']) #\n",
|
462 |
+
" \n",
|
463 |
+
" if grouping:\n",
|
464 |
+
" ax.set_xlabel(\"Time (s, requested, not actual)\")\n",
|
465 |
+
" else:\n",
|
466 |
+
" ax.set_xlabel(\"Time taken\")\n",
|
467 |
+
" ax.set_ylabel(metric_renamer[metric_key])\n",
|
468 |
+
"\n",
|
469 |
+
" #ax.legend()\n",
|
470 |
+
" \n",
|
471 |
+
" times = np.log10(max_times)\n",
|
472 |
+
" ax.set_xticks(times)\n",
|
473 |
+
" ax.set_xticklabels([max_times_renamer[t] for t in max_times])\n",
|
474 |
+
" \n",
|
475 |
+
" #ax.legend([],[], frameon=False)\n",
|
476 |
+
" \n",
|
477 |
+
" return ax"
|
478 |
+
]
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"cell_type": "code",
|
482 |
+
"execution_count": null,
|
483 |
+
"metadata": {},
|
484 |
+
"outputs": [],
|
485 |
+
"source": [
|
486 |
+
"df_absolute = df_.copy()"
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "code",
|
491 |
+
"execution_count": null,
|
492 |
+
"metadata": {},
|
493 |
+
"outputs": [],
|
494 |
+
"source": [
|
495 |
+
"df_absolute = df_.copy()\n",
|
496 |
+
"df_absolute = df_absolute[np.logical_or(df_.method != 'autogluon', df_.time >= 30)] # Autogluon did not yield any useful results before 30s\n",
|
497 |
+
"\n",
|
498 |
+
"knn_extend = df_absolute[np.logical_and(df_absolute.method=='knn', df_absolute.time == 3600)].copy()\n",
|
499 |
+
"knn_extend['real_time'] = 14400\n",
|
500 |
+
"knn_extend['time'] = 14400\n",
|
501 |
+
"df_absolute = df_absolute.append(knn_extend, ignore_index=True).reindex()\n",
|
502 |
+
"\n",
|
503 |
+
"knn_extend = df_absolute[np.logical_and(df_absolute.method=='logistic', df_absolute.time == 3600)].copy()\n",
|
504 |
+
"knn_extend['real_time'] = 14400\n",
|
505 |
+
"knn_extend['time'] = 14400\n",
|
506 |
+
"\n",
|
507 |
+
"df_absolute = df_absolute.append(knn_extend, ignore_index=True).reindex()"
|
508 |
+
]
|
509 |
+
},
|
510 |
+
{
|
511 |
+
"cell_type": "code",
|
512 |
+
"execution_count": null,
|
513 |
+
"metadata": {
|
514 |
+
"tags": []
|
515 |
+
},
|
516 |
+
"outputs": [],
|
517 |
+
"source": [
|
518 |
+
"exclude=['']\n",
|
519 |
+
"#ax = make_tabular_results_plot('time', exclude=exclude)\n",
|
520 |
+
"ax = make_tabular_results_plot('roc', df_=df_absolute, exclude=exclude, grouping=False, max_times=[1, 5, 30, 60*5, 60*60])\n",
|
521 |
+
"ax.set_ylim([0.84, 0.9])\n",
|
522 |
+
"ax.set_xlim([np.log10(0.7), np.log10(3600)])\n",
|
523 |
+
"ax.legend([],[], frameon=False)\n",
|
524 |
+
"\n",
|
525 |
+
"#tikzplotlib.save(f'roc_over_time.tex', axis_height='5cm', axis_width='6cm', strict=True)"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": null,
|
531 |
+
"metadata": {},
|
532 |
+
"outputs": [],
|
533 |
+
"source": [
|
534 |
+
"ax = make_tabular_results_plot('rank_roc', df_=df_[df_.time >= 1].copy(), exclude=['tabnet'], max_times=[1, 5, 30, 60*5, 60*60])\n",
|
535 |
+
"ax.invert_yaxis()\n",
|
536 |
+
"ax.set_xlim([np.log10(1.0), np.log10(3600)])\n",
|
537 |
+
"ax.legend([],[], frameon=False)\n",
|
538 |
+
"tikzplotlib.save(f'roc_raks_tabular.tex', axis_height='5cm', axis_width='6cm', strict=True)"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "code",
|
543 |
+
"execution_count": null,
|
544 |
+
"metadata": {},
|
545 |
+
"outputs": [],
|
546 |
+
"source": [
|
547 |
+
"ax = make_tabular_results_plot('wins_roc', df_=df_[df_.time >= 1].copy(), exclude=exclude, max_times=[1, 5, 30, 60*5, 60*60])\n",
|
548 |
+
"ax.set_xlim([np.log10(1.0), np.log10(3600)])\n",
|
549 |
+
"ax.legend([],[], frameon=False)\n",
|
550 |
+
"tikzplotlib.save(f'roc_wins_tabular.tex', axis_height='5cm', axis_width='6cm', strict=True)"
|
551 |
+
]
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"cell_type": "markdown",
|
555 |
+
"metadata": {
|
556 |
+
"tags": []
|
557 |
+
},
|
558 |
+
"source": [
|
559 |
+
"#### Big Table metrics"
|
560 |
+
]
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"cell_type": "code",
|
564 |
+
"execution_count": null,
|
565 |
+
"metadata": {},
|
566 |
+
"outputs": [],
|
567 |
+
"source": [
|
568 |
+
"max_time = '3600'"
|
569 |
+
]
|
570 |
+
},
|
571 |
+
{
|
572 |
+
"cell_type": "code",
|
573 |
+
"execution_count": null,
|
574 |
+
"metadata": {},
|
575 |
+
"outputs": [],
|
576 |
+
"source": [
|
577 |
+
"global_results_filtered = {**global_results}\n",
|
578 |
+
"global_results_filtered = {k: global_results_filtered[k] for k in global_results_filtered.keys() if '_time_'+str(max_time)+tabular_baselines.get_scoring_string(metric_used, usage='')+'_' in k or 'transformer' in k}\n"
|
579 |
+
]
|
580 |
+
},
|
581 |
+
{
|
582 |
+
"cell_type": "code",
|
583 |
+
"execution_count": null,
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [],
|
586 |
+
"source": [
|
587 |
+
"roc_matrix, roc_matrix_stds = make_metric_matrix(global_results_filtered, methods, pos, 'roc', test_datasets_multiclass_filtered)\n",
|
588 |
+
"acc_matrix, acc_matrix_stds = make_metric_matrix(global_results_filtered, methods, pos, 'acc', test_datasets_multiclass_filtered)\n",
|
589 |
+
"cross_entropy_matrix, cross_entropy_matrix_stds = make_metric_matrix(global_results_filtered, methods, pos, 'cross_entropy', test_datasets_multiclass_filtered)\n",
|
590 |
+
"time_matrix, time_matrix_stds = make_metric_matrix(global_results_filtered, methods, pos, 'time', test_datasets_multiclass_filtered)\n",
|
591 |
+
"\n",
|
592 |
+
"roc_rank, rocs_wins = make_ranks_and_wins_table(roc_matrix.copy())\n",
|
593 |
+
"acc_rank, acc_wins = make_ranks_and_wins_table(acc_matrix.copy())\n",
|
594 |
+
"cross_entropy_rank, cross_entropy_wins = make_ranks_and_wins_table(-cross_entropy_matrix.copy())"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": null,
|
600 |
+
"metadata": {},
|
601 |
+
"outputs": [],
|
602 |
+
"source": [
|
603 |
+
"def wins_vs_idx(matrix, idx):\n",
|
604 |
+
" wins_auc = np.array([[(matrix.values[:, j] < matrix.values[:, i]).sum() if i != j else 0 for i,method in enumerate(methods)] for j in [idx]])\n",
|
605 |
+
" ties_auc = np.array([[(matrix.values[:, j] == matrix.values[:, i]).sum() if i != j else 0 for i,method in enumerate(methods)] for j in [idx]])\n",
|
606 |
+
" losses_auc = np.array([[(matrix.values[:, j] > matrix.values[:, i]).sum() if i != j else 0 for i,method in enumerate(methods)] for j in [idx]])\n",
|
607 |
+
" \n",
|
608 |
+
" return wins_auc, ties_auc, losses_auc\n",
|
609 |
+
"\n",
|
610 |
+
"transformer_idx = np.where(roc_matrix.columns == 'transformer')[0][0]\n",
|
611 |
+
"\n",
|
612 |
+
"wins_roc_vs_us, ties_roc_vs_us, losses_roc_vs_us = wins_vs_idx(roc_matrix, transformer_idx)\n",
|
613 |
+
"wins_acc_vs_us, ties_acc_vs_us, losses_acc_vs_us = wins_vs_idx(acc_matrix, transformer_idx)\n",
|
614 |
+
"wins_ce_vs_us, ties_ce_vs_us, losses_ce_vs_us = wins_vs_idx(-cross_entropy_matrix, transformer_idx)"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "code",
|
619 |
+
"execution_count": null,
|
620 |
+
"metadata": {},
|
621 |
+
"outputs": [],
|
622 |
+
"source": [
|
623 |
+
"def rename(table):\n",
|
624 |
+
" return table.rename(columns=relabeler).T.rename(columns={'blood-transfusion-service-center': 'blood-transfus..'\n",
|
625 |
+
" , 'jungle_chess_2pcs_raw_endgame_complete': 'jungle\\_chess..', 'bank-marketing': 'bank-market..'}).T\n",
|
626 |
+
"\n",
|
627 |
+
"def get_suffix(i, k):\n",
|
628 |
+
" suffix = ''\n",
|
629 |
+
" suffix = suffix+'s' if test_datasets[i][5]['samples_capped'] == True else suffix\n",
|
630 |
+
" suffix = suffix+'f' if test_datasets[i][5]['feats_capped'] == True else suffix\n",
|
631 |
+
" suffix = suffix+'c' if test_datasets[i][5]['classes_capped'] == True else suffix\n",
|
632 |
+
" suffix = '' if len(suffix) == 0 else f' [{suffix}]'\n",
|
633 |
+
" \n",
|
634 |
+
" return k + suffix"
|
635 |
+
]
|
636 |
+
},
|
637 |
+
{
|
638 |
+
"cell_type": "code",
|
639 |
+
"execution_count": null,
|
640 |
+
"metadata": {},
|
641 |
+
"outputs": [],
|
642 |
+
"source": [
|
643 |
+
"relabeler = {'transformer': 'Tabular PFN'\n",
|
644 |
+
" , 'autogluon': 'Autogluon'\n",
|
645 |
+
" , 'autosklearn2': 'Autosklearn2'\n",
|
646 |
+
" , 'gp': 'GP (RBF)'\n",
|
647 |
+
" , 'logistic': 'Log. Regr.'\n",
|
648 |
+
" , 'knn': 'KNN'\n",
|
649 |
+
" , 'catboost': 'Catboost'\n",
|
650 |
+
" , 'xgb': 'XGB'}"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"cell_type": "code",
|
655 |
+
"execution_count": null,
|
656 |
+
"metadata": {},
|
657 |
+
"outputs": [],
|
658 |
+
"source": [
|
659 |
+
"table = roc_matrix.copy()\n",
|
660 |
+
"#table = roc_ovr_matrix.copy()\n",
|
661 |
+
"#table = acc_matrix.copy()\n",
|
662 |
+
"#table = cross_entropy_matrix.copy()\n",
|
663 |
+
"\n",
|
664 |
+
"#table = table_acc\n",
|
665 |
+
"table.index = [get_suffix(i, k) for i, k in enumerate(table.index[0:table.shape[0]])]\n",
|
666 |
+
"\n",
|
667 |
+
"table.loc['Wins AUC OVO'] = rocs_wins.values\n",
|
668 |
+
"#table.loc['Mean AUC OVR'] = roc_ovr_matrix.mean(skipna=True)\n",
|
669 |
+
"table.loc['Wins Acc.'] = acc_wins.values\n",
|
670 |
+
"#table.loc['Mean Bal. Acc.'] = balanced_acc_matrix.mean()\n",
|
671 |
+
"table.loc['Wins CE'] = cross_entropy_wins.values\n",
|
672 |
+
"\n",
|
673 |
+
"table.loc['Win/T/L AUC vs Us'] = [\"{:d}/{:d}/{:d}\".format(w, t, l) for w,t,l in zip(wins_roc_vs_us[-1, :], ties_roc_vs_us[-1, :], losses_roc_vs_us[-1, :])]\n",
|
674 |
+
"table.loc['Win/T/L Acc vs Us'] = [\"{:d}/{:d}/{:d}\".format(w, t, l) for w,t,l in zip(wins_acc_vs_us[-1, :], ties_acc_vs_us[-1, :], losses_acc_vs_us[-1, :])]\n",
|
675 |
+
"table.loc['Win/T/L CE vs Us'] = [\"{:d}/{:d}/{:d}\".format(w, t, l) for w,t,l in zip(wins_ce_vs_us[-1, :], ties_ce_vs_us[-1, :], losses_ce_vs_us[-1, :])]\n",
|
676 |
+
"\n",
|
677 |
+
"table.loc['Mean AUC OVO'] = roc_matrix.mean(skipna=True)\n",
|
678 |
+
"table.loc['Mean AUC OVO Stds'] = roc_matrix_stds.mean(skipna=True)\n",
|
679 |
+
"\n",
|
680 |
+
"#table.loc['Mean AUC OVR'] = roc_ovr_matrix.mean(skipna=True)\n",
|
681 |
+
"table.loc['Mean Acc.'] = acc_matrix.mean()\n",
|
682 |
+
"table.loc['Mean Acc. Stds'] = acc_matrix_stds.mean(skipna=True)\n",
|
683 |
+
"\n",
|
684 |
+
"#table.loc['Mean Bal. Acc.'] = balanced_acc_matrix.mean()\n",
|
685 |
+
"table.loc['Mean CE'] = cross_entropy_matrix.mean()\n",
|
686 |
+
"table.loc['Mean CE Stds'] = cross_entropy_matrix_stds.mean()\n",
|
687 |
+
"\n",
|
688 |
+
"table.loc['M. rank AUC OVO'] = roc_rank.values\n",
|
689 |
+
"#table.loc['Mean rank AUC OVR'] = roc_ovr_rank.values\n",
|
690 |
+
"table.loc['Mean rank Acc.'] = acc_rank.values\n",
|
691 |
+
"#table.loc['Mean rank Bal. Acc.'] = balanced_acc_rank.values\n",
|
692 |
+
"table.loc['Mean rank CE'] = cross_entropy_rank.values\n",
|
693 |
+
"\n",
|
694 |
+
"table.loc['Mean time (s)'] = time_matrix.mean()\n",
|
695 |
+
"table.loc['Mean time (s)', 'knn'] = 0.5\n",
|
696 |
+
"table.loc['Mean time (s)', 'logistic'] = 60\n",
|
697 |
+
"\n",
|
698 |
+
"table = table[['knn', 'logistic', 'gp', 'catboost', 'xgb', 'autosklearn2', 'autogluon', 'transformer']]\n",
|
699 |
+
"rename(table).round(decimals=3).style.highlight_max(axis = 1, props= 'font-weight: bold;').format(precision=3)"
|
700 |
+
]
|
701 |
+
},
|
702 |
+
{
|
703 |
+
"cell_type": "code",
|
704 |
+
"execution_count": null,
|
705 |
+
"metadata": {},
|
706 |
+
"outputs": [],
|
707 |
+
"source": [
|
708 |
+
"def bold_extreme_values(data, format_string=\"%.3g\", max_=True):\n",
|
709 |
+
" data = data.astype(float).round(3)\n",
|
710 |
+
" if max_:\n",
|
711 |
+
" extrema = data != data.max()\n",
|
712 |
+
" else:\n",
|
713 |
+
" extrema = data != data.min()\n",
|
714 |
+
" bolded = data.apply(lambda x : \"\\\\textbf{%s}\" % format_string % x)\n",
|
715 |
+
" formatted = data.apply(lambda x : format_string % x)\n",
|
716 |
+
" return formatted.where(extrema, bolded) \n",
|
717 |
+
"\n",
|
718 |
+
"def to_str(data, format_string=\"%.3g\"):\n",
|
719 |
+
" formatted = data.apply(lambda x : format_string % x)\n",
|
720 |
+
" return formatted"
|
721 |
+
]
|
722 |
+
},
|
723 |
+
{
|
724 |
+
"cell_type": "code",
|
725 |
+
"execution_count": null,
|
726 |
+
"metadata": {},
|
727 |
+
"outputs": [],
|
728 |
+
"source": [
|
729 |
+
"keys_max = [\"Mean rank CE\", \"Mean rank Acc.\", \"Mean rank AUC OVO\", \"Mean rank AUC OVR\", \"Mean rank Bal. Acc.\", \"Mean AUC OVO\", \"Mean Acc.\"]\n",
|
730 |
+
"keys_max = [\"Mean AUC OVO\", \"Mean Acc.\", \"Wins AUC OVO\", \"Wins Acc.\", \"Wins CE\"]\n",
|
731 |
+
"\n",
|
732 |
+
"keys_min = [\"Mean rank CE\", \"Mean rank Acc.\", \"M. rank AUC OVO\", \"Mean CE\"]\n",
|
733 |
+
"\n",
|
734 |
+
"table_latex = rename(table).copy()\n",
|
735 |
+
"\n",
|
736 |
+
"table_latex.iloc[0:30] = table_latex.iloc[0:30].apply(lambda data : bold_extreme_values(data),axis=1)\n",
|
737 |
+
"table_latex.loc[[\"Mean time (s)\"]] = table_latex.loc[[\"Mean time (s)\"]].apply(lambda data : bold_extreme_values(data, format_string=\"%.4g\", max_=False), axis=1)\n",
|
738 |
+
"table_latex.loc[keys_max] = table_latex.loc[keys_max].apply(lambda data : bold_extreme_values(data),axis=1)\n",
|
739 |
+
"table_latex.loc[keys_min] = table_latex.loc[keys_min].apply(lambda data : bold_extreme_values(data, max_=False),axis=1)\n",
|
740 |
+
"\n",
|
741 |
+
"table_latex.loc[['Mean CE Stds']] = table_latex.loc[['Mean CE Stds']].apply(lambda data : to_str(data, format_string=\"%.2g\"),axis=1)\n",
|
742 |
+
"table_latex.loc['Mean CE'] = table_latex.loc['Mean CE'] + '$\\pm$' + table_latex.loc['Mean CE Stds']\n",
|
743 |
+
"table_latex = table_latex.drop(['Mean CE Stds'])\n",
|
744 |
+
"\n",
|
745 |
+
"table_latex.loc[['Mean Acc. Stds']] = table_latex.loc[['Mean Acc. Stds']].apply(lambda data : to_str(data, format_string=\"%.2g\"),axis=1)\n",
|
746 |
+
"table_latex.loc['Mean Acc.'] = table_latex.loc['Mean Acc.'] + '$\\pm$' + table_latex.loc['Mean Acc. Stds']\n",
|
747 |
+
"table_latex = table_latex.drop(['Mean Acc. Stds'])\n",
|
748 |
+
"\n",
|
749 |
+
"table_latex.loc[['Mean AUC OVO Stds']] = table_latex.loc[['Mean AUC OVO Stds']].apply(lambda data : to_str(data, format_string=\"%.2g\"),axis=1)\n",
|
750 |
+
"table_latex.loc['Mean AUC OVO'] = table_latex.loc['Mean AUC OVO'] + '$\\pm$' + table_latex.loc['Mean AUC OVO Stds']\n",
|
751 |
+
"table_latex = table_latex.drop(['Mean AUC OVO Stds'])\n",
|
752 |
+
"\n",
|
753 |
+
"table_latex\n",
|
754 |
+
"#print(table_latex.to_latex(escape=False))"
|
755 |
+
]
|
756 |
+
},
|
757 |
+
{
|
758 |
+
"cell_type": "code",
|
759 |
+
"execution_count": null,
|
760 |
+
"metadata": {
|
761 |
+
"tags": []
|
762 |
+
},
|
763 |
+
"outputs": [],
|
764 |
+
"source": [
|
765 |
+
"print(table_latex.to_latex(escape=False))"
|
766 |
+
]
|
767 |
+
},
|
768 |
+
{
|
769 |
+
"cell_type": "code",
|
770 |
+
"execution_count": null,
|
771 |
+
"metadata": {},
|
772 |
+
"outputs": [],
|
773 |
+
"source": [
|
774 |
+
"table_latex_small = table_latex.iloc[-len(keys_min+keys_max)-1-3:]\n",
|
775 |
+
"table_latex_small"
|
776 |
+
]
|
777 |
+
},
|
778 |
+
{
|
779 |
+
"cell_type": "code",
|
780 |
+
"execution_count": null,
|
781 |
+
"metadata": {},
|
782 |
+
"outputs": [],
|
783 |
+
"source": [
|
784 |
+
"print(table_latex_small.to_latex(escape=False))"
|
785 |
+
]
|
786 |
+
},
|
787 |
+
{
|
788 |
+
"cell_type": "code",
|
789 |
+
"execution_count": null,
|
790 |
+
"metadata": {},
|
791 |
+
"outputs": [],
|
792 |
+
"source": [
|
793 |
+
"table_latex = table.copy()\n",
|
794 |
+
"\n",
|
795 |
+
"table_latex.iloc[:-5] = table_latex.iloc[:-5].apply(lambda data : bold_extreme_values(data),axis=1)\n",
|
796 |
+
"table_latex.iloc[-5:-5] = table_latex.iloc[-5:-5].apply(lambda data : bold_extreme_values(data, max_=False),axis=1)\n",
|
797 |
+
"\n",
|
798 |
+
"table_latex\n",
|
799 |
+
"#print(table_latex.to_latex(escape=False))"
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"cell_type": "code",
|
804 |
+
"execution_count": null,
|
805 |
+
"metadata": {},
|
806 |
+
"outputs": [],
|
807 |
+
"source": [
|
808 |
+
"rename(table[-7:]).round(decimals=3).style.highlight_min(axis = 1, props= 'font-weight: bold;').format(precision=3)"
|
809 |
+
]
|
810 |
+
}
|
811 |
+
],
|
812 |
+
"metadata": {
|
813 |
+
"kernelspec": {
|
814 |
+
"display_name": "Python 3 (ipykernel)",
|
815 |
+
"language": "python",
|
816 |
+
"name": "python3"
|
817 |
+
},
|
818 |
+
"language_info": {
|
819 |
+
"codemirror_mode": {
|
820 |
+
"name": "ipython",
|
821 |
+
"version": 3
|
822 |
+
},
|
823 |
+
"file_extension": ".py",
|
824 |
+
"mimetype": "text/x-python",
|
825 |
+
"name": "python",
|
826 |
+
"nbconvert_exporter": "python",
|
827 |
+
"pygments_lexer": "ipython3",
|
828 |
+
"version": "3.9.6"
|
829 |
+
}
|
830 |
+
},
|
831 |
+
"nbformat": 4,
|
832 |
+
"nbformat_minor": 4
|
833 |
+
}
|
TabPFN/SyntheticGPAblation.ipynb
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"\n",
|
11 |
-
"%autoreload 2"
|
12 |
-
]
|
13 |
-
},
|
14 |
-
{
|
15 |
-
"cell_type": "code",
|
16 |
-
"execution_count": 2,
|
17 |
-
"metadata": {},
|
18 |
-
"outputs": [],
|
19 |
-
"source": [
|
20 |
-
"import os\n",
|
21 |
-
"import time\n",
|
22 |
-
"\n",
|
23 |
-
"import torch\n",
|
24 |
-
"\n",
|
25 |
-
"import numpy as np\n",
|
26 |
-
"\n",
|
27 |
-
"import matplotlib.pyplot as plt\n",
|
28 |
-
"\n",
|
29 |
-
"from model_builder import get_model, get_default_spec, save_model, load_model\n",
|
30 |
-
"\n",
|
31 |
-
"from scripts.model_configs import *"
|
32 |
-
]
|
33 |
-
},
|
34 |
-
{
|
35 |
-
"cell_type": "markdown",
|
36 |
-
"metadata": {
|
37 |
-
"tags": []
|
38 |
-
},
|
39 |
-
"source": [
|
40 |
-
"# Setting params"
|
41 |
-
]
|
42 |
-
},
|
43 |
-
{
|
44 |
-
"cell_type": "code",
|
45 |
-
"execution_count": 6,
|
46 |
-
"metadata": {},
|
47 |
-
"outputs": [],
|
48 |
-
"source": [
|
49 |
-
"device = 'cuda'\n",
|
50 |
-
"base_path = os.path.join('.')"
|
51 |
-
]
|
52 |
-
},
|
53 |
-
{
|
54 |
-
"cell_type": "code",
|
55 |
-
"execution_count": 7,
|
56 |
-
"metadata": {},
|
57 |
-
"outputs": [],
|
58 |
-
"source": [
|
59 |
-
"def train_function(config_sample, i, add_name=''):\n",
|
60 |
-
" start_time = time.time()\n",
|
61 |
-
" N_epochs_to_save = 50\n",
|
62 |
-
" \n",
|
63 |
-
" def save_callback(model, epoch):\n",
|
64 |
-
" if not hasattr(model, 'last_saved_epoch'):\n",
|
65 |
-
" model.last_saved_epoch = 0\n",
|
66 |
-
" if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:\n",
|
67 |
-
" print('Saving model..')\n",
|
68 |
-
" config_sample['epoch_in_training'] = epoch\n",
|
69 |
-
" save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',\n",
|
70 |
-
" config_sample)\n",
|
71 |
-
" model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint\n",
|
72 |
-
" \n",
|
73 |
-
" model = get_model(config_sample\n",
|
74 |
-
" , device\n",
|
75 |
-
" , should_train=True\n",
|
76 |
-
" , verbose=1\n",
|
77 |
-
" , epoch_callback = save_callback)\n",
|
78 |
-
" \n",
|
79 |
-
" return"
|
80 |
-
]
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"cell_type": "markdown",
|
84 |
-
"metadata": {
|
85 |
-
"heading_collapsed": true,
|
86 |
-
"tags": []
|
87 |
-
},
|
88 |
-
"source": [
|
89 |
-
"# Check synthetic data fitting"
|
90 |
-
]
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"cell_type": "markdown",
|
94 |
-
"metadata": {
|
95 |
-
"tags": []
|
96 |
-
},
|
97 |
-
"source": [
|
98 |
-
"#### Workflow functions"
|
99 |
-
]
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "code",
|
103 |
-
"execution_count": 8,
|
104 |
-
"metadata": {
|
105 |
-
"hidden": true,
|
106 |
-
"tags": []
|
107 |
-
},
|
108 |
-
"outputs": [],
|
109 |
-
"source": [
|
110 |
-
"def generate_test_data(test_gp_params):\n",
|
111 |
-
" # Generate test data\n",
|
112 |
-
" config = {**test_gp_params}\n",
|
113 |
-
"\n",
|
114 |
-
" config['verbose'] = False\n",
|
115 |
-
" config['differentiable'] = False\n",
|
116 |
-
" #config['bptt'] = config['bptt_in_training']\n",
|
117 |
-
"\n",
|
118 |
-
" model_test_data = get_model(config, device, should_train=False, verbose=True)\n",
|
119 |
-
" (hp_embedding, data, targets_), targets = next(iter(model_test_data[3]))\n",
|
120 |
-
" (hp_embedding, data, targets_), targets = (hp_embedding, data.to(device), targets_.to(device)), targets.to(device)\n",
|
121 |
-
" \n",
|
122 |
-
" return (hp_embedding, data, targets_), targets\n",
|
123 |
-
"\n",
|
124 |
-
"def evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size):\n",
|
125 |
-
" losses, hparams = [], []\n",
|
126 |
-
" for l in np.arange(-1.74, 1.74, plot_step_size):\n",
|
127 |
-
" hparam = [*hparam_true]\n",
|
128 |
-
" hparam[vary_hparam_ind] = l\n",
|
129 |
-
" hp_embedding_used = torch.tensor(hparam).to(device).float()\n",
|
130 |
-
" with torch.inference_mode():\n",
|
131 |
-
" outputs = torch.sigmoid(model[2]((hp_embedding_used.repeat(data.shape[1], 1), data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
132 |
-
" \n",
|
133 |
-
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten()).detach().cpu()\n",
|
134 |
-
" losses += [loss]\n",
|
135 |
-
" hparam_real = [diff_hparams_f[i][1](hp) for i, hp in enumerate(hparam)]\n",
|
136 |
-
" hparams += [hparam_real]\n",
|
137 |
-
" \n",
|
138 |
-
" print(loss, hparam_real, hparam, outputs.shape)\n",
|
139 |
-
" return np.array(losses), np.array(hparams)"
|
140 |
-
]
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "code",
|
144 |
-
"execution_count": 9,
|
145 |
-
"metadata": {},
|
146 |
-
"outputs": [],
|
147 |
-
"source": [
|
148 |
-
"def differentiable_hparam_tuning_workflow(config_sample, hparam_label, batch_size=4, N_grad_steps=50, plot_step_size=0.1):\n",
|
149 |
-
" test_gp_params = {\n",
|
150 |
-
" \"lengthscale\": 1.0,\n",
|
151 |
-
" #\"lengthscale_mean\": true_lengthscale,\n",
|
152 |
-
" #\"lengthscale_std\": 0.5,\n",
|
153 |
-
" \"noise\": 0.2,\n",
|
154 |
-
" \"outputscale\": 1.0,\n",
|
155 |
-
" 'batch_size': batch_size\n",
|
156 |
-
" }\n",
|
157 |
-
" config_sample.update(test_gp_params)\n",
|
158 |
-
" (hp_embedding, data, targets_), targets = generate_test_data(config_sample)\n",
|
159 |
-
" hparam_true = [diff_hparams_f[i][0](test_gp_params[hp]) for i, hp in enumerate(diff_hparams_keys)]\n",
|
160 |
-
" #hparam_true = [test_gp_params[hp] for i, hp in enumerate(diff_hparams_keys)]\n",
|
161 |
-
"\n",
|
162 |
-
" for vary_hparam_ind, vary_hparam_name in hparam_label:\n",
|
163 |
-
" print(vary_hparam_name)\n",
|
164 |
-
"\n",
|
165 |
-
" losses, hparams = evaluate_hp_range(model, hparam_true, vary_hparam_ind, data, targets, eval_pos, plot_step_size=plot_step_size)\n",
|
166 |
-
"\n",
|
167 |
-
" # TODO: Make only one parameter diffable\n",
|
168 |
-
" hparam = torch.tensor([*hparam_true]).to(device).float()\n",
|
169 |
-
" hparam[vary_hparam_ind] = hparam[vary_hparam_ind] + 0.1 #random.random() * 2 - 1\n",
|
170 |
-
" hparam = torch.nn.Parameter(hparam, requires_grad=True)\n",
|
171 |
-
" hparam_grad_mask = torch.zeros_like(hparam)\n",
|
172 |
-
" hparam_grad_mask[vary_hparam_ind] = 1\n",
|
173 |
-
"\n",
|
174 |
-
" optimizer = torch.optim.Adam([hparam], lr=0.1)\n",
|
175 |
-
" \n",
|
176 |
-
" for t in range(N_grad_steps):\n",
|
177 |
-
" style = hparam.repeat(data.shape[1], 1)\n",
|
178 |
-
" outputs = torch.sigmoid(model[2]((style, data, targets.float()), single_eval_pos=eval_pos)).squeeze(-1)\n",
|
179 |
-
" loss = torch.nn.BCELoss()(outputs.flatten(), targets[eval_pos:].flatten())\n",
|
180 |
-
" optimizer.zero_grad()\n",
|
181 |
-
" loss.backward()\n",
|
182 |
-
" with torch.no_grad():\n",
|
183 |
-
" hparam.grad *= hparam_grad_mask\n",
|
184 |
-
" optimizer.step()\n",
|
185 |
-
" print('loss:', loss, 'hparams', diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind]), 'true', diff_hparams_f[vary_hparam_ind][1](hparam_true[vary_hparam_ind]))\n",
|
186 |
-
" inferred_param = diff_hparams_f[vary_hparam_ind][1](hparam[vary_hparam_ind].cpu().detach().numpy())\n",
|
187 |
-
" return hparams, losses, inferred_param, vary_hparam_ind, hparam_true\n",
|
188 |
-
" "
|
189 |
-
]
|
190 |
-
},
|
191 |
-
{
|
192 |
-
"cell_type": "markdown",
|
193 |
-
"metadata": {
|
194 |
-
"tags": []
|
195 |
-
},
|
196 |
-
"source": [
|
197 |
-
"#### Fitting a PFN with HP-Diffable GP Prior"
|
198 |
-
]
|
199 |
-
},
|
200 |
-
{
|
201 |
-
"cell_type": "code",
|
202 |
-
"execution_count": 10,
|
203 |
-
"metadata": {
|
204 |
-
"hidden": true,
|
205 |
-
"tags": []
|
206 |
-
},
|
207 |
-
"outputs": [],
|
208 |
-
"source": [
|
209 |
-
"num_features = 5\n",
|
210 |
-
"bptt = 200\n",
|
211 |
-
"eval_positions = [100]\n",
|
212 |
-
"\n",
|
213 |
-
"config_general = get_general_config(num_features, bptt, eval_positions)\n",
|
214 |
-
"config_flexible_categorical = get_flexible_categorical_config(num_features)\n",
|
215 |
-
"\n",
|
216 |
-
"config_gp = {'noise': 0.2, \"lengthscale\": 1.0, \"outputscale\": 1.0}\n",
|
217 |
-
"config_diff_gp = {'differentiable_hyperparameters': {\n",
|
218 |
-
" 'outputscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
219 |
-
" 'lengthscale': {'distribution': 'uniform', 'min': 0., 'max': 10.0},\n",
|
220 |
-
" 'noise': {'distribution': 'uniform', 'min': 0.0000001, 'max': 0.5},\n",
|
221 |
-
" }\n",
|
222 |
-
"}\n",
|
223 |
-
"\n",
|
224 |
-
"config = {**config_general, **config_flexible_categorical, **config_diff_gp, **config_gp}\n",
|
225 |
-
"\n",
|
226 |
-
"config['prior_type'], config['differentiable'], config['flexible'] = 'gp', True, True\n",
|
227 |
-
"config['num_features'], config['num_features_used'] = num_features, num_features\n",
|
228 |
-
"config['epochs'], config['num_steps'], config['verbose'] = 500, 100, False\n",
|
229 |
-
"config[\"lr\"] = 0.00001\n",
|
230 |
-
"config[\"dropout\"] = 0\n",
|
231 |
-
"config[\"emsize\"] = 512\n",
|
232 |
-
"config[\"batch_size\"] = 128\n",
|
233 |
-
"config[\"aggregate_k_gradients\"] = 1\n",
|
234 |
-
"config['set_value_to_nan'] = 0.0\n",
|
235 |
-
"config['output_multiclass_ordered_p'] = 1.0\n",
|
236 |
-
"config['categorical_feature_p'] = 0.0\n",
|
237 |
-
"config['nan_prob_a_reason'] = 0.0\n",
|
238 |
-
"config['nan_prob_no_reason'] = 0.0\n",
|
239 |
-
"config['nan_prob_unknown_reason'] = 0.0\n",
|
240 |
-
"config[\"nlayers\"] = 8\n",
|
241 |
-
"\n",
|
242 |
-
"# TODO: This should not be sampled, but be one config\n",
|
243 |
-
"# TODO: This uses old hyperparam sampler throws error\n",
|
244 |
-
"config_sample = evaluate_hypers(config)"
|
245 |
-
]
|
246 |
-
},
|
247 |
-
{
|
248 |
-
"cell_type": "code",
|
249 |
-
"execution_count": 11,
|
250 |
-
"metadata": {
|
251 |
-
"hidden": true,
|
252 |
-
"tags": []
|
253 |
-
},
|
254 |
-
"outputs": [
|
255 |
-
{
|
256 |
-
"name": "stdout",
|
257 |
-
"output_type": "stream",
|
258 |
-
"text": [
|
259 |
-
"Using style prior: True\n",
|
260 |
-
"Using cpu:0 device\n",
|
261 |
-
"Not using distributed\n",
|
262 |
-
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 128, 'seq_len': 200, 'seq_len_maximum': 200, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 128, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 200, 'eval_positions': None, 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': 5, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.2, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'output_multiclass_ordered_p': 1.0, 'recompute_attn': False}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad8dcf80>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
263 |
-
"Using a Transformer with 17.35 M parameters\n"
|
264 |
-
]
|
265 |
-
}
|
266 |
-
],
|
267 |
-
"source": [
|
268 |
-
"device = 'cuda'\n",
|
269 |
-
"train_function(config_sample, 0, add_name='gp_experiments_diff_with_noise_no_meta_new')"
|
270 |
-
]
|
271 |
-
},
|
272 |
-
{
|
273 |
-
"cell_type": "markdown",
|
274 |
-
"metadata": {
|
275 |
-
"tags": []
|
276 |
-
},
|
277 |
-
"source": [
|
278 |
-
"#### Evaluating a PFN (with pretrained model)"
|
279 |
-
]
|
280 |
-
},
|
281 |
-
{
|
282 |
-
"cell_type": "code",
|
283 |
-
"execution_count": 13,
|
284 |
-
"metadata": {
|
285 |
-
"hidden": true,
|
286 |
-
"tags": []
|
287 |
-
},
|
288 |
-
"outputs": [
|
289 |
-
{
|
290 |
-
"name": "stdout",
|
291 |
-
"output_type": "stream",
|
292 |
-
"text": [
|
293 |
-
"Using style prior: True\n",
|
294 |
-
"Using cpu:0 device\n",
|
295 |
-
"Not using distributed\n",
|
296 |
-
"DataLoader.__dict__ {'num_steps': 100, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 1, 'seq_len': 10, 'seq_len_maximum': 10, 'device': 'cpu:0', 'num_features': 5, 'hyperparameters': {'lr': 1e-05, 'dropout': 0, 'emsize': 512, 'batch_size': 1, 'nlayers': 8, 'num_features': 5, 'nhead': 4, 'nhid_factor': 2, 'bptt': 10, 'eval_positions': [190], 'seq_len_used': 200, 'sampling': 'normal', 'epochs': 500, 'num_steps': 100, 'verbose': False, 'pre_sample_causes': True, 'mix_activations': False, 'nan_prob_unknown_reason_reason_prior': 1.0, 'output_multiclass_ordered_p': 1.0, 'categorical_feature_p': 0.0, 'nan_prob_no_reason': 0.0, 'nan_prob_unknown_reason': 0.0, 'nan_prob_a_reason': 0.0, 'max_num_classes': 2, 'num_classes': 2, 'noise_type': 'Gaussian', 'balanced': True, 'multiclass_type': 'rank', 'normalize_to_ranking': False, 'set_value_to_nan': 0.0, 'normalize_by_used_features': True, 'num_features_used': <function load_model.<locals>.<lambda> at 0x7f39ad8534d0>, 'differentiable_hyperparameters': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': 0.03, 'lengthscale': 1.0, 'outputscale': 1.0, 'prior_type': 'gp', 'differentiable': True, 'flexible': True, 'aggregate_k_gradients': 1, 'recompute_attn': False, 'bptt_extra_samples': None, 'epoch_in_training': 0.998, 'categorical_features_sampler': <function load_model.<locals>.<lambda> at 0x7f39ad853680>, 'num_features_used_in_training': 5, 'num_classes_in_training': 2, 'batch_size_in_training': 128, 'bptt_in_training': 200, 'bptt_extra_samples_in_training': None}, 'num_outputs': 1, 'dynamic_batch_size': 2, 'get_batch': <function get_model.<locals>.make_get_batch.<locals>.<lambda> at 0x7f39ad81ab90>, 'differentiable_hyperparameters': {'outputscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'lengthscale': {'distribution': 'uniform', 'min': 0.0, 'max': 10.0}, 'noise': {'distribution': 'uniform', 'min': 1e-07, 'max': 0.5}}}, 'num_features': 5, 'num_outputs': 1}\n",
|
297 |
-
"Using a Transformer with 17.35 M parameters\n"
|
298 |
-
]
|
299 |
-
}
|
300 |
-
],
|
301 |
-
"source": [
|
302 |
-
"device = 'cpu'\n",
|
303 |
-
"model, c = load_model(base_path, f'models_diff/gp_ablation_model.cpkt', device, eval_positions, verbose=False)"
|
304 |
-
]
|
305 |
-
},
|
306 |
-
{
|
307 |
-
"cell_type": "code",
|
308 |
-
"execution_count": 14,
|
309 |
-
"metadata": {},
|
310 |
-
"outputs": [],
|
311 |
-
"source": [
|
312 |
-
"from priors.differentiable_prior import DifferentiableHyperparameterList\n",
|
313 |
-
"diff_list = DifferentiableHyperparameterList(c['differentiable_hyperparameters'], 512, device)\n",
|
314 |
-
"diff_hparams_keys, diff_hparams_f = diff_list.get_hyperparameter_info()"
|
315 |
-
]
|
316 |
-
},
|
317 |
-
{
|
318 |
-
"cell_type": "code",
|
319 |
-
"execution_count": null,
|
320 |
-
"metadata": {
|
321 |
-
"tags": []
|
322 |
-
},
|
323 |
-
"outputs": [],
|
324 |
-
"source": [
|
325 |
-
"model[2].eval()\n",
|
326 |
-
"eval_pos = 100\n",
|
327 |
-
"\n",
|
328 |
-
"hparam_label = [(1, 'outputscale')]\n",
|
329 |
-
"hparam_label = [(0, 'lengthscale')]\n",
|
330 |
-
"hparam_label = [(2, 'noise')]\n",
|
331 |
-
"hparam_labels = [[(1, 'outputscale')], [(2, 'noise')], [(0, 'lengthscale')]]\n",
|
332 |
-
"#hparam_labels = [[(2, 'noise')]]\n",
|
333 |
-
"\n",
|
334 |
-
"hparams, losses, inferred_param, vary_hparam_ind, hparam_true = {}, {}, {}, {}, {}\n",
|
335 |
-
"\n",
|
336 |
-
"for hparam_label in hparam_labels:\n",
|
337 |
-
" (hparams[hparam_label[0][1]], losses[hparam_label[0][1]], inferred_param[hparam_label[0][1]], vary_hparam_ind[hparam_label[0][1]], \n",
|
338 |
-
" hparam_true[hparam_label[0][1]]) = differentiable_hparam_tuning_workflow(config_sample, \n",
|
339 |
-
" hparam_label=hparam_label, \n",
|
340 |
-
" batch_size=256, \n",
|
341 |
-
" N_grad_steps=50,\n",
|
342 |
-
" plot_step_size=0.05)\n"
|
343 |
-
]
|
344 |
-
},
|
345 |
-
{
|
346 |
-
"cell_type": "code",
|
347 |
-
"execution_count": null,
|
348 |
-
"metadata": {},
|
349 |
-
"outputs": [],
|
350 |
-
"source": [
|
351 |
-
"label = 'lengthscale'\n",
|
352 |
-
"\n",
|
353 |
-
"#import tikzplotlib\n",
|
354 |
-
"\n",
|
355 |
-
"inferred = losses[label]\n",
|
356 |
-
"\n",
|
357 |
-
"plt.plot(hparams[label][:, vary_hparam_ind[label]], losses[label])\n",
|
358 |
-
"true = diff_hparams_f[vary_hparam_ind[label]][1](hparam_true[label][vary_hparam_ind[label]])\n",
|
359 |
-
"plt.axvline(x=inferred_param[label], linestyle='solid', color='red')\n",
|
360 |
-
"plt.axvline(x=true, linestyle='dashed')\n",
|
361 |
-
"\n",
|
362 |
-
"plt.ylabel('Cross entropy Loss')\n",
|
363 |
-
"plt.xlabel(label)\n",
|
364 |
-
"\n",
|
365 |
-
"#tikzplotlib.save(f'diff_inferred_params_{label}.tex', axis_height='5.2cm', axis_width='5.2cm', strict=True)\n",
|
366 |
-
"\n",
|
367 |
-
"plt.show()"
|
368 |
-
]
|
369 |
-
}
|
370 |
-
],
|
371 |
-
"metadata": {
|
372 |
-
"kernelspec": {
|
373 |
-
"display_name": "Python 3 (ipykernel)",
|
374 |
-
"language": "python",
|
375 |
-
"name": "python3"
|
376 |
-
},
|
377 |
-
"language_info": {
|
378 |
-
"codemirror_mode": {
|
379 |
-
"name": "ipython",
|
380 |
-
"version": 3
|
381 |
-
},
|
382 |
-
"file_extension": ".py",
|
383 |
-
"mimetype": "text/x-python",
|
384 |
-
"name": "python",
|
385 |
-
"nbconvert_exporter": "python",
|
386 |
-
"pygments_lexer": "ipython3",
|
387 |
-
"version": "3.7.13"
|
388 |
-
}
|
389 |
-
},
|
390 |
-
"nbformat": 4,
|
391 |
-
"nbformat_minor": 4
|
392 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TabPFN/TabularEvaluationVisualization.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/TrainingTuningAndPrediction.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
TabPFN/differentiable_pfn_evaluation.py
DELETED
@@ -1,345 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import numpy as np
|
4 |
-
import time
|
5 |
-
import pickle
|
6 |
-
from scripts import tabular_metrics
|
7 |
-
from scripts.tabular_metrics import calculate_score_per_method
|
8 |
-
from scripts.tabular_evaluation import evaluate
|
9 |
-
from priors.differentiable_prior import draw_random_style
|
10 |
-
from tqdm import tqdm
|
11 |
-
import random
|
12 |
-
from scripts.transformer_prediction_interface import get_params_from_config, load_model_workflow
|
13 |
-
|
14 |
-
"""
|
15 |
-
===============================
|
16 |
-
PUBLIC FUNCTIONS FOR EVALUATION
|
17 |
-
===============================
|
18 |
-
"""
|
19 |
-
|
20 |
-
|
21 |
-
def eval_model_range(i_range, *args, **kwargs):
|
22 |
-
for i in i_range:
|
23 |
-
eval_model(i, *args, **kwargs)
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
|
28 |
-
bptt_valid,
|
29 |
-
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
|
30 |
-
"""
|
31 |
-
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
32 |
-
|
33 |
-
:param i:
|
34 |
-
:param e:
|
35 |
-
:param valid_datasets:
|
36 |
-
:param test_datasets:
|
37 |
-
:param train_datasets:
|
38 |
-
:param eval_positions_valid:
|
39 |
-
:param eval_positions_test:
|
40 |
-
:param bptt_valid:
|
41 |
-
:param bptt_test:
|
42 |
-
:param add_name:
|
43 |
-
:param base_path:
|
44 |
-
:param device:
|
45 |
-
:param eval_addition:
|
46 |
-
:param extra_tuning_args:
|
47 |
-
:return:
|
48 |
-
"""
|
49 |
-
model, c, results_file = load_model_workflow(i, e, add_name, base_path, device, eval_addition)
|
50 |
-
params = {'bptt': bptt_valid
|
51 |
-
, 'bptt_final': bptt_test
|
52 |
-
, 'eval_positions': eval_positions_valid
|
53 |
-
, 'eval_positions_test': eval_positions_test
|
54 |
-
, 'valid_datasets': valid_datasets
|
55 |
-
, 'test_datasets': test_datasets
|
56 |
-
, 'train_datasets': train_datasets
|
57 |
-
, 'verbose': True
|
58 |
-
, 'device': device
|
59 |
-
}
|
60 |
-
|
61 |
-
params.update(get_params_from_config(c))
|
62 |
-
|
63 |
-
start = time.time()
|
64 |
-
metrics, metrics_valid, style, temperature, optimization_route = evaluate_differentiable_model(model, **params,
|
65 |
-
**extra_tuning_args)
|
66 |
-
print('Evaluation time: ', time.time() - start)
|
67 |
-
|
68 |
-
print(results_file)
|
69 |
-
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
|
70 |
-
with open(results_file, 'wb') as output:
|
71 |
-
del r[0]['num_features_used']
|
72 |
-
del r[0]['categorical_features_sampler']
|
73 |
-
pickle.dump(r, output)
|
74 |
-
|
75 |
-
_, _, _, style, temperature, _ = r
|
76 |
-
|
77 |
-
return r, model
|
78 |
-
|
79 |
-
"""
|
80 |
-
===============================
|
81 |
-
INTERNAL HELPER FUNCTIONS
|
82 |
-
===============================
|
83 |
-
"""
|
84 |
-
|
85 |
-
def evaluate_differentiable_model(model
|
86 |
-
, valid_datasets
|
87 |
-
, test_datasets
|
88 |
-
, train_datasets
|
89 |
-
, N_draws=100
|
90 |
-
, N_grad_steps=10
|
91 |
-
, eval_positions=None
|
92 |
-
, eval_positions_test=None
|
93 |
-
, bptt=100
|
94 |
-
, bptt_final=200
|
95 |
-
, style=None
|
96 |
-
, n_parallel_configurations=1
|
97 |
-
, device='cpu'
|
98 |
-
, selection_metric='auc'
|
99 |
-
, final_splits=[1, 2, 3, 4, 5]
|
100 |
-
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
101 |
-
, **kwargs):
|
102 |
-
"""
|
103 |
-
Evaluation function for diffable model evaluation. Returns a list of results.
|
104 |
-
|
105 |
-
:param model:
|
106 |
-
:param valid_datasets:
|
107 |
-
:param test_datasets:
|
108 |
-
:param train_datasets:
|
109 |
-
:param N_draws:
|
110 |
-
:param N_grad_steps:
|
111 |
-
:param eval_positions:
|
112 |
-
:param eval_positions_test:
|
113 |
-
:param bptt:
|
114 |
-
:param bptt_final:
|
115 |
-
:param style:
|
116 |
-
:param n_parallel_configurations:
|
117 |
-
:param device:
|
118 |
-
:param selection_metric:
|
119 |
-
:param final_splits:
|
120 |
-
:param N_ensemble_configurations_list:
|
121 |
-
:param kwargs:
|
122 |
-
:return:
|
123 |
-
"""
|
124 |
-
torch.manual_seed(0)
|
125 |
-
np.random.seed(0)
|
126 |
-
random.seed(0)
|
127 |
-
|
128 |
-
diffable_metric = tabular_metrics.cross_entropy
|
129 |
-
evaluation_metric = tabular_metrics.auc_metric
|
130 |
-
if selection_metric in ('auc', 'roc'):
|
131 |
-
selection_metric_min_max = 'max'
|
132 |
-
selection_metric = tabular_metrics.auc_metric
|
133 |
-
evaluation_metric = selection_metric
|
134 |
-
elif selection_metric in ('ce', 'selection_metric'):
|
135 |
-
selection_metric_min_max = 'min'
|
136 |
-
selection_metric = tabular_metrics.cross_entropy
|
137 |
-
evaluation_metric = selection_metric
|
138 |
-
|
139 |
-
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
|
140 |
-
evaluation_metric)
|
141 |
-
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
|
142 |
-
print('eval_positions', eval_positions)
|
143 |
-
|
144 |
-
def evaluate_valid(style, softmax_temperature, results, results_tracked):
|
145 |
-
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
|
146 |
-
return_tensor=False, inference_mode=True, selection_metric=selection_metric,
|
147 |
-
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
|
148 |
-
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
|
149 |
-
results += [result_valid]
|
150 |
-
results_tracked += [np.nanmean(result_valid)]
|
151 |
-
|
152 |
-
model[2].to(device)
|
153 |
-
model[2].eval()
|
154 |
-
|
155 |
-
results_on_valid, results_on_valid_tracked = [], []
|
156 |
-
best_style, best_softmax_temperature = style, torch.cat(
|
157 |
-
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
|
158 |
-
optimization_routes = []
|
159 |
-
|
160 |
-
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
161 |
-
0)
|
162 |
-
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
163 |
-
0)
|
164 |
-
|
165 |
-
|
166 |
-
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
|
167 |
-
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
168 |
-
0)
|
169 |
-
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
170 |
-
0)
|
171 |
-
|
172 |
-
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
|
173 |
-
|
174 |
-
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
|
175 |
-
|
176 |
-
if N_grad_steps > 0:
|
177 |
-
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
|
178 |
-
, softmax_temperature=softmax_temperature
|
179 |
-
, model=model[2]
|
180 |
-
, train_datasets=train_datasets
|
181 |
-
, valid_datasets=valid_datasets
|
182 |
-
, selection_metric_min_max=selection_metric_min_max
|
183 |
-
, **kwargs)
|
184 |
-
optimization_routes += [gradient_optimize_result['optimization_route']]
|
185 |
-
|
186 |
-
evaluate_valid(gradient_optimize_result['best_style']
|
187 |
-
, gradient_optimize_result['best_temperature']
|
188 |
-
, results_on_valid, results_on_valid_tracked)
|
189 |
-
|
190 |
-
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
|
191 |
-
|
192 |
-
if selection_metric_min_max == 'min':
|
193 |
-
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
|
194 |
-
else:
|
195 |
-
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
|
196 |
-
|
197 |
-
if is_best or best_style is None:
|
198 |
-
best_style = gradient_optimize_result['best_style'].clone()
|
199 |
-
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
|
200 |
-
torch.cuda.empty_cache()
|
201 |
-
|
202 |
-
def final_evaluation():
|
203 |
-
print('Running eval dataset with final params (no gradients)..')
|
204 |
-
print(best_style, best_softmax_temperature)
|
205 |
-
result_test = []
|
206 |
-
for N_ensemble_configurations in N_ensemble_configurations_list:
|
207 |
-
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
208 |
-
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
209 |
-
splits = []
|
210 |
-
for split in final_splits:
|
211 |
-
splits += [eval_step(test_datasets, best_style, softmax_temperature=best_softmax_temperature
|
212 |
-
, return_tensor=False, eval_positions=eval_positions_test,
|
213 |
-
bptt=bptt_final, inference_mode=True, split_number=split, model=model[2]
|
214 |
-
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)]
|
215 |
-
result_test += [splits]
|
216 |
-
|
217 |
-
print('Running valid dataset with final params (no gradients)..')
|
218 |
-
result_valid = eval_step(valid_datasets, best_style, softmax_temperature=best_softmax_temperature
|
219 |
-
, return_tensor=False, eval_positions=eval_positions_test,
|
220 |
-
bptt=bptt_final, inference_mode=True, model=model[2]
|
221 |
-
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)
|
222 |
-
|
223 |
-
return result_test, result_valid
|
224 |
-
|
225 |
-
result_test, result_valid = final_evaluation()
|
226 |
-
|
227 |
-
return result_test, result_valid, best_style, best_softmax_temperature, optimization_routes
|
228 |
-
|
229 |
-
|
230 |
-
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
231 |
-
def step():
|
232 |
-
return evaluate(datasets=ds,
|
233 |
-
method='transformer'
|
234 |
-
, overwrite=True
|
235 |
-
, style=used_style
|
236 |
-
, eval_positions=eval_positions
|
237 |
-
, metric_used=selection_metric
|
238 |
-
, save=False
|
239 |
-
, path_interfix=None
|
240 |
-
, base_path=None
|
241 |
-
, verbose=True
|
242 |
-
, **kwargs)
|
243 |
-
|
244 |
-
if return_tensor:
|
245 |
-
r = step()
|
246 |
-
else:
|
247 |
-
with torch.no_grad():
|
248 |
-
r = step()
|
249 |
-
|
250 |
-
calculate_score_per_method(selection_metric, 'select', r, ds, eval_positions, aggregator='mean')
|
251 |
-
calculate_score_per_method(evaluation_metric, 'eval', r, ds, eval_positions, aggregator='mean')
|
252 |
-
|
253 |
-
return r
|
254 |
-
|
255 |
-
|
256 |
-
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
|
257 |
-
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
258 |
-
"""
|
259 |
-
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
260 |
-
|
261 |
-
:param model:
|
262 |
-
:param init_style:
|
263 |
-
:param steps:
|
264 |
-
:param learning_rate:
|
265 |
-
:param softmax_temperature:
|
266 |
-
:param train_datasets:
|
267 |
-
:param valid_datasets:
|
268 |
-
:param optimize_all:
|
269 |
-
:param limit_style:
|
270 |
-
:param N_datasets_sampled:
|
271 |
-
:param optimize_softmax_temperature:
|
272 |
-
:param selection_metric_min_max:
|
273 |
-
:param kwargs:
|
274 |
-
:return:
|
275 |
-
"""
|
276 |
-
grad_style = torch.nn.Parameter(init_style.detach(), requires_grad=True)
|
277 |
-
|
278 |
-
best_style, best_temperature, best_selection_metric, best_diffable_metric = grad_style.detach(), softmax_temperature.detach(), None, None
|
279 |
-
softmax_temperature = torch.nn.Parameter(softmax_temperature.detach(), requires_grad=optimize_softmax_temperature)
|
280 |
-
variables_to_optimize = model[2].parameters() if optimize_all else [grad_style, softmax_temperature]
|
281 |
-
optimizer = torch.optim.Adam(variables_to_optimize, lr=learning_rate)
|
282 |
-
|
283 |
-
optimization_route_selection, optimization_route_diffable = [], []
|
284 |
-
optimization_route_selection_valid, optimization_route_diffable_valid = [], []
|
285 |
-
|
286 |
-
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
287 |
-
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
288 |
-
, inference_mode=inference_mode, model=model[2], **kwargs)
|
289 |
-
|
290 |
-
diffable_metric = result['mean_metric']
|
291 |
-
selection_metric = result['mean_select']
|
292 |
-
|
293 |
-
return diffable_metric, selection_metric
|
294 |
-
|
295 |
-
def eval_all_datasets(datasets, propagate=True):
|
296 |
-
selection_metrics_this_step, diffable_metrics_this_step = [], []
|
297 |
-
for ds in datasets:
|
298 |
-
diffable_metric_train, selection_metric_train = eval_opt([ds], inference_mode=(not propagate))
|
299 |
-
if not torch.isnan(diffable_metric_train).any():
|
300 |
-
if propagate and diffable_metric_train.requires_grad == True:
|
301 |
-
diffable_metric_train.backward()
|
302 |
-
selection_metrics_this_step += [selection_metric_train]
|
303 |
-
diffable_metrics_this_step += [float(diffable_metric_train.detach().cpu().numpy())]
|
304 |
-
diffable_metric_train = np.nanmean(diffable_metrics_this_step)
|
305 |
-
selection_metric_train = np.nanmean(selection_metrics_this_step)
|
306 |
-
|
307 |
-
return diffable_metric_train, selection_metric_train
|
308 |
-
|
309 |
-
for t in tqdm(range(steps), desc='Iterate over Optimization steps'):
|
310 |
-
optimizer.zero_grad()
|
311 |
-
|
312 |
-
# Select subset of datasets
|
313 |
-
random.seed(t)
|
314 |
-
train_datasets_ = random.sample(train_datasets, N_datasets_sampled)
|
315 |
-
|
316 |
-
# Get score on train
|
317 |
-
diffable_metric_train, selection_metric_train = eval_all_datasets(train_datasets_, propagate=True)
|
318 |
-
optimization_route_selection += [float(selection_metric_train)]
|
319 |
-
optimization_route_diffable += [float(diffable_metric_train)]
|
320 |
-
|
321 |
-
# Get score on valid
|
322 |
-
diffable_metric_valid, selection_metric_valid = eval_all_datasets(valid_datasets, propagate=False)
|
323 |
-
optimization_route_selection_valid += [float(selection_metric_valid)]
|
324 |
-
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
325 |
-
|
326 |
-
is_best = (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
|
327 |
-
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
328 |
-
if (best_selection_metric is None) or (not np.isnan(selection_metric_valid) and is_best):
|
329 |
-
print('New best', best_selection_metric, selection_metric_valid)
|
330 |
-
best_style = grad_style.detach().clone()
|
331 |
-
best_temperature = softmax_temperature.detach().clone()
|
332 |
-
best_selection_metric, best_diffable_metric = selection_metric_valid, diffable_metric_valid
|
333 |
-
|
334 |
-
optimizer.step()
|
335 |
-
|
336 |
-
if limit_style:
|
337 |
-
grad_style = grad_style.detach().clamp(-1.74, 1.74)
|
338 |
-
|
339 |
-
print(f'Valid: Diffable metric={diffable_metric_valid} Selection metric={selection_metric_valid};' +
|
340 |
-
f'Train: Diffable metric={diffable_metric_train} Selection metric={selection_metric_train}')
|
341 |
-
|
342 |
-
print(f'Return best:{best_style} {best_selection_metric}')
|
343 |
-
return {'best_style': best_style, 'best_temperature': best_temperature
|
344 |
-
, 'optimization_route': {'select': optimization_route_selection, 'loss': optimization_route_diffable,
|
345 |
-
'test_select': optimization_route_selection_valid, 'test_loss': optimization_route_diffable_valid}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TabPFN/layer.py
CHANGED
@@ -103,6 +103,12 @@ class TransformerEncoderLayer(Module):
|
|
103 |
|
104 |
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
else:
|
107 |
if self.recompute_attn:
|
108 |
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
|
|
103 |
|
104 |
src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)
|
105 |
|
106 |
+
elif isinstance(src_mask, int):
|
107 |
+
assert src_key_padding_mask is None
|
108 |
+
single_eval_position = src_mask
|
109 |
+
src_left = self.self_attn(src_[:single_eval_position], src_[:single_eval_position], src_[:single_eval_position])[0]
|
110 |
+
src_right = self.self_attn(src_[single_eval_position:], src_[:single_eval_position], src_[:single_eval_position])[0]
|
111 |
+
src2 = torch.cat([src_left, src_right], dim=0)
|
112 |
else:
|
113 |
if self.recompute_attn:
|
114 |
src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
|
TabPFN/model_builder.py
DELETED
@@ -1,273 +0,0 @@
|
|
1 |
-
from train import train, Losses
|
2 |
-
import priors
|
3 |
-
import encoders
|
4 |
-
|
5 |
-
from collections import defaultdict
|
6 |
-
|
7 |
-
from priors.utils import trunc_norm_sampler_f, gamma_sampler_f
|
8 |
-
from utils import get_uniform_single_eval_pos_sampler
|
9 |
-
import torch
|
10 |
-
import math
|
11 |
-
|
12 |
-
def save_model(model, path, filename, config_sample):
|
13 |
-
config_sample = {**config_sample}
|
14 |
-
|
15 |
-
def make_serializable(config_sample):
|
16 |
-
if isinstance(config_sample, dict):
|
17 |
-
config_sample = {k: make_serializable(config_sample[k]) for k in config_sample}
|
18 |
-
if isinstance(config_sample, list):
|
19 |
-
config_sample = [make_serializable(v) for v in config_sample]
|
20 |
-
if callable(config_sample):
|
21 |
-
config_sample = str(config_sample)
|
22 |
-
return config_sample
|
23 |
-
|
24 |
-
#if 'num_features_used' in config_sample:
|
25 |
-
# del config_sample['num_features_used']
|
26 |
-
|
27 |
-
#config_sample['num_classes_as_str'] = str(config_sample['num_classes'])
|
28 |
-
#del config_sample['num_classes']
|
29 |
-
|
30 |
-
config_sample = make_serializable(config_sample)
|
31 |
-
|
32 |
-
torch.save((model.state_dict(), None, config_sample), os.path.join(path, filename))
|
33 |
-
|
34 |
-
|
35 |
-
import subprocess as sp
|
36 |
-
import os
|
37 |
-
|
38 |
-
def get_gpu_memory():
|
39 |
-
command = "nvidia-smi"
|
40 |
-
memory_free_info = sp.check_output(command.split()).decode('ascii')
|
41 |
-
return memory_free_info
|
42 |
-
|
43 |
-
|
44 |
-
def load_model(path, filename, device, eval_positions, verbose):
|
45 |
-
# TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible.
|
46 |
-
|
47 |
-
model_state, optimizer_state, config_sample = torch.load(
|
48 |
-
os.path.join(path, filename), map_location='cpu')
|
49 |
-
if ('differentiable_hyperparameters' in config_sample
|
50 |
-
and 'prior_mlp_activations' in config_sample['differentiable_hyperparameters']):
|
51 |
-
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values_used'] = config_sample[
|
52 |
-
'differentiable_hyperparameters'][
|
53 |
-
'prior_mlp_activations'][
|
54 |
-
'choice_values']
|
55 |
-
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values'] = [
|
56 |
-
torch.nn.Tanh for k in config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values']]
|
57 |
-
|
58 |
-
config_sample['categorical_features_sampler'] = lambda: lambda x: ([], [], [])
|
59 |
-
config_sample['num_features_used_in_training'] = config_sample['num_features_used']
|
60 |
-
config_sample['num_features_used'] = lambda: config_sample['num_features']
|
61 |
-
config_sample['num_classes_in_training'] = config_sample['num_classes']
|
62 |
-
config_sample['num_classes'] = 2
|
63 |
-
config_sample['batch_size_in_training'] = config_sample['batch_size']
|
64 |
-
config_sample['batch_size'] = 1
|
65 |
-
config_sample['bptt_in_training'] = config_sample['bptt']
|
66 |
-
config_sample['bptt'] = 10
|
67 |
-
config_sample['bptt_extra_samples_in_training'] = config_sample['bptt_extra_samples']
|
68 |
-
config_sample['bptt_extra_samples'] = None
|
69 |
-
|
70 |
-
#print('Memory', str(get_gpu_memory()))
|
71 |
-
|
72 |
-
model = get_model(config_sample, device=device, should_train=False, verbose=verbose)
|
73 |
-
module_prefix = 'module.'
|
74 |
-
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
|
75 |
-
model[2].load_state_dict(model_state)
|
76 |
-
model[2].to(device)
|
77 |
-
|
78 |
-
return model, config_sample
|
79 |
-
|
80 |
-
def fix_loaded_config_sample(loaded_config_sample, config):
|
81 |
-
def copy_to_sample(*k):
|
82 |
-
t,s = loaded_config_sample, config
|
83 |
-
for k_ in k[:-1]:
|
84 |
-
t = t[k_]
|
85 |
-
s = s[k_]
|
86 |
-
t[k[-1]] = s[k[-1]]
|
87 |
-
copy_to_sample('num_features_used')
|
88 |
-
copy_to_sample('num_classes')
|
89 |
-
copy_to_sample('differentiable_hyperparameters','prior_mlp_activations','choice_values')
|
90 |
-
|
91 |
-
def load_config_sample(path, template_config):
|
92 |
-
model_state, optimizer_state, loaded_config_sample = torch.load(path, map_location='cpu')
|
93 |
-
fix_loaded_config_sample(loaded_config_sample, template_config)
|
94 |
-
return loaded_config_sample
|
95 |
-
|
96 |
-
def get_default_spec(test_datasets, valid_datasets):
|
97 |
-
bptt = 10000
|
98 |
-
eval_positions = [1000, 2000, 3000, 4000, 5000] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12]))
|
99 |
-
max_features = max([X.shape[1] for (_, X, _, _, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets])
|
100 |
-
max_splits = 5
|
101 |
-
|
102 |
-
return bptt, eval_positions, max_features, max_splits
|
103 |
-
|
104 |
-
def get_mlp_prior_hyperparameters(config):
|
105 |
-
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
106 |
-
|
107 |
-
if "prior_sigma_gamma_k" in config:
|
108 |
-
sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"])
|
109 |
-
config['init_std'] = sigma_sampler
|
110 |
-
if "prior_noise_std_gamma_k" in config:
|
111 |
-
noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"])
|
112 |
-
config['noise_std'] = noise_std_sampler
|
113 |
-
|
114 |
-
return config
|
115 |
-
|
116 |
-
|
117 |
-
def get_gp_mix_prior_hyperparameters(config):
|
118 |
-
return {'lengthscale_concentration': config["prior_lengthscale_concentration"],
|
119 |
-
'nu': config["prior_nu"],
|
120 |
-
'outputscale_concentration': config["prior_outputscale_concentration"],
|
121 |
-
'categorical_data': config["prior_y_minmax_norm"],
|
122 |
-
'y_minmax_norm': config["prior_lengthscale_concentration"],
|
123 |
-
'noise_concentration': config["prior_noise_concentration"],
|
124 |
-
'noise_rate': config["prior_noise_rate"]}
|
125 |
-
|
126 |
-
def get_gp_prior_hyperparameters(config):
|
127 |
-
return {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
128 |
-
|
129 |
-
|
130 |
-
def get_meta_gp_prior_hyperparameters(config):
|
131 |
-
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config}
|
132 |
-
|
133 |
-
if "outputscale_mean" in config:
|
134 |
-
outputscale_sampler = trunc_norm_sampler_f(config["outputscale_mean"]
|
135 |
-
, config["outputscale_mean"] * config["outputscale_std_f"])
|
136 |
-
config['outputscale'] = outputscale_sampler
|
137 |
-
if "lengthscale_mean" in config:
|
138 |
-
lengthscale_sampler = trunc_norm_sampler_f(config["lengthscale_mean"],
|
139 |
-
config["lengthscale_mean"] * config["lengthscale_std_f"])
|
140 |
-
config['lengthscale'] = lengthscale_sampler
|
141 |
-
|
142 |
-
return config
|
143 |
-
|
144 |
-
|
145 |
-
def get_model(config, device, should_train=True, verbose=False, state_dict=None, epoch_callback=None):
|
146 |
-
extra_kwargs = {}
|
147 |
-
verbose_train, verbose_prior = verbose >= 1, verbose >= 2
|
148 |
-
config['verbose'] = verbose_prior
|
149 |
-
|
150 |
-
if 'aggregate_k_gradients' not in config or config['aggregate_k_gradients'] is None:
|
151 |
-
config['aggregate_k_gradients'] = math.ceil(config['batch_size'] * ((config['nlayers'] * config['emsize'] * config['bptt'] * config['bptt']) / 10824640000))
|
152 |
-
|
153 |
-
config['num_steps'] = math.ceil(config['num_steps'] * config['aggregate_k_gradients'])
|
154 |
-
config['batch_size'] = math.ceil(config['batch_size'] / config['aggregate_k_gradients'])
|
155 |
-
config['recompute_attn'] = config['recompute_attn'] if 'recompute_attn' in config else False
|
156 |
-
|
157 |
-
def make_get_batch(model_proto, **extra_kwargs):
|
158 |
-
extra_kwargs = defaultdict(lambda: None, **extra_kwargs)
|
159 |
-
return (lambda batch_size, seq_len, num_features, hyperparameters
|
160 |
-
, device, model_proto=model_proto, get_batch=extra_kwargs['get_batch']
|
161 |
-
, prior_bag_priors=extra_kwargs['prior_bag_priors']: model_proto.get_batch(
|
162 |
-
batch_size=batch_size
|
163 |
-
, seq_len=seq_len
|
164 |
-
, device=device
|
165 |
-
, get_batch=get_batch
|
166 |
-
, hyperparameters=hyperparameters
|
167 |
-
, num_features=num_features))
|
168 |
-
|
169 |
-
if config['prior_type'] == 'prior_bag':
|
170 |
-
# Prior bag combines priors
|
171 |
-
get_batch_gp = make_get_batch(priors.fast_gp)
|
172 |
-
get_batch_mlp = make_get_batch(priors.mlp)
|
173 |
-
if 'flexible' in config and config['flexible']:
|
174 |
-
get_batch_gp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_gp})
|
175 |
-
get_batch_mlp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_mlp})
|
176 |
-
prior_bag_hyperparameters = {'prior_bag_get_batch': (get_batch_gp, get_batch_mlp)
|
177 |
-
, 'prior_bag_exp_weights_1': 2.0}
|
178 |
-
prior_hyperparameters = {**get_mlp_prior_hyperparameters(config), **get_gp_prior_hyperparameters(config)
|
179 |
-
, **prior_bag_hyperparameters}
|
180 |
-
model_proto = priors.prior_bag
|
181 |
-
else:
|
182 |
-
if config['prior_type'] == 'mlp':
|
183 |
-
prior_hyperparameters = get_mlp_prior_hyperparameters(config)
|
184 |
-
model_proto = priors.mlp
|
185 |
-
elif config['prior_type'] == 'gp':
|
186 |
-
prior_hyperparameters = get_gp_prior_hyperparameters(config)
|
187 |
-
model_proto = priors.fast_gp
|
188 |
-
elif config['prior_type'] == 'gp_mix':
|
189 |
-
prior_hyperparameters = get_gp_mix_prior_hyperparameters(config)
|
190 |
-
model_proto = priors.fast_gp_mix
|
191 |
-
else:
|
192 |
-
raise Exception()
|
193 |
-
|
194 |
-
if 'flexible' in config and config['flexible']:
|
195 |
-
get_batch_base = make_get_batch(model_proto)
|
196 |
-
extra_kwargs['get_batch'] = get_batch_base
|
197 |
-
model_proto = priors.flexible_categorical
|
198 |
-
|
199 |
-
use_style = False
|
200 |
-
|
201 |
-
if 'differentiable' in config and config['differentiable']:
|
202 |
-
get_batch_base = make_get_batch(model_proto, **extra_kwargs)
|
203 |
-
extra_kwargs = {'get_batch': get_batch_base, 'differentiable_hyperparameters': config['differentiable_hyperparameters']}
|
204 |
-
model_proto = priors.differentiable_prior
|
205 |
-
use_style = True
|
206 |
-
print(f"Using style prior: {use_style}")
|
207 |
-
|
208 |
-
if (('nan_prob_no_reason' in config and config['nan_prob_no_reason'] > 0.0) or
|
209 |
-
('nan_prob_a_reason' in config and config['nan_prob_a_reason'] > 0.0) or
|
210 |
-
('nan_prob_unknown_reason' in config and config['nan_prob_unknown_reason'] > 0.0)):
|
211 |
-
encoder = encoders.NanHandlingEncoder
|
212 |
-
else:
|
213 |
-
encoder = encoders.Linear
|
214 |
-
|
215 |
-
num_outputs = config['num_outputs'] if 'num_outputs' in config else 1
|
216 |
-
if config['max_num_classes'] == 2:
|
217 |
-
if 'joint_loss' in config and config['joint_loss']:
|
218 |
-
loss = JointBCELossWithLogits
|
219 |
-
else:
|
220 |
-
loss = Losses.bce
|
221 |
-
elif config['max_num_classes'] > 2:
|
222 |
-
loss = Losses.ce(torch.ones((config['max_num_classes'])))
|
223 |
-
else:
|
224 |
-
loss = BarDistribution(borders=get_bucket_limits(500, full_range=(-10, 10)))
|
225 |
-
|
226 |
-
aggregate_k_gradients = 1 if 'aggregate_k_gradients' not in config else config['aggregate_k_gradients']
|
227 |
-
check_is_compatible = False if 'multiclass_loss_type' not in config else (config['multiclass_loss_type'] == 'compatible')
|
228 |
-
config['multiclass_type'] = config['multiclass_type'] if 'multiclass_type' in config else 'rank'
|
229 |
-
config['mix_activations'] = config['mix_activations'] if 'mix_activations' in config else False
|
230 |
-
|
231 |
-
config['bptt_extra_samples'] = config['bptt_extra_samples'] if 'bptt_extra_samples' in config else None
|
232 |
-
config['eval_positions'] = [int(config['bptt'] * 0.95)] if config['bptt_extra_samples'] is None else [int(config['bptt'])]
|
233 |
-
|
234 |
-
epochs = 0 if not should_train else config['epochs']
|
235 |
-
model = train(model_proto.DataLoader
|
236 |
-
, loss
|
237 |
-
, encoder
|
238 |
-
, style_encoder_generator = encoders.StyleEncoder if use_style else None
|
239 |
-
, emsize=config['emsize']
|
240 |
-
, nhead=config['nhead']
|
241 |
-
, y_encoder_generator= encoders.get_Canonical(config['max_num_classes']) if config.get('canonical_y_encoder', False) else encoders.Linear
|
242 |
-
, pos_encoder_generator=None
|
243 |
-
, batch_size=config['batch_size']
|
244 |
-
, nlayers=config['nlayers']
|
245 |
-
, nhid=config['emsize'] * config['nhid_factor']
|
246 |
-
, epochs=epochs
|
247 |
-
, total_available_time_in_s=config.get('total_available_time_in_s', None)
|
248 |
-
, warmup_epochs=20
|
249 |
-
, bptt=config['bptt']
|
250 |
-
, gpu_device=device
|
251 |
-
, dropout=config['dropout']
|
252 |
-
, steps_per_epoch=config['num_steps']
|
253 |
-
, single_eval_pos_gen=get_uniform_single_eval_pos_sampler(config['bptt'])
|
254 |
-
, load_weights_from_this_state_dict=state_dict
|
255 |
-
, aggregate_k_gradients=aggregate_k_gradients
|
256 |
-
, check_is_compatible=check_is_compatible
|
257 |
-
, recompute_attn=config['recompute_attn']
|
258 |
-
, epoch_callback=epoch_callback
|
259 |
-
, bptt_extra_samples = config['bptt_extra_samples']
|
260 |
-
, extra_prior_kwargs_dict={
|
261 |
-
'num_features': config['num_features']
|
262 |
-
, 'fuse_x_y': False
|
263 |
-
, 'hyperparameters': prior_hyperparameters
|
264 |
-
, 'num_outputs':num_outputs
|
265 |
-
, 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2
|
266 |
-
, **extra_kwargs
|
267 |
-
}
|
268 |
-
, lr=config['lr']
|
269 |
-
, verbose=verbose_train,
|
270 |
-
weight_decay=config.get('weight_decay', 0.0),
|
271 |
-
normalize_labels=True)
|
272 |
-
|
273 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TabPFN/models_diff/gp_ablation_model.cpkt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c7b0c8febc553cca3fdee265b5a1cd7567dbf83da855969940be4707a9218ffb
|
3 |
-
size 69460013
|
|
|
|
|
|
|
|
TabPFN/models_diff/prior_diff_real_checkpoint_n_8x_lr0.0003_epoch_49.cpkt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dae97f45bd53d719fc2b23fac4ec55eab16d63892196d939b1bb1c3b408be242
|
3 |
-
size 103616779
|
|
|
|
|
|
|
|
TabPFN/prior_tuning_result.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:24d2189bbc836aeea888cf6c540f2c1b45b5351822931189e8bf10a0bc80a0b6
|
3 |
-
size 18668851
|
|
|
|
|
|
|
|
TabPFN/scripts/differentiable_pfn_evaluation.py
CHANGED
@@ -10,8 +10,9 @@ from priors.differentiable_prior import draw_random_style
|
|
10 |
from tqdm import tqdm
|
11 |
from pathlib import Path
|
12 |
import random
|
13 |
-
from model_builder import load_model
|
14 |
from scripts.transformer_prediction_interface import get_params_from_config
|
|
|
15 |
|
16 |
"""
|
17 |
===============================
|
@@ -24,55 +25,9 @@ def eval_model_range(i_range, *args, **kwargs):
|
|
24 |
for i in i_range:
|
25 |
eval_model(i, *args, **kwargs)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
Workflow for loading a model and setting appropriate parameters for diffable hparam tuning.
|
31 |
-
|
32 |
-
:param i:
|
33 |
-
:param e:
|
34 |
-
:param eval_positions_valid:
|
35 |
-
:param add_name:
|
36 |
-
:param base_path:
|
37 |
-
:param device:
|
38 |
-
:param eval_addition:
|
39 |
-
:return:
|
40 |
-
"""
|
41 |
-
def check_file(e):
|
42 |
-
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
43 |
-
model_path = os.path.join(base_path, model_file)
|
44 |
-
# print('Evaluate ', model_path)
|
45 |
-
results_file = os.path.join(base_path,
|
46 |
-
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
47 |
-
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
48 |
-
return None, None, None
|
49 |
-
return model_file, model_path, results_file
|
50 |
-
|
51 |
-
model_file = None
|
52 |
-
if e == -1:
|
53 |
-
for e_ in range(100, -1, -1):
|
54 |
-
model_file_, model_path_, results_file_ = check_file(e_)
|
55 |
-
if model_file_ is not None:
|
56 |
-
e = e_
|
57 |
-
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
58 |
-
break
|
59 |
-
else:
|
60 |
-
model_file, model_path, results_file = check_file(e)
|
61 |
-
|
62 |
-
if model_file is None:
|
63 |
-
print('No checkpoint found')
|
64 |
-
return None
|
65 |
-
|
66 |
-
print(f'Loading {model_file}')
|
67 |
-
|
68 |
-
model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False)
|
69 |
-
|
70 |
-
return model, c, results_file
|
71 |
-
|
72 |
-
|
73 |
-
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positions_valid, eval_positions_test,
|
74 |
-
bptt_valid,
|
75 |
-
bptt_test, add_name, base_path, device='cpu', eval_addition='', **extra_tuning_args):
|
76 |
"""
|
77 |
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
78 |
|
@@ -107,12 +62,12 @@ def eval_model(i, e, valid_datasets, test_datasets, train_datasets, eval_positio
|
|
107 |
params.update(get_params_from_config(c))
|
108 |
|
109 |
start = time.time()
|
110 |
-
metrics, metrics_valid, style, temperature, optimization_route =
|
111 |
-
|
112 |
print('Evaluation time: ', time.time() - start)
|
113 |
|
114 |
print(results_file)
|
115 |
-
r = [c.copy(), metrics, metrics_valid, style.to('cpu'), temperature.to('cpu'), optimization_route]
|
116 |
with open(results_file, 'wb') as output:
|
117 |
del r[0]['num_features_used']
|
118 |
del r[0]['categorical_features_sampler']
|
@@ -128,22 +83,18 @@ INTERNAL HELPER FUNCTIONS
|
|
128 |
===============================
|
129 |
"""
|
130 |
|
131 |
-
def
|
132 |
, valid_datasets
|
133 |
, test_datasets
|
134 |
, train_datasets
|
135 |
-
, N_draws=100
|
136 |
-
, N_grad_steps=10
|
137 |
-
, eval_positions=None
|
138 |
, eval_positions_test=None
|
139 |
-
, bptt=100
|
140 |
, bptt_final=200
|
141 |
-
, style=None
|
142 |
-
, n_parallel_configurations=1
|
143 |
, device='cpu'
|
144 |
, selection_metric='auc'
|
145 |
, final_splits=[1, 2, 3, 4, 5]
|
146 |
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
|
|
|
|
147 |
, **kwargs):
|
148 |
"""
|
149 |
Evaluation function for diffable model evaluation. Returns a list of results.
|
@@ -171,107 +122,38 @@ def evaluate_differentiable_model(model
|
|
171 |
np.random.seed(0)
|
172 |
random.seed(0)
|
173 |
|
174 |
-
diffable_metric = tabular_metrics.cross_entropy
|
175 |
evaluation_metric = tabular_metrics.auc_metric
|
176 |
-
|
177 |
-
selection_metric_min_max = 'max'
|
178 |
-
selection_metric = tabular_metrics.auc_metric
|
179 |
-
evaluation_metric = selection_metric
|
180 |
-
elif selection_metric in ('ce', 'selection_metric'):
|
181 |
-
selection_metric_min_max = 'min'
|
182 |
-
selection_metric = tabular_metrics.cross_entropy
|
183 |
-
evaluation_metric = selection_metric
|
184 |
-
|
185 |
-
print('Diffable metric', diffable_metric, ' Selection metric', selection_metric, ' Evaluation metric',
|
186 |
-
evaluation_metric)
|
187 |
-
print('N PARALLEL CONFIGURATIONS', n_parallel_configurations)
|
188 |
-
print('eval_positions', eval_positions)
|
189 |
-
|
190 |
-
def evaluate_valid(style, softmax_temperature, results, results_tracked):
|
191 |
-
result_valid = eval_step(valid_datasets, style, softmax_temperature=softmax_temperature,
|
192 |
-
return_tensor=False, inference_mode=True, selection_metric=selection_metric,
|
193 |
-
evaluation_metric=evaluation_metric, eval_positions=eval_positions, bptt=bptt, model=model[2])
|
194 |
-
result_valid = [float(result_valid[f'mean_select_at_{pos}']) for pos in eval_positions]
|
195 |
-
results += [result_valid]
|
196 |
-
results_tracked += [np.nanmean(result_valid)]
|
197 |
|
198 |
model[2].to(device)
|
199 |
model[2].eval()
|
200 |
|
201 |
-
results_on_valid, results_on_valid_tracked = [], []
|
202 |
-
best_style, best_softmax_temperature = style, torch.cat(
|
203 |
-
[torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)], 0)
|
204 |
-
optimization_routes = []
|
205 |
-
|
206 |
-
best_style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
207 |
-
0)
|
208 |
-
best_softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
209 |
-
0)
|
210 |
-
|
211 |
-
|
212 |
-
for _ in tqdm(range(0, N_draws), desc='Iterate over Optimization initializations'): # Evaluates N hparam draws
|
213 |
-
style = torch.cat([draw_random_style(model[3], device).detach() for n in range(0, n_parallel_configurations)],
|
214 |
-
0)
|
215 |
-
softmax_temperature = torch.cat([torch.tensor([0.0]).to(device) for n in range(0, n_parallel_configurations)],
|
216 |
-
0)
|
217 |
-
|
218 |
-
evaluate_valid(style, softmax_temperature, results_on_valid, results_on_valid_tracked)
|
219 |
-
|
220 |
-
print(f'Draw --> Valid Selection metric: {results_on_valid[-1]}')
|
221 |
-
|
222 |
-
if N_grad_steps > 0:
|
223 |
-
gradient_optimize_result = gradient_optimize_style(model, style, N_grad_steps
|
224 |
-
, softmax_temperature=softmax_temperature
|
225 |
-
, model=model[2]
|
226 |
-
, train_datasets=train_datasets
|
227 |
-
, valid_datasets=valid_datasets
|
228 |
-
, selection_metric_min_max=selection_metric_min_max
|
229 |
-
, **kwargs)
|
230 |
-
optimization_routes += [gradient_optimize_result['optimization_route']]
|
231 |
-
|
232 |
-
evaluate_valid(gradient_optimize_result['best_style']
|
233 |
-
, gradient_optimize_result['best_temperature']
|
234 |
-
, results_on_valid, results_on_valid_tracked)
|
235 |
-
|
236 |
-
print(f'After diff --> Valid Selection metric: {results_on_valid[-1]}')
|
237 |
-
|
238 |
-
if selection_metric_min_max == 'min':
|
239 |
-
is_best = (results_on_valid_tracked[-1] <= min(results_on_valid_tracked))
|
240 |
-
else:
|
241 |
-
is_best = (results_on_valid_tracked[-1] >= max(results_on_valid_tracked))
|
242 |
-
|
243 |
-
if is_best or best_style is None:
|
244 |
-
best_style = gradient_optimize_result['best_style'].clone()
|
245 |
-
best_softmax_temperature = gradient_optimize_result['best_temperature'].clone()
|
246 |
-
torch.cuda.empty_cache()
|
247 |
-
|
248 |
def final_evaluation():
|
249 |
print('Running eval dataset with final params (no gradients)..')
|
250 |
-
print(best_style, best_softmax_temperature)
|
251 |
result_test = []
|
252 |
for N_ensemble_configurations in N_ensemble_configurations_list:
|
253 |
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
254 |
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
255 |
splits = []
|
256 |
for split in final_splits:
|
257 |
-
splits += [eval_step(test_datasets,
|
258 |
, return_tensor=False, eval_positions=eval_positions_test,
|
259 |
-
bptt=bptt_final,
|
260 |
-
, selection_metric=selection_metric, evaluation_metric=evaluation_metric
|
|
|
261 |
result_test += [splits]
|
262 |
|
263 |
print('Running valid dataset with final params (no gradients)..')
|
264 |
-
result_valid = eval_step(valid_datasets,
|
265 |
, return_tensor=False, eval_positions=eval_positions_test,
|
266 |
-
bptt=bptt_final,
|
267 |
-
, selection_metric=selection_metric, evaluation_metric=evaluation_metric)
|
268 |
|
269 |
return result_test, result_valid
|
270 |
|
271 |
result_test, result_valid = final_evaluation()
|
272 |
|
273 |
-
return result_test, result_valid,
|
274 |
-
|
275 |
|
276 |
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
277 |
def step():
|
@@ -284,7 +166,6 @@ def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_position
|
|
284 |
, save=False
|
285 |
, path_interfix=None
|
286 |
, base_path=None
|
287 |
-
, verbose=True
|
288 |
, **kwargs)
|
289 |
|
290 |
if return_tensor:
|
@@ -299,7 +180,7 @@ def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_position
|
|
299 |
return r
|
300 |
|
301 |
|
302 |
-
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, learning_rate=0.03, optimize_all=False,
|
303 |
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
304 |
"""
|
305 |
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
@@ -331,7 +212,7 @@ def gradient_optimize_style(model, init_style, steps, softmax_temperature, train
|
|
331 |
|
332 |
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
333 |
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
334 |
-
, inference_mode=inference_mode, model=model[2], **kwargs)
|
335 |
|
336 |
diffable_metric = result['mean_metric']
|
337 |
selection_metric = result['mean_select']
|
@@ -369,9 +250,10 @@ def gradient_optimize_style(model, init_style, steps, softmax_temperature, train
|
|
369 |
optimization_route_selection_valid += [float(selection_metric_valid)]
|
370 |
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
371 |
|
372 |
-
is_best = (
|
|
|
373 |
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
374 |
-
if (
|
375 |
print('New best', best_selection_metric, selection_metric_valid)
|
376 |
best_style = grad_style.detach().clone()
|
377 |
best_temperature = softmax_temperature.detach().clone()
|
|
|
10 |
from tqdm import tqdm
|
11 |
from pathlib import Path
|
12 |
import random
|
13 |
+
from scripts.model_builder import load_model
|
14 |
from scripts.transformer_prediction_interface import get_params_from_config
|
15 |
+
from scripts.transformer_prediction_interface import load_model_workflow
|
16 |
|
17 |
"""
|
18 |
===============================
|
|
|
25 |
for i in i_range:
|
26 |
eval_model(i, *args, **kwargs)
|
27 |
|
28 |
+
def eval_model(i, e, valid_datasets, test_datasets, train_datasets, add_name, base_path, eval_positions_valid=[1000], eval_positions_test=[1000],
|
29 |
+
bptt_valid=2000,
|
30 |
+
bptt_test=2000, device='cpu', eval_addition='', differentiable=False, **extra_tuning_args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"""
|
32 |
Differentiable model evaliation workflow. Evaluates and saves results to disk.
|
33 |
|
|
|
62 |
params.update(get_params_from_config(c))
|
63 |
|
64 |
start = time.time()
|
65 |
+
metrics, metrics_valid, style, temperature, optimization_route = evaluate_point_model(model, **params,
|
66 |
+
**extra_tuning_args)
|
67 |
print('Evaluation time: ', time.time() - start)
|
68 |
|
69 |
print(results_file)
|
70 |
+
r = [c.copy(), metrics, metrics_valid, style.to('cpu') if style else style, temperature.to('cpu') if temperature else temperature, optimization_route]
|
71 |
with open(results_file, 'wb') as output:
|
72 |
del r[0]['num_features_used']
|
73 |
del r[0]['categorical_features_sampler']
|
|
|
83 |
===============================
|
84 |
"""
|
85 |
|
86 |
+
def evaluate_point_model(model
|
87 |
, valid_datasets
|
88 |
, test_datasets
|
89 |
, train_datasets
|
|
|
|
|
|
|
90 |
, eval_positions_test=None
|
|
|
91 |
, bptt_final=200
|
|
|
|
|
92 |
, device='cpu'
|
93 |
, selection_metric='auc'
|
94 |
, final_splits=[1, 2, 3, 4, 5]
|
95 |
, N_ensemble_configurations_list=[1, 5, 10, 20, 50, 100]
|
96 |
+
, bptt=None
|
97 |
+
, eval_positions=None
|
98 |
, **kwargs):
|
99 |
"""
|
100 |
Evaluation function for diffable model evaluation. Returns a list of results.
|
|
|
122 |
np.random.seed(0)
|
123 |
random.seed(0)
|
124 |
|
|
|
125 |
evaluation_metric = tabular_metrics.auc_metric
|
126 |
+
selection_metric = tabular_metrics.auc_metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
model[2].to(device)
|
129 |
model[2].eval()
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def final_evaluation():
|
132 |
print('Running eval dataset with final params (no gradients)..')
|
|
|
133 |
result_test = []
|
134 |
for N_ensemble_configurations in N_ensemble_configurations_list:
|
135 |
print(f'Running with {N_ensemble_configurations} ensemble_configurations')
|
136 |
kwargs['N_ensemble_configurations'] = N_ensemble_configurations
|
137 |
splits = []
|
138 |
for split in final_splits:
|
139 |
+
splits += [eval_step(test_datasets, None, softmax_temperature=torch.tensor([0])
|
140 |
, return_tensor=False, eval_positions=eval_positions_test,
|
141 |
+
bptt=bptt_final, split_number=split, model=model[2], device=device
|
142 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric
|
143 |
+
, **kwargs)]
|
144 |
result_test += [splits]
|
145 |
|
146 |
print('Running valid dataset with final params (no gradients)..')
|
147 |
+
result_valid = eval_step(valid_datasets, None, softmax_temperature=torch.tensor([0])
|
148 |
, return_tensor=False, eval_positions=eval_positions_test,
|
149 |
+
bptt=bptt_final, model=model[2], device=device
|
150 |
+
, selection_metric=selection_metric, evaluation_metric=evaluation_metric,**kwargs)
|
151 |
|
152 |
return result_test, result_valid
|
153 |
|
154 |
result_test, result_valid = final_evaluation()
|
155 |
|
156 |
+
return result_test, result_valid, None, None, None
|
|
|
157 |
|
158 |
def eval_step(ds, used_style, selection_metric, evaluation_metric, eval_positions, return_tensor=True, **kwargs):
|
159 |
def step():
|
|
|
166 |
, save=False
|
167 |
, path_interfix=None
|
168 |
, base_path=None
|
|
|
169 |
, **kwargs)
|
170 |
|
171 |
if return_tensor:
|
|
|
180 |
return r
|
181 |
|
182 |
|
183 |
+
def gradient_optimize_style(model, init_style, steps, softmax_temperature, train_datasets, valid_datasets, bptt, learning_rate=0.03, optimize_all=False,
|
184 |
limit_style=True, N_datasets_sampled=90, optimize_softmax_temperature=True, selection_metric_min_max='max', **kwargs):
|
185 |
"""
|
186 |
Uses gradient based methods to optimize 'style' on the 'train_datasets' and uses stopping with 'valid_datasets'.
|
|
|
212 |
|
213 |
def eval_opt(ds, return_tensor=True, inference_mode=False):
|
214 |
result = eval_step(ds, grad_style, softmax_temperature=softmax_temperature, return_tensor=return_tensor
|
215 |
+
, inference_mode=inference_mode, model=model[2], bptt=bptt, **kwargs)
|
216 |
|
217 |
diffable_metric = result['mean_metric']
|
218 |
selection_metric = result['mean_select']
|
|
|
250 |
optimization_route_selection_valid += [float(selection_metric_valid)]
|
251 |
optimization_route_diffable_valid += [float(diffable_metric_valid)]
|
252 |
|
253 |
+
is_best = (best_selection_metric is None)
|
254 |
+
is_best = is_best or (selection_metric_min_max == 'min' and best_selection_metric > selection_metric_valid)
|
255 |
is_best = is_best or (selection_metric_min_max == 'max' and best_selection_metric < selection_metric_valid)
|
256 |
+
if (not np.isnan(selection_metric_valid) and is_best):
|
257 |
print('New best', best_selection_metric, selection_metric_valid)
|
258 |
best_style = grad_style.detach().clone()
|
259 |
best_temperature = softmax_temperature.detach().clone()
|
TabPFN/scripts/model_configs.py
CHANGED
@@ -12,10 +12,10 @@ def get_general_config(max_features, bptt, eval_positions=None):
|
|
12 |
Returns the general PFN training hyperparameters.
|
13 |
"""
|
14 |
config_general = {
|
15 |
-
"lr": CSH.UniformFloatHyperparameter('lr', lower=0.
|
16 |
"dropout": CSH.CategoricalHyperparameter('dropout', [0.0]),
|
17 |
"emsize": CSH.CategoricalHyperparameter('emsize', [2 ** i for i in range(8, 9)]), ## upper bound is -1
|
18 |
-
"batch_size": CSH.CategoricalHyperparameter('batch_size', [2 ** i for i in range(
|
19 |
"nlayers": CSH.CategoricalHyperparameter('nlayers', [12]),
|
20 |
"num_features": max_features,
|
21 |
"nhead": CSH.CategoricalHyperparameter('nhead', [4]),
|
@@ -27,8 +27,9 @@ def get_general_config(max_features, bptt, eval_positions=None):
|
|
27 |
"epochs": 80,
|
28 |
"num_steps": 100,
|
29 |
"verbose": False,
|
30 |
-
"
|
31 |
-
"
|
|
|
32 |
}
|
33 |
|
34 |
return config_general
|
@@ -38,9 +39,9 @@ def get_flexible_categorical_config(max_features):
|
|
38 |
Returns the configuration parameters for the tabular multiclass wrapper.
|
39 |
"""
|
40 |
config_flexible_categorical = {
|
41 |
-
"nan_prob_unknown_reason_reason_prior": CSH.CategoricalHyperparameter('nan_prob_unknown_reason_reason_prior', [
|
42 |
-
"categorical_feature_p": CSH.CategoricalHyperparameter('categorical_feature_p', [0.0]),
|
43 |
-
"nan_prob_no_reason": CSH.CategoricalHyperparameter('nan_prob_no_reason', [0.0, 0.1
|
44 |
"nan_prob_unknown_reason": CSH.CategoricalHyperparameter('nan_prob_unknown_reason', [0.0]),
|
45 |
"nan_prob_a_reason": CSH.CategoricalHyperparameter('nan_prob_a_reason', [0.0]),
|
46 |
# "num_classes": lambda : random.randint(2, 10), "balanced": False,
|
@@ -66,6 +67,7 @@ def get_diff_flex():
|
|
66 |
# "num_categorical_features_sampler_a": hp.choice('num_categorical_features_sampler_a',
|
67 |
# [{'distribution': 'uniform', 'min': 0.3, 'max': 0.9}, None]),
|
68 |
# "num_categorical_features_sampler_b": {'distribution': 'uniform', 'min': 0.3, 'max': 0.9},
|
|
|
69 |
"output_multiclass_ordered_p": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5}, #CSH.CategoricalHyperparameter('output_multiclass_ordered_p', [0.0, 0.1, 0.2]),
|
70 |
"multiclass_type": {'distribution': 'meta_choice', 'choice_values': ['value', 'rank']},
|
71 |
}
|
@@ -91,34 +93,41 @@ def get_diff_causal():
|
|
91 |
Returns the configuration parameters for a differentiable wrapper around MLP / Causal mixture.
|
92 |
"""
|
93 |
diff_causal = {
|
94 |
-
"
|
|
|
|
|
|
|
95 |
'lower_bound': 2},
|
96 |
# Better beta?
|
97 |
-
"prior_mlp_hidden_dim": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 130, 'min_mean': 5,
|
98 |
-
|
|
|
99 |
|
100 |
-
"prior_mlp_dropout_prob": {'distribution': 'meta_beta', 'scale': 0.
|
101 |
# This mustn't be too high since activations get too large otherwise
|
102 |
|
103 |
"noise_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': .3, 'min_mean': 0.0001, 'round': False,
|
104 |
'lower_bound': 0.0},
|
105 |
"init_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 0.01, 'round': False,
|
106 |
'lower_bound': 0.0},
|
107 |
-
"num_causes": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 12, 'min_mean': 1, 'round': True,
|
108 |
-
|
|
|
|
|
|
|
109 |
"is_causal": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
110 |
"pre_sample_weights": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
111 |
"y_is_effect": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
|
|
112 |
"prior_mlp_activations": {'distribution': 'meta_choice_mixed', 'choice_values': [
|
113 |
torch.nn.Tanh
|
114 |
-
, torch.nn.ReLU
|
115 |
, torch.nn.Identity
|
116 |
-
,
|
117 |
-
, torch.nn.ELU
|
118 |
]},
|
119 |
"block_wise_dropout": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
120 |
"sort_features": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
121 |
"in_clique": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
|
|
122 |
}
|
123 |
|
124 |
return diff_causal
|
@@ -128,7 +137,7 @@ def get_diff_prior_bag():
|
|
128 |
Returns the configuration parameters for a GP and MLP / Causal mixture.
|
129 |
"""
|
130 |
diff_prior_bag = {
|
131 |
-
'prior_bag_exp_weights_1': {'distribution': 'uniform', 'min':
|
132 |
# MLP Weight (Biased, since MLP works better, 1.0 is weight for prior number 0)
|
133 |
}
|
134 |
|
@@ -148,6 +157,72 @@ def get_diff_config():
|
|
148 |
return config_diff
|
149 |
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
def sample_differentiable(config):
|
152 |
""""
|
153 |
Returns sampled hyperparameters from a differentiable wrapper, that is it makes a non-differentiable out of
|
|
|
12 |
Returns the general PFN training hyperparameters.
|
13 |
"""
|
14 |
config_general = {
|
15 |
+
"lr": CSH.UniformFloatHyperparameter('lr', lower=0.0001, upper=0.00015, log=True),
|
16 |
"dropout": CSH.CategoricalHyperparameter('dropout', [0.0]),
|
17 |
"emsize": CSH.CategoricalHyperparameter('emsize', [2 ** i for i in range(8, 9)]), ## upper bound is -1
|
18 |
+
"batch_size": CSH.CategoricalHyperparameter('batch_size', [2 ** i for i in range(6, 8)]),
|
19 |
"nlayers": CSH.CategoricalHyperparameter('nlayers', [12]),
|
20 |
"num_features": max_features,
|
21 |
"nhead": CSH.CategoricalHyperparameter('nhead', [4]),
|
|
|
27 |
"epochs": 80,
|
28 |
"num_steps": 100,
|
29 |
"verbose": False,
|
30 |
+
"mix_activations": False,
|
31 |
+
"pre_sample_causes": True,
|
32 |
+
"multiclass_type": 'rank'
|
33 |
}
|
34 |
|
35 |
return config_general
|
|
|
39 |
Returns the configuration parameters for the tabular multiclass wrapper.
|
40 |
"""
|
41 |
config_flexible_categorical = {
|
42 |
+
"nan_prob_unknown_reason_reason_prior": CSH.CategoricalHyperparameter('nan_prob_unknown_reason_reason_prior', [0.5]),
|
43 |
+
"categorical_feature_p": CSH.CategoricalHyperparameter('categorical_feature_p', [0.0, 0.1, 0.2]),
|
44 |
+
"nan_prob_no_reason": CSH.CategoricalHyperparameter('nan_prob_no_reason', [0.0, 0.1]),
|
45 |
"nan_prob_unknown_reason": CSH.CategoricalHyperparameter('nan_prob_unknown_reason', [0.0]),
|
46 |
"nan_prob_a_reason": CSH.CategoricalHyperparameter('nan_prob_a_reason', [0.0]),
|
47 |
# "num_classes": lambda : random.randint(2, 10), "balanced": False,
|
|
|
67 |
# "num_categorical_features_sampler_a": hp.choice('num_categorical_features_sampler_a',
|
68 |
# [{'distribution': 'uniform', 'min': 0.3, 'max': 0.9}, None]),
|
69 |
# "num_categorical_features_sampler_b": {'distribution': 'uniform', 'min': 0.3, 'max': 0.9},
|
70 |
+
|
71 |
"output_multiclass_ordered_p": {'distribution': 'uniform', 'min': 0.0, 'max': 0.5}, #CSH.CategoricalHyperparameter('output_multiclass_ordered_p', [0.0, 0.1, 0.2]),
|
72 |
"multiclass_type": {'distribution': 'meta_choice', 'choice_values': ['value', 'rank']},
|
73 |
}
|
|
|
93 |
Returns the configuration parameters for a differentiable wrapper around MLP / Causal mixture.
|
94 |
"""
|
95 |
diff_causal = {
|
96 |
+
#"mix_activations": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
97 |
+
#"num_layers": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 6, 'min_mean': 1, 'round': True,
|
98 |
+
# 'lower_bound': 2},
|
99 |
+
"num_layers": {'distribution': 'meta_gamma', 'max_alpha': 2, 'max_scale': 3, 'round': True,
|
100 |
'lower_bound': 2},
|
101 |
# Better beta?
|
102 |
+
#"prior_mlp_hidden_dim": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 130, 'min_mean': 5,
|
103 |
+
# 'round': True, 'lower_bound': 4},
|
104 |
+
"prior_mlp_hidden_dim": {'distribution': 'meta_gamma', 'max_alpha': 3, 'max_scale': 100, 'round': True, 'lower_bound': 4},
|
105 |
|
106 |
+
"prior_mlp_dropout_prob": {'distribution': 'meta_beta', 'scale': 0.6, 'min': 0.1, 'max': 5.0},
|
107 |
# This mustn't be too high since activations get too large otherwise
|
108 |
|
109 |
"noise_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': .3, 'min_mean': 0.0001, 'round': False,
|
110 |
'lower_bound': 0.0},
|
111 |
"init_std": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 10.0, 'min_mean': 0.01, 'round': False,
|
112 |
'lower_bound': 0.0},
|
113 |
+
#"num_causes": {'distribution': 'meta_trunc_norm_log_scaled', 'max_mean': 12, 'min_mean': 1, 'round': True,
|
114 |
+
# 'lower_bound': 1},
|
115 |
+
"num_causes": {'distribution': 'meta_gamma', 'max_alpha': 3, 'max_scale': 7, 'round': True,
|
116 |
+
'lower_bound': 2},
|
117 |
+
|
118 |
"is_causal": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
119 |
"pre_sample_weights": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
120 |
"y_is_effect": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
121 |
+
"sampling": {'distribution': 'meta_choice', 'choice_values': ['normal', 'mixed']},
|
122 |
"prior_mlp_activations": {'distribution': 'meta_choice_mixed', 'choice_values': [
|
123 |
torch.nn.Tanh
|
|
|
124 |
, torch.nn.Identity
|
125 |
+
, torch.nn.ReLU
|
|
|
126 |
]},
|
127 |
"block_wise_dropout": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
128 |
"sort_features": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
129 |
"in_clique": {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
130 |
+
#'pre_sample_causes': {'distribution': 'meta_choice', 'choice_values': [True, False]},
|
131 |
}
|
132 |
|
133 |
return diff_causal
|
|
|
137 |
Returns the configuration parameters for a GP and MLP / Causal mixture.
|
138 |
"""
|
139 |
diff_prior_bag = {
|
140 |
+
'prior_bag_exp_weights_1': {'distribution': 'uniform', 'min': 2.0, 'max': 10.0},
|
141 |
# MLP Weight (Biased, since MLP works better, 1.0 is weight for prior number 0)
|
142 |
}
|
143 |
|
|
|
157 |
return config_diff
|
158 |
|
159 |
|
160 |
+
def get_prior_config(config_type):
|
161 |
+
if config_type == 'causal':
|
162 |
+
return get_prior_config_causal()
|
163 |
+
elif config_type == 'gp':
|
164 |
+
return get_prior_config_gp()
|
165 |
+
elif config_type == 'bnn':
|
166 |
+
return get_prior_config_bnn()
|
167 |
+
|
168 |
+
|
169 |
+
def get_prior_config_gp(max_features=100):
|
170 |
+
config_general = get_general_config(max_features, 50, eval_positions=[30])
|
171 |
+
config_general_real_world = {**config_general}
|
172 |
+
|
173 |
+
config_flexible_categorical = get_flexible_categorical_config(max_features)
|
174 |
+
config_flexible_categorical_real_world = {**config_flexible_categorical}
|
175 |
+
|
176 |
+
config_gp = {}
|
177 |
+
|
178 |
+
config_diff = get_diff_config()
|
179 |
+
|
180 |
+
config = {**config_general_real_world, **config_flexible_categorical_real_world, **config_diff, **config_gp}
|
181 |
+
|
182 |
+
config['differentiable_hyperparameters']['prior_bag_exp_weights_1'] = {'distribution': 'uniform', 'min': 0.0,
|
183 |
+
'max': .01} # Never select MLP
|
184 |
+
|
185 |
+
|
186 |
+
def get_prior_config_bnn(max_features=100):
|
187 |
+
config_general = get_general_config(max_features, 50, eval_positions=[30])
|
188 |
+
config_general_real_world = {**config_general}
|
189 |
+
|
190 |
+
config_flexible_categorical = get_flexible_categorical_config(max_features)
|
191 |
+
config_flexible_categorical_real_world = {**config_flexible_categorical}
|
192 |
+
|
193 |
+
config_gp = {}
|
194 |
+
config_mlp = {}
|
195 |
+
|
196 |
+
config_diff = get_diff_config()
|
197 |
+
|
198 |
+
config = {**config_general_real_world, **config_flexible_categorical_real_world, **config_diff, **config_gp,
|
199 |
+
**config_mlp}
|
200 |
+
|
201 |
+
config['differentiable_hyperparameters']['prior_bag_exp_weights_1'] = {'distribution': 'uniform',
|
202 |
+
'min': 1000.0,
|
203 |
+
'max': 1001.0} # Always select MLP
|
204 |
+
|
205 |
+
|
206 |
+
def get_prior_config_causal(max_features=100):
|
207 |
+
config_general = get_general_config(max_features, 50, eval_positions=[30])
|
208 |
+
config_general_real_world = {**config_general}
|
209 |
+
|
210 |
+
config_flexible_categorical = get_flexible_categorical_config(max_features)
|
211 |
+
config_flexible_categorical_real_world = {**config_flexible_categorical}
|
212 |
+
config_flexible_categorical_real_world[
|
213 |
+
'num_categorical_features_sampler_a'] = -1.0 # Categorical features disabled by default
|
214 |
+
|
215 |
+
config_gp = {}
|
216 |
+
config_mlp = {}
|
217 |
+
|
218 |
+
config_diff = get_diff_config()
|
219 |
+
|
220 |
+
config = {**config_general_real_world, **config_flexible_categorical_real_world, **config_diff, **config_gp,
|
221 |
+
**config_mlp}
|
222 |
+
|
223 |
+
return config
|
224 |
+
|
225 |
+
|
226 |
def sample_differentiable(config):
|
227 |
""""
|
228 |
Returns sampled hyperparameters from a differentiable wrapper, that is it makes a non-differentiable out of
|
TabPFN/scripts/tabular_baselines.py
CHANGED
@@ -1,19 +1,34 @@
|
|
|
|
1 |
from catboost import CatBoostClassifier, Pool
|
|
|
|
|
|
|
2 |
|
|
|
|
|
3 |
import math
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
from sklearn.impute import SimpleImputer
|
6 |
|
|
|
|
|
7 |
import xgboost as xgb
|
8 |
from sklearn import neighbors
|
9 |
from sklearn.gaussian_process import GaussianProcessClassifier
|
10 |
from sklearn.gaussian_process.kernels import RBF
|
11 |
import numpy as np
|
12 |
-
|
|
|
13 |
from scripts import tabular_metrics
|
14 |
import pandas as pd
|
|
|
|
|
15 |
|
16 |
-
from sklearn.linear_model import LogisticRegression
|
17 |
from sklearn.model_selection import cross_val_score
|
18 |
import time
|
19 |
|
@@ -37,18 +52,28 @@ def get_scoring_direction(metric_used):
|
|
37 |
else:
|
38 |
raise Exception('No scoring string found for metric')
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def get_scoring_string(metric_used, multiclass=True, usage="sklearn_cv"):
|
41 |
if metric_used == tabular_metrics.auc_metric:
|
42 |
if usage == 'sklearn_cv':
|
43 |
return 'roc_auc_ovo'
|
44 |
elif usage == 'autogluon':
|
45 |
-
return 'log_loss' # Autogluon crashes when using 'roc_auc' with some datasets usning logloss gives better scores;
|
46 |
# We might be able to fix this, but doesn't work out of box.
|
47 |
# File bug report? Error happens with dataset robert and fabert
|
48 |
if multiclass:
|
49 |
return 'roc_auc_ovo_macro'
|
50 |
else:
|
51 |
return 'roc_auc'
|
|
|
|
|
52 |
elif usage == 'autosklearn':
|
53 |
if multiclass:
|
54 |
return autosklearn.metrics.log_loss # roc_auc only works for binary, use logloss instead
|
@@ -58,25 +83,72 @@ def get_scoring_string(metric_used, multiclass=True, usage="sklearn_cv"):
|
|
58 |
return 'MultiClass' # Effectively LogLoss, ROC not available
|
59 |
elif usage == 'xgb':
|
60 |
return 'logloss'
|
|
|
|
|
|
|
|
|
|
|
61 |
return 'roc_auc'
|
62 |
elif metric_used == tabular_metrics.cross_entropy:
|
63 |
if usage == 'sklearn_cv':
|
64 |
return 'neg_log_loss'
|
65 |
elif usage == 'autogluon':
|
66 |
return 'log_loss'
|
|
|
|
|
67 |
elif usage == 'autosklearn':
|
68 |
return autosklearn.metrics.log_loss
|
69 |
elif usage == 'catboost':
|
70 |
return 'MultiClass' # Effectively LogLoss
|
71 |
return 'logloss'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
else:
|
73 |
raise Exception('No scoring string found for metric')
|
74 |
|
75 |
def eval_f(params, clf_, x, y, metric_used, start_time, max_time):
|
76 |
if time.time() - start_time > max_time:
|
77 |
return np.nan
|
78 |
-
scores = cross_val_score(clf_(**params), x, y, cv=CV, scoring=get_scoring_string(metric_used))
|
79 |
-
|
|
|
80 |
return -np.nanmean(scores)
|
81 |
|
82 |
def preprocess_impute(x, y, test_x, test_y, impute, one_hot, standardize, cat_features=[]):
|
@@ -110,10 +182,26 @@ def preprocess_impute(x, y, test_x, test_y, impute, one_hot, standardize, cat_fe
|
|
110 |
x, test_x = scaler.transform(x), scaler.transform(test_x)
|
111 |
|
112 |
return x, y, test_x, test_y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
## Auto Gluon
|
|
|
115 |
def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
116 |
-
from autogluon.tabular import TabularPredictor
|
117 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
118 |
, one_hot=False
|
119 |
, cat_features=cat_features
|
@@ -121,12 +209,15 @@ def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=3
|
|
121 |
, standardize=False)
|
122 |
train_data = pd.DataFrame(np.concatenate([x, y[:, np.newaxis]], 1))
|
123 |
test_data = pd.DataFrame(np.concatenate([test_x, test_y[:, np.newaxis]], 1))
|
124 |
-
|
|
|
|
|
|
|
125 |
# AutoGluon automatically infers datatypes, we don't specify the categorical labels
|
126 |
predictor = TabularPredictor(
|
127 |
label=train_data.columns[-1],
|
128 |
eval_metric=get_scoring_string(metric_used, usage='autogluon', multiclass=(len(np.unique(y)) > 2)),
|
129 |
-
problem_type=
|
130 |
## seed=int(y[:].sum()) doesn't accept seed
|
131 |
).fit(
|
132 |
train_data=train_data,
|
@@ -135,19 +226,717 @@ def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=3
|
|
135 |
# The seed is deterministic but varies for each dataset and each split of it
|
136 |
)
|
137 |
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
metric = metric_used(test_y, pred)
|
141 |
|
142 |
return metric, pred, predictor.fit_summary()
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
## AUTO Sklearn
|
145 |
def autosklearn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
|
146 |
return autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=max_time, version=1)
|
147 |
|
148 |
-
from autosklearn.experimental.askl2 import AutoSklearn2Classifier
|
149 |
-
from autosklearn.classification import AutoSklearnClassifier
|
150 |
def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300, version=2):
|
|
|
|
|
|
|
151 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
152 |
, one_hot=False
|
153 |
, cat_features=cat_features
|
@@ -163,7 +952,12 @@ def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_tim
|
|
163 |
x = make_pd_from_np(x)
|
164 |
test_x = make_pd_from_np(test_x)
|
165 |
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
clf = clf_(time_left_for_this_task=max_time,
|
168 |
memory_limit=4000,
|
169 |
n_jobs=MULTITHREAD,
|
@@ -174,17 +968,141 @@ def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_tim
|
|
174 |
# fit model to data
|
175 |
clf.fit(x, y)
|
176 |
|
177 |
-
|
|
|
|
|
|
|
178 |
metric = metric_used(test_y, pred)
|
179 |
|
180 |
return metric, pred, None
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
param_grid_hyperopt['logistic'] = {
|
183 |
'penalty': hp.choice('penalty', ['l1', 'l2', 'none'])
|
184 |
-
, 'max_iter': hp.randint('max_iter',
|
185 |
, 'fit_intercept': hp.choice('fit_intercept', [True, False])
|
186 |
, 'C': hp.loguniform('C', -5, math.log(5.0))} # 'normalize': [False],
|
187 |
|
|
|
188 |
def logistic_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
189 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
190 |
, one_hot=True, impute=True, standardize=True
|
@@ -225,7 +1143,9 @@ def knn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
225 |
cat_features=cat_features)
|
226 |
|
227 |
def clf_(**params):
|
228 |
-
|
|
|
|
|
229 |
|
230 |
start_time = time.time()
|
231 |
|
@@ -245,7 +1165,10 @@ def knn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
245 |
clf = clf_(**best)
|
246 |
clf.fit(x, y)
|
247 |
|
248 |
-
|
|
|
|
|
|
|
249 |
metric = metric_used(test_y, pred)
|
250 |
|
251 |
return metric, pred, best
|
@@ -253,8 +1176,7 @@ def knn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
253 |
## GP
|
254 |
param_grid_hyperopt['gp'] = {
|
255 |
'params_y_scale': hp.loguniform('params_y_scale', math.log(0.05), math.log(5.0)),
|
256 |
-
'params_length_scale': hp.loguniform('params_length_scale', math.log(0.1), math.log(1.0))
|
257 |
-
'n_jobs': hp.choice('njobs', [1])
|
258 |
}
|
259 |
def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
260 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
|
@@ -262,7 +1184,10 @@ def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
262 |
cat_features=cat_features)
|
263 |
|
264 |
def clf_(params_y_scale,params_length_scale, **params):
|
265 |
-
|
|
|
|
|
|
|
266 |
|
267 |
start_time = time.time()
|
268 |
def stop(trial):
|
@@ -282,11 +1207,89 @@ def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
282 |
clf = clf_(**best)
|
283 |
clf.fit(x, y)
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
pred = clf.predict_proba(test_x)
|
286 |
metric = metric_used(test_y, pred)
|
287 |
|
288 |
return metric, pred, best
|
289 |
|
|
|
|
|
290 |
|
291 |
# Catboost
|
292 |
# Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
|
@@ -301,8 +1304,6 @@ param_grid_hyperopt['catboost'] = {
|
|
301 |
}
|
302 |
|
303 |
def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
304 |
-
print(x)
|
305 |
-
|
306 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
307 |
, one_hot=False
|
308 |
, cat_features=cat_features
|
@@ -323,14 +1324,24 @@ def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=30
|
|
323 |
test_x = make_pd_from_np(test_x)
|
324 |
|
325 |
def clf_(**params):
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
start_time = time.time()
|
336 |
def stop(trial):
|
@@ -348,8 +1359,10 @@ def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=30
|
|
348 |
|
349 |
clf = clf_(**best)
|
350 |
clf.fit(x, y)
|
351 |
-
|
352 |
-
|
|
|
|
|
353 |
metric = metric_used(test_y, pred)
|
354 |
|
355 |
return metric, pred, best
|
@@ -371,6 +1384,7 @@ param_grid_hyperopt['xgb'] = {
|
|
371 |
}
|
372 |
|
373 |
def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
|
374 |
# XGB Documentation:
|
375 |
# XGB handles categorical data appropriately without using One Hot Encoding, categorical features are experimetal
|
376 |
# XGB handles missing values appropriately without imputation
|
@@ -382,11 +1396,18 @@ def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
382 |
, standardize=False)
|
383 |
|
384 |
def clf_(**params):
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
start_time = time.time()
|
392 |
def stop(trial):
|
@@ -405,17 +1426,97 @@ def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
405 |
clf = clf_(**best)
|
406 |
clf.fit(x, y)
|
407 |
|
408 |
-
|
|
|
|
|
|
|
409 |
metric = metric_used(test_y, pred)
|
410 |
|
411 |
return metric, pred, best
|
412 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
clf_dict = {'gp': gp_metric
|
415 |
, 'knn': knn_metric
|
416 |
, 'catboost': catboost_metric
|
|
|
417 |
, 'xgb': xgb_metric
|
|
|
418 |
, 'logistic': logistic_metric
|
419 |
, 'autosklearn': autosklearn_metric
|
420 |
, 'autosklearn2': autosklearn2_metric
|
421 |
-
, 'autogluon': autogluon_metric
|
|
|
|
1 |
+
import pandas
|
2 |
from catboost import CatBoostClassifier, Pool
|
3 |
+
from sklearn.model_selection import GridSearchCV
|
4 |
+
from sklearn.model_selection import KFold
|
5 |
+
from sklearn.model_selection import ParameterGrid
|
6 |
|
7 |
+
import tempfile
|
8 |
+
import random
|
9 |
import math
|
10 |
+
import os
|
11 |
+
#from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
|
12 |
+
from sklearn import preprocessing
|
13 |
+
from torch import nn
|
14 |
+
from sklearn.metrics import make_scorer
|
15 |
from sklearn.impute import SimpleImputer
|
16 |
|
17 |
+
|
18 |
+
from sklearn.base import BaseEstimator, ClassifierMixin
|
19 |
import xgboost as xgb
|
20 |
from sklearn import neighbors
|
21 |
from sklearn.gaussian_process import GaussianProcessClassifier
|
22 |
from sklearn.gaussian_process.kernels import RBF
|
23 |
import numpy as np
|
24 |
+
import torch
|
25 |
+
import itertools
|
26 |
from scripts import tabular_metrics
|
27 |
import pandas as pd
|
28 |
+
from tqdm import tqdm
|
29 |
+
from utils import remove_outliers
|
30 |
|
31 |
+
from sklearn.linear_model import LogisticRegression, Ridge
|
32 |
from sklearn.model_selection import cross_val_score
|
33 |
import time
|
34 |
|
|
|
52 |
else:
|
53 |
raise Exception('No scoring string found for metric')
|
54 |
|
55 |
+
def is_classification(metric_used):
|
56 |
+
if metric_used == tabular_metrics.auc_metric or metric_used == tabular_metrics.cross_entropy:
|
57 |
+
return 'classification'
|
58 |
+
elif metric_used == tabular_metrics.auc_metric:
|
59 |
+
return -1
|
60 |
+
|
61 |
+
# Loss
|
62 |
+
|
63 |
def get_scoring_string(metric_used, multiclass=True, usage="sklearn_cv"):
|
64 |
if metric_used == tabular_metrics.auc_metric:
|
65 |
if usage == 'sklearn_cv':
|
66 |
return 'roc_auc_ovo'
|
67 |
elif usage == 'autogluon':
|
68 |
+
#return 'log_loss' # Autogluon crashes when using 'roc_auc' with some datasets usning logloss gives better scores;
|
69 |
# We might be able to fix this, but doesn't work out of box.
|
70 |
# File bug report? Error happens with dataset robert and fabert
|
71 |
if multiclass:
|
72 |
return 'roc_auc_ovo_macro'
|
73 |
else:
|
74 |
return 'roc_auc'
|
75 |
+
elif usage == 'tabnet':
|
76 |
+
return 'logloss' if multiclass else 'auc'
|
77 |
elif usage == 'autosklearn':
|
78 |
if multiclass:
|
79 |
return autosklearn.metrics.log_loss # roc_auc only works for binary, use logloss instead
|
|
|
83 |
return 'MultiClass' # Effectively LogLoss, ROC not available
|
84 |
elif usage == 'xgb':
|
85 |
return 'logloss'
|
86 |
+
elif usage == 'lightgbm':
|
87 |
+
if multiclass:
|
88 |
+
return 'auc'
|
89 |
+
else:
|
90 |
+
return 'binary'
|
91 |
return 'roc_auc'
|
92 |
elif metric_used == tabular_metrics.cross_entropy:
|
93 |
if usage == 'sklearn_cv':
|
94 |
return 'neg_log_loss'
|
95 |
elif usage == 'autogluon':
|
96 |
return 'log_loss'
|
97 |
+
elif usage == 'tabnet':
|
98 |
+
return 'logloss'
|
99 |
elif usage == 'autosklearn':
|
100 |
return autosklearn.metrics.log_loss
|
101 |
elif usage == 'catboost':
|
102 |
return 'MultiClass' # Effectively LogLoss
|
103 |
return 'logloss'
|
104 |
+
elif metric_used == tabular_metrics.r2_metric:
|
105 |
+
if usage == 'autosklearn':
|
106 |
+
return autosklearn.metrics.r2
|
107 |
+
elif usage == 'sklearn_cv':
|
108 |
+
return 'r2' # tabular_metrics.neg_r2
|
109 |
+
elif usage == 'autogluon':
|
110 |
+
return 'r2'
|
111 |
+
elif usage == 'xgb': # XGB cannot directly optimize r2
|
112 |
+
return 'rmse'
|
113 |
+
elif usage == 'catboost': # Catboost cannot directly optimize r2 ("Can't be used for optimization." - docu)
|
114 |
+
return 'RMSE'
|
115 |
+
else:
|
116 |
+
return 'r2'
|
117 |
+
elif metric_used == tabular_metrics.root_mean_squared_error_metric:
|
118 |
+
if usage == 'autosklearn':
|
119 |
+
return autosklearn.metrics.root_mean_squared_error
|
120 |
+
elif usage == 'sklearn_cv':
|
121 |
+
return 'neg_root_mean_squared_error' # tabular_metrics.neg_r2
|
122 |
+
elif usage == 'autogluon':
|
123 |
+
return 'rmse'
|
124 |
+
elif usage == 'xgb':
|
125 |
+
return 'rmse'
|
126 |
+
elif usage == 'catboost':
|
127 |
+
return 'RMSE'
|
128 |
+
else:
|
129 |
+
return 'neg_root_mean_squared_error'
|
130 |
+
elif metric_used == tabular_metrics.mean_absolute_error_metric:
|
131 |
+
if usage == 'autosklearn':
|
132 |
+
return autosklearn.metrics.mean_absolute_error
|
133 |
+
elif usage == 'sklearn_cv':
|
134 |
+
return 'neg_mean_absolute_error' # tabular_metrics.neg_r2
|
135 |
+
elif usage == 'autogluon':
|
136 |
+
return 'mae'
|
137 |
+
elif usage == 'xgb':
|
138 |
+
return 'mae'
|
139 |
+
elif usage == 'catboost':
|
140 |
+
return 'MAE'
|
141 |
+
else:
|
142 |
+
return 'neg_mean_absolute_error'
|
143 |
else:
|
144 |
raise Exception('No scoring string found for metric')
|
145 |
|
146 |
def eval_f(params, clf_, x, y, metric_used, start_time, max_time):
|
147 |
if time.time() - start_time > max_time:
|
148 |
return np.nan
|
149 |
+
scores = cross_val_score(clf_(**params), x, y, cv=CV, scoring=get_scoring_string(metric_used, usage='sklearn_cv'))
|
150 |
+
if get_scoring_string(metric_used, usage='sklearn_cv') == 'r2':
|
151 |
+
return np.nanmean(scores)
|
152 |
return -np.nanmean(scores)
|
153 |
|
154 |
def preprocess_impute(x, y, test_x, test_y, impute, one_hot, standardize, cat_features=[]):
|
|
|
182 |
x, test_x = scaler.transform(x), scaler.transform(test_x)
|
183 |
|
184 |
return x, y, test_x, test_y
|
185 |
+
import torch
|
186 |
+
import random
|
187 |
+
from tqdm import tqdm
|
188 |
+
def transformer_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
189 |
+
from scripts.transformer_prediction_interface import TabPFNClassifier
|
190 |
+
|
191 |
+
classifier = TabPFNClassifier(device='cpu', base_path='.',
|
192 |
+
model_string='')
|
193 |
+
classifier.fit(x, y)
|
194 |
+
print('Train data shape', x.shape, ' Test data shape', test_x.shape)
|
195 |
+
pred = classifier.predict_proba(test_x)
|
196 |
+
|
197 |
+
metric = metric_used(test_y, pred)
|
198 |
+
|
199 |
+
return metric, pred, None
|
200 |
|
201 |
## Auto Gluon
|
202 |
+
# WARNING: Crashes for some predictors for regression
|
203 |
def autogluon_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
204 |
+
from autogluon.tabular import TabularPredictor
|
205 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
206 |
, one_hot=False
|
207 |
, cat_features=cat_features
|
|
|
209 |
, standardize=False)
|
210 |
train_data = pd.DataFrame(np.concatenate([x, y[:, np.newaxis]], 1))
|
211 |
test_data = pd.DataFrame(np.concatenate([test_x, test_y[:, np.newaxis]], 1))
|
212 |
+
if is_classification(metric_used):
|
213 |
+
problem_type = 'multiclass' if len(np.unique(y)) > 2 else 'binary'
|
214 |
+
else:
|
215 |
+
problem_type = 'regression'
|
216 |
# AutoGluon automatically infers datatypes, we don't specify the categorical labels
|
217 |
predictor = TabularPredictor(
|
218 |
label=train_data.columns[-1],
|
219 |
eval_metric=get_scoring_string(metric_used, usage='autogluon', multiclass=(len(np.unique(y)) > 2)),
|
220 |
+
problem_type=problem_type
|
221 |
## seed=int(y[:].sum()) doesn't accept seed
|
222 |
).fit(
|
223 |
train_data=train_data,
|
|
|
226 |
# The seed is deterministic but varies for each dataset and each split of it
|
227 |
)
|
228 |
|
229 |
+
if is_classification(metric_used):
|
230 |
+
pred = predictor.predict_proba(test_data, as_multiclass=True).values
|
231 |
+
else:
|
232 |
+
pred = predictor.predict(test_data).values
|
233 |
+
|
234 |
+
metric = metric_used(test_y, pred)
|
235 |
+
|
236 |
+
return metric, pred, predictor.fit_summary()
|
237 |
+
|
238 |
+
|
239 |
+
from autogluon.core.models import AbstractModel
|
240 |
+
from scripts.transformer_prediction_interface import TabPFNClassifier
|
241 |
+
|
242 |
+
|
243 |
+
class TabPFNModel(AbstractModel):
|
244 |
+
def __init__(self, **kwargs):
|
245 |
+
# Simply pass along kwargs to parent, and init our internal `_feature_generator` variable to None
|
246 |
+
super().__init__(**kwargs)
|
247 |
+
|
248 |
+
# The `_preprocess` method takes the input data and transforms it to the internal representation usable by the model.
|
249 |
+
# `_preprocess` is called by `preprocess` and is used during model fit and model inference.
|
250 |
+
def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> np.ndarray:
|
251 |
+
return X
|
252 |
+
|
253 |
+
# The `_fit` method takes the input training data (and optionally the validation data) and trains the model.
|
254 |
+
def _fit(self,
|
255 |
+
X: pd.DataFrame, # training data
|
256 |
+
y: pd.Series, # training labels
|
257 |
+
# X_val=None, # val data (unused in RF model)
|
258 |
+
# y_val=None, # val labels (unused in RF model)
|
259 |
+
# time_limit=None, # time limit in seconds (ignored in tutorial)
|
260 |
+
**kwargs): # kwargs includes many other potential inputs, refer to AbstractModel documentation for details
|
261 |
+
|
262 |
+
self.model = TabPFNClassifier(device='cpu', base_path='/work/dlclarge1/hollmann-PFN_Tabular/',
|
263 |
+
model_string='_longer_multiclass_causal_05_02_2022_12_49_44_sams',
|
264 |
+
N_ensemble_configurations=10)
|
265 |
+
self.model.fit(X.to_numpy(), y.to_numpy())
|
266 |
+
|
267 |
+
def _predict_proba(self, X, **kwargs):
|
268 |
+
X = self.preprocess(X, **kwargs)
|
269 |
+
|
270 |
+
#if self.problem_type in [REGRESSION, QUANTILE]:
|
271 |
+
# y_pred = self.model.predict(X)
|
272 |
+
# return y_pred
|
273 |
+
|
274 |
+
y_pred_proba = self.model.predict_proba(X.to_numpy())
|
275 |
+
return super()._convert_proba_to_unified_form(y_pred_proba)
|
276 |
+
|
277 |
+
# The `_set_default_params` method defines the default hyperparameters of the model.
|
278 |
+
# User-specified parameters will override these values on a key-by-key basis.
|
279 |
+
def _set_default_params(self):
|
280 |
+
default_params = {
|
281 |
+
}
|
282 |
+
from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config
|
283 |
+
def autogluon_tabpfn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
284 |
+
from autogluon.tabular import TabularPredictor
|
285 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
286 |
+
, one_hot=False
|
287 |
+
, cat_features=cat_features
|
288 |
+
, impute=False
|
289 |
+
, standardize=False)
|
290 |
+
train_data = pd.DataFrame(np.concatenate([x, y[:, np.newaxis]], 1))
|
291 |
+
test_data = pd.DataFrame(np.concatenate([test_x, test_y[:, np.newaxis]], 1))
|
292 |
+
if is_classification(metric_used):
|
293 |
+
problem_type = 'multiclass' if len(np.unique(y)) > 2 else 'binary'
|
294 |
+
else:
|
295 |
+
problem_type = 'regression'
|
296 |
+
# AutoGluon automatically infers datatypes, we don't specify the categorical labels
|
297 |
+
custom_hyperparameters = {}#get_hyperparameter_config('default')
|
298 |
+
custom_hyperparameters[TabPFNModel] = {}
|
299 |
+
predictor = TabularPredictor(
|
300 |
+
label=train_data.columns[-1],
|
301 |
+
eval_metric=get_scoring_string(metric_used, usage='autogluon', multiclass=(len(np.unique(y)) > 2)),
|
302 |
+
problem_type=problem_type
|
303 |
+
## seed=int(y[:].sum()) doesn't accept seed
|
304 |
+
).fit(
|
305 |
+
train_data=train_data,
|
306 |
+
time_limit=max_time,
|
307 |
+
presets=['best_quality'],
|
308 |
+
hyperparameters=custom_hyperparameters
|
309 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
310 |
+
)
|
311 |
+
|
312 |
+
if is_classification(metric_used):
|
313 |
+
pred = predictor.predict_proba(test_data, as_multiclass=True).values
|
314 |
+
else:
|
315 |
+
pred = predictor.predict(test_data).values
|
316 |
|
317 |
metric = metric_used(test_y, pred)
|
318 |
|
319 |
return metric, pred, predictor.fit_summary()
|
320 |
|
321 |
+
def get_updates_for_regularization_cocktails(
|
322 |
+
categorical_indicator: np.ndarray):
|
323 |
+
"""
|
324 |
+
These updates replicate the regularization cocktail paper search space.
|
325 |
+
Args:
|
326 |
+
categorical_indicator (np.ndarray)
|
327 |
+
An array that indicates whether a feature is categorical or not.
|
328 |
+
args (Namespace):
|
329 |
+
The different updates for the setup of the run, mostly updates
|
330 |
+
for the different regularization ingredients.
|
331 |
+
Returns:
|
332 |
+
________
|
333 |
+
pipeline_update, search_space_updates, include_updates (Tuple[dict, HyperparameterSearchSpaceUpdates, dict]):
|
334 |
+
The pipeline updates like number of epochs, budget, seed etc.
|
335 |
+
The search space updates like setting different hps to different values or ranges.
|
336 |
+
Lastly include updates, which can be used to include different features.
|
337 |
+
"""
|
338 |
+
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
|
339 |
+
import argparse
|
340 |
+
|
341 |
+
augmentation_names_to_trainers = {
|
342 |
+
'mixup': 'MixUpTrainer',
|
343 |
+
'cutout': 'RowCutOutTrainer',
|
344 |
+
'cutmix': 'RowCutMixTrainer',
|
345 |
+
'standard': 'StandardTrainer',
|
346 |
+
'adversarial': 'AdversarialTrainer',
|
347 |
+
}
|
348 |
+
|
349 |
+
include_updates = dict()
|
350 |
+
include_updates['network_embedding'] = ['NoEmbedding']
|
351 |
+
include_updates['network_init'] = ['NoInit']
|
352 |
+
|
353 |
+
has_cat_features = any(categorical_indicator)
|
354 |
+
has_numerical_features = not all(categorical_indicator)
|
355 |
+
|
356 |
+
def str2bool(v):
|
357 |
+
if isinstance(v, bool):
|
358 |
+
return [v, ]
|
359 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
360 |
+
return [True, ]
|
361 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
362 |
+
return [False, ]
|
363 |
+
elif v.lower() == 'conditional':
|
364 |
+
return [True, False]
|
365 |
+
else:
|
366 |
+
raise ValueError('No valid value given.')
|
367 |
+
search_space_updates = HyperparameterSearchSpaceUpdates()
|
368 |
+
|
369 |
+
# architecture head
|
370 |
+
search_space_updates.append(
|
371 |
+
node_name='network_head',
|
372 |
+
hyperparameter='__choice__',
|
373 |
+
value_range=['no_head'],
|
374 |
+
default_value='no_head',
|
375 |
+
)
|
376 |
+
search_space_updates.append(
|
377 |
+
node_name='network_head',
|
378 |
+
hyperparameter='no_head:activation',
|
379 |
+
value_range=['relu'],
|
380 |
+
default_value='relu',
|
381 |
+
)
|
382 |
+
|
383 |
+
# backbone architecture
|
384 |
+
search_space_updates.append(
|
385 |
+
node_name='network_backbone',
|
386 |
+
hyperparameter='__choice__',
|
387 |
+
value_range=['ShapedResNetBackbone'],
|
388 |
+
default_value='ShapedResNetBackbone',
|
389 |
+
)
|
390 |
+
search_space_updates.append(
|
391 |
+
node_name='network_backbone',
|
392 |
+
hyperparameter='ShapedResNetBackbone:resnet_shape',
|
393 |
+
value_range=['brick'],
|
394 |
+
default_value='brick',
|
395 |
+
)
|
396 |
+
search_space_updates.append(
|
397 |
+
node_name='network_backbone',
|
398 |
+
hyperparameter='ShapedResNetBackbone:num_groups',
|
399 |
+
value_range=[2],
|
400 |
+
default_value=2,
|
401 |
+
)
|
402 |
+
search_space_updates.append(
|
403 |
+
node_name='network_backbone',
|
404 |
+
hyperparameter='ShapedResNetBackbone:blocks_per_group',
|
405 |
+
value_range=[2],
|
406 |
+
default_value=2,
|
407 |
+
)
|
408 |
+
search_space_updates.append(
|
409 |
+
node_name='network_backbone',
|
410 |
+
hyperparameter='ShapedResNetBackbone:output_dim',
|
411 |
+
value_range=[512],
|
412 |
+
default_value=512,
|
413 |
+
)
|
414 |
+
search_space_updates.append(
|
415 |
+
node_name='network_backbone',
|
416 |
+
hyperparameter='ShapedResNetBackbone:max_units',
|
417 |
+
value_range=[512],
|
418 |
+
default_value=512,
|
419 |
+
)
|
420 |
+
search_space_updates.append(
|
421 |
+
node_name='network_backbone',
|
422 |
+
hyperparameter='ShapedResNetBackbone:activation',
|
423 |
+
value_range=['relu'],
|
424 |
+
default_value='relu',
|
425 |
+
)
|
426 |
+
search_space_updates.append(
|
427 |
+
node_name='network_backbone',
|
428 |
+
hyperparameter='ShapedResNetBackbone:shake_shake_update_func',
|
429 |
+
value_range=['even-even'],
|
430 |
+
default_value='even-even',
|
431 |
+
)
|
432 |
+
|
433 |
+
# training updates
|
434 |
+
search_space_updates.append(
|
435 |
+
node_name='lr_scheduler',
|
436 |
+
hyperparameter='__choice__',
|
437 |
+
value_range=['CosineAnnealingWarmRestarts'],
|
438 |
+
default_value='CosineAnnealingWarmRestarts',
|
439 |
+
)
|
440 |
+
search_space_updates.append(
|
441 |
+
node_name='lr_scheduler',
|
442 |
+
hyperparameter='CosineAnnealingWarmRestarts:n_restarts',
|
443 |
+
value_range=[3],
|
444 |
+
default_value=3,
|
445 |
+
)
|
446 |
+
search_space_updates.append(
|
447 |
+
node_name='optimizer',
|
448 |
+
hyperparameter='__choice__',
|
449 |
+
value_range=['AdamWOptimizer'],
|
450 |
+
default_value='AdamWOptimizer',
|
451 |
+
)
|
452 |
+
search_space_updates.append(
|
453 |
+
node_name='optimizer',
|
454 |
+
hyperparameter='AdamWOptimizer:lr',
|
455 |
+
value_range=[1e-3],
|
456 |
+
default_value=1e-3,
|
457 |
+
)
|
458 |
+
search_space_updates.append(
|
459 |
+
node_name='data_loader',
|
460 |
+
hyperparameter='batch_size',
|
461 |
+
value_range=[128],
|
462 |
+
default_value=128,
|
463 |
+
)
|
464 |
+
|
465 |
+
# preprocessing
|
466 |
+
search_space_updates.append(
|
467 |
+
node_name='feature_preprocessor',
|
468 |
+
hyperparameter='__choice__',
|
469 |
+
value_range=['NoFeaturePreprocessor'],
|
470 |
+
default_value='NoFeaturePreprocessor',
|
471 |
+
)
|
472 |
+
|
473 |
+
if has_numerical_features:
|
474 |
+
print('has numerical features')
|
475 |
+
search_space_updates.append(
|
476 |
+
node_name='imputer',
|
477 |
+
hyperparameter='numerical_strategy',
|
478 |
+
value_range=['median'],
|
479 |
+
default_value='median',
|
480 |
+
)
|
481 |
+
search_space_updates.append(
|
482 |
+
node_name='scaler',
|
483 |
+
hyperparameter='__choice__',
|
484 |
+
value_range=['StandardScaler'],
|
485 |
+
default_value='StandardScaler',
|
486 |
+
)
|
487 |
+
|
488 |
+
if has_cat_features:
|
489 |
+
print('has cat features')
|
490 |
+
search_space_updates.append(
|
491 |
+
node_name='imputer',
|
492 |
+
hyperparameter='categorical_strategy',
|
493 |
+
value_range=['constant_!missing!'],
|
494 |
+
default_value='constant_!missing!',
|
495 |
+
)
|
496 |
+
search_space_updates.append(
|
497 |
+
node_name='encoder',
|
498 |
+
hyperparameter='__choice__',
|
499 |
+
value_range=['OneHotEncoder'],
|
500 |
+
default_value='OneHotEncoder',
|
501 |
+
)
|
502 |
+
|
503 |
+
search_space_updates.append(
|
504 |
+
node_name='optimizer',
|
505 |
+
hyperparameter='AdamWOptimizer:beta1',
|
506 |
+
value_range=[0.9],
|
507 |
+
default_value=0.9,
|
508 |
+
)
|
509 |
+
search_space_updates.append(
|
510 |
+
node_name='optimizer',
|
511 |
+
hyperparameter='AdamWOptimizer:beta2',
|
512 |
+
value_range=[0.999],
|
513 |
+
default_value=0.999,
|
514 |
+
)
|
515 |
+
|
516 |
+
|
517 |
+
parser = argparse.ArgumentParser(
|
518 |
+
description='Run AutoPyTorch on a benchmark.',
|
519 |
+
)
|
520 |
+
# experiment setup arguments
|
521 |
+
parser.add_argument(
|
522 |
+
'--task_id',
|
523 |
+
type=int,
|
524 |
+
default=233088,
|
525 |
+
)
|
526 |
+
parser.add_argument(
|
527 |
+
'--wall_time',
|
528 |
+
type=int,
|
529 |
+
default=9000,
|
530 |
+
)
|
531 |
+
parser.add_argument(
|
532 |
+
'--func_eval_time',
|
533 |
+
type=int,
|
534 |
+
default=1000,
|
535 |
+
)
|
536 |
+
parser.add_argument(
|
537 |
+
'--epochs',
|
538 |
+
type=int,
|
539 |
+
default=105,
|
540 |
+
)
|
541 |
+
parser.add_argument(
|
542 |
+
'--seed',
|
543 |
+
type=int,
|
544 |
+
default=11,
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
'--tmp_dir',
|
548 |
+
type=str,
|
549 |
+
default='./runs/autoPyTorch_cocktails',
|
550 |
+
)
|
551 |
+
parser.add_argument(
|
552 |
+
'--output_dir',
|
553 |
+
type=str,
|
554 |
+
default='./runs/autoPyTorch_cocktails',
|
555 |
+
)
|
556 |
+
parser.add_argument(
|
557 |
+
'--nr_workers',
|
558 |
+
type=int,
|
559 |
+
default=1,
|
560 |
+
)
|
561 |
+
parser.add_argument(
|
562 |
+
'--nr_threads',
|
563 |
+
type=int,
|
564 |
+
default=1,
|
565 |
+
)
|
566 |
+
parser.add_argument(
|
567 |
+
'--cash_cocktail',
|
568 |
+
help='If the regularization cocktail should be used.',
|
569 |
+
type=bool,
|
570 |
+
default=False,
|
571 |
+
)
|
572 |
+
|
573 |
+
# regularization ingredient arguments
|
574 |
+
parser.add_argument(
|
575 |
+
'--use_swa',
|
576 |
+
help='If stochastic weight averaging should be used.',
|
577 |
+
type=str2bool,
|
578 |
+
nargs='?',
|
579 |
+
const=[True],
|
580 |
+
default=[False],
|
581 |
+
)
|
582 |
+
parser.add_argument(
|
583 |
+
'--use_se',
|
584 |
+
help='If snapshot ensembling should be used.',
|
585 |
+
type=str2bool,
|
586 |
+
nargs='?',
|
587 |
+
const=[True],
|
588 |
+
default=[False],
|
589 |
+
)
|
590 |
+
parser.add_argument(
|
591 |
+
'--use_lookahead',
|
592 |
+
help='If the lookahead optimizing technique should be used.',
|
593 |
+
type=str2bool,
|
594 |
+
nargs='?',
|
595 |
+
const=[True],
|
596 |
+
default=[False],
|
597 |
+
)
|
598 |
+
parser.add_argument(
|
599 |
+
'--use_weight_decay',
|
600 |
+
help='If weight decay regularization should be used.',
|
601 |
+
type=str2bool,
|
602 |
+
nargs='?',
|
603 |
+
const=[True],
|
604 |
+
default=[False],
|
605 |
+
)
|
606 |
+
parser.add_argument(
|
607 |
+
'--use_batch_normalization',
|
608 |
+
help='If batch normalization regularization should be used.',
|
609 |
+
type=str2bool,
|
610 |
+
nargs='?',
|
611 |
+
const=[True],
|
612 |
+
default=[False],
|
613 |
+
)
|
614 |
+
parser.add_argument(
|
615 |
+
'--use_skip_connection',
|
616 |
+
help='If skip connections should be used. '
|
617 |
+
'Turns the network into a residual network.',
|
618 |
+
type=str2bool,
|
619 |
+
nargs='?',
|
620 |
+
const=[True],
|
621 |
+
default=[False],
|
622 |
+
)
|
623 |
+
parser.add_argument(
|
624 |
+
'--use_dropout',
|
625 |
+
help='If dropout regularization should be used.',
|
626 |
+
type=str2bool,
|
627 |
+
nargs='?',
|
628 |
+
const=[True],
|
629 |
+
default=[False],
|
630 |
+
)
|
631 |
+
parser.add_argument(
|
632 |
+
'--mb_choice',
|
633 |
+
help='Multibranch network regularization. '
|
634 |
+
'Only active when skip_connection is active.',
|
635 |
+
type=str,
|
636 |
+
choices=['none', 'shake-shake', 'shake-drop'],
|
637 |
+
default='none',
|
638 |
+
)
|
639 |
+
parser.add_argument(
|
640 |
+
'--augmentation',
|
641 |
+
help='If methods that augment examples should be used',
|
642 |
+
type=str,
|
643 |
+
choices=['mixup', 'cutout', 'cutmix', 'standard', 'adversarial'],
|
644 |
+
default='standard',
|
645 |
+
)
|
646 |
+
|
647 |
+
args = parser.parse_args([]) # just get default values
|
648 |
+
|
649 |
+
|
650 |
+
|
651 |
+
# if the cash formulation of the cocktail is not activated,
|
652 |
+
# otherwise the methods activation will be chosen by the SMBO optimizer.
|
653 |
+
|
654 |
+
|
655 |
+
# No early stopping and train on gpu
|
656 |
+
pipeline_update = {
|
657 |
+
'early_stopping': -1,
|
658 |
+
'min_epochs': args.epochs,
|
659 |
+
'epochs': args.epochs,
|
660 |
+
"device": 'cpu',
|
661 |
+
}
|
662 |
+
|
663 |
+
return pipeline_update, search_space_updates, include_updates
|
664 |
+
|
665 |
+
def get_smac_object(
|
666 |
+
scenario_dict,
|
667 |
+
seed: int,
|
668 |
+
ta,
|
669 |
+
ta_kwargs,
|
670 |
+
n_jobs: int,
|
671 |
+
initial_budget: int,
|
672 |
+
max_budget: int,
|
673 |
+
dask_client,
|
674 |
+
):
|
675 |
+
"""
|
676 |
+
This function returns an SMAC object that is gonna be used as
|
677 |
+
optimizer of pipelines.
|
678 |
+
Args:
|
679 |
+
scenario_dict (typing.Dict[str, typing.Any]): constrain on how to run
|
680 |
+
the jobs.
|
681 |
+
seed (int): to make the job deterministic.
|
682 |
+
ta (typing.Callable): the function to be intensified by smac.
|
683 |
+
ta_kwargs (typing.Dict[str, typing.Any]): Arguments to the above ta.
|
684 |
+
n_jobs (int): Amount of cores to use for this task.
|
685 |
+
initial_budget (int):
|
686 |
+
The initial budget for a configuration.
|
687 |
+
max_budget (int):
|
688 |
+
The maximal budget for a configuration.
|
689 |
+
dask_client (dask.distributed.Client): User provided scheduler.
|
690 |
+
Returns:
|
691 |
+
(SMAC4AC): sequential model algorithm configuration object
|
692 |
+
"""
|
693 |
+
from smac.intensification.simple_intensifier import SimpleIntensifier
|
694 |
+
from smac.runhistory.runhistory2epm import RunHistory2EPM4LogCost
|
695 |
+
from smac.scenario.scenario import Scenario
|
696 |
+
from smac.facade.smac_ac_facade import SMAC4AC
|
697 |
+
# multi-fidelity is disabled, that is why initial_budget and max_budget
|
698 |
+
# are not used.
|
699 |
+
rh2EPM = RunHistory2EPM4LogCost
|
700 |
+
|
701 |
+
return SMAC4AC(
|
702 |
+
scenario=Scenario(scenario_dict),
|
703 |
+
rng=seed,
|
704 |
+
runhistory2epm=rh2EPM,
|
705 |
+
tae_runner=ta,
|
706 |
+
tae_runner_kwargs=ta_kwargs,
|
707 |
+
initial_configurations=None,
|
708 |
+
run_id=seed,
|
709 |
+
intensifier=SimpleIntensifier,
|
710 |
+
dask_client=dask_client,
|
711 |
+
n_jobs=n_jobs,
|
712 |
+
)
|
713 |
+
|
714 |
+
|
715 |
+
def get_incumbent_results(
|
716 |
+
run_history_file: str,
|
717 |
+
search_space
|
718 |
+
):
|
719 |
+
"""
|
720 |
+
Get the incumbent configuration and performance from the previous run HPO
|
721 |
+
search with AutoPytorch.
|
722 |
+
Args:
|
723 |
+
run_history_file (str):
|
724 |
+
The path where the AutoPyTorch search data is located.
|
725 |
+
search_space (ConfigSpace.ConfigurationSpace):
|
726 |
+
The ConfigurationSpace that was previously used for the HPO
|
727 |
+
search space.
|
728 |
+
Returns:
|
729 |
+
config, incumbent_run_value (Tuple[ConfigSpace.Configuration, float]):
|
730 |
+
The incumbent configuration found from HPO search and the validation
|
731 |
+
performance it achieved.
|
732 |
+
"""
|
733 |
+
from smac.runhistory.runhistory import RunHistory
|
734 |
+
run_history = RunHistory()
|
735 |
+
run_history.load_json(
|
736 |
+
run_history_file,
|
737 |
+
search_space,
|
738 |
+
)
|
739 |
+
|
740 |
+
run_history_data = run_history.data
|
741 |
+
sorted_runvalue_by_cost = sorted(run_history_data.items(), key=lambda item: item[1].cost)
|
742 |
+
incumbent_run_key, incumbent_run_value = sorted_runvalue_by_cost[0]
|
743 |
+
config = run_history.ids_config[incumbent_run_key.config_id]
|
744 |
+
return config, incumbent_run_value
|
745 |
+
|
746 |
+
|
747 |
+
def well_tuned_simple_nets_metric(X_train, y_train, X_test, y_test, categorical_indicator, metric_used, max_time=300, nr_workers=1):
|
748 |
+
"""Install:
|
749 |
+
git clone https://github.com/automl/Auto-PyTorch.git
|
750 |
+
cd Auto-PyTorch
|
751 |
+
git checkout regularization_cocktails
|
752 |
+
From the page, not needed for me at least: conda install gxx_linux-64 gcc_linux-64 swig
|
753 |
+
conda create --clone CONDANAME --name CLONENAME
|
754 |
+
conda activate CLONENAME
|
755 |
+
pip install -r requirements.txt (I checked looks like nothing should break functionality of our project not sure about baselines, thus a copied env is likely good :))
|
756 |
+
pip install -e .
|
757 |
+
"""
|
758 |
+
#os.environ.get('SLURM_JOBID', '')
|
759 |
+
categorical_indicator = np.array([i in categorical_indicator for i in range(X_train.shape[1])])
|
760 |
+
with tempfile.TemporaryDirectory(prefix=f"{len(X_train)}_{len(X_test)}_{max_time}") as temp_dir:
|
761 |
+
from autoPyTorch.api.tabular_classification import TabularClassificationTask
|
762 |
+
from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes, NoResamplingStrategyTypes
|
763 |
+
from autoPyTorch.data.tabular_validator import TabularInputValidator
|
764 |
+
from autoPyTorch.datasets.tabular_dataset import TabularDataset
|
765 |
+
from autoPyTorch import metrics
|
766 |
+
# append random folder to temp_dir to avoid collisions
|
767 |
+
rand_int = str(random.randint(1,1000))
|
768 |
+
temp_dir = os.path.join(temp_dir, 'temp_'+rand_int)
|
769 |
+
out_dir = os.path.join(temp_dir, 'out_'+rand_int)
|
770 |
+
|
771 |
+
start_time = time.time()
|
772 |
+
|
773 |
+
X_train, y_train, X_test, y_test = X_train.cpu().numpy(), y_train.cpu().long().numpy(), X_test.cpu().numpy(), y_test.cpu().long().numpy()
|
774 |
+
|
775 |
+
def safe_int(x):
|
776 |
+
assert np.all(x.astype('int64') == x) or np.any(x != x), np.unique(x) # second condition for ignoring nans
|
777 |
+
return pd.Series(x, dtype='category')
|
778 |
+
|
779 |
+
X_train = pd.DataFrame({i: safe_int(X_train[:,i]) if c else X_train[:,i] for i, c in enumerate(categorical_indicator)})
|
780 |
+
X_test = pd.DataFrame({i: safe_int(X_test[:,i]) if c else X_test[:,i] for i, c in enumerate(categorical_indicator)})
|
781 |
+
|
782 |
+
|
783 |
+
if isinstance(y_train[1], bool):
|
784 |
+
y_train = y_train.astype('bool')
|
785 |
+
if isinstance(y_test[1], bool):
|
786 |
+
y_test = y_test.astype('bool')
|
787 |
+
|
788 |
+
number_of_configurations_limit = 840 # hard coded in the paper
|
789 |
+
epochs = 105
|
790 |
+
func_eval_time = min(1000, max_time/2)
|
791 |
+
seed = int(y_train[:].sum())
|
792 |
+
|
793 |
+
|
794 |
+
resampling_strategy_args = {
|
795 |
+
'val_share': len(y_test)/(len(y_test)+len(y_train)),
|
796 |
+
}
|
797 |
+
|
798 |
+
pipeline_update, search_space_updates, include_updates = get_updates_for_regularization_cocktails(
|
799 |
+
categorical_indicator,
|
800 |
+
)
|
801 |
+
print(search_space_updates)
|
802 |
+
|
803 |
+
|
804 |
+
|
805 |
+
############################################################################
|
806 |
+
# Build and fit a classifier
|
807 |
+
# ==========================
|
808 |
+
# if we use HPO, we can use multiple workers in parallel
|
809 |
+
if number_of_configurations_limit == 0:
|
810 |
+
nr_workers = 1
|
811 |
+
|
812 |
+
api = TabularClassificationTask(
|
813 |
+
temporary_directory=temp_dir,
|
814 |
+
output_directory=out_dir,
|
815 |
+
delete_tmp_folder_after_terminate=False,
|
816 |
+
delete_output_folder_after_terminate=False,
|
817 |
+
resampling_strategy=HoldoutValTypes.stratified_holdout_validation,
|
818 |
+
resampling_strategy_args=resampling_strategy_args,
|
819 |
+
ensemble_size=1,
|
820 |
+
ensemble_nbest=1,
|
821 |
+
max_models_on_disc=10,
|
822 |
+
include_components=include_updates,
|
823 |
+
search_space_updates=search_space_updates,
|
824 |
+
seed=seed,
|
825 |
+
n_jobs=nr_workers,
|
826 |
+
n_threads=1,
|
827 |
+
)
|
828 |
+
|
829 |
+
api.set_pipeline_config(**pipeline_update)
|
830 |
+
############################################################################
|
831 |
+
# Search for the best hp configuration
|
832 |
+
# ====================================
|
833 |
+
# We search for the best hp configuration only in the case of a cocktail ingredient
|
834 |
+
# that has hyperparameters.
|
835 |
+
print(X_train, X_test)
|
836 |
+
print('temp_dir',temp_dir)
|
837 |
+
print(max_time, min(func_eval_time, max_time, number_of_configurations_limit))
|
838 |
+
|
839 |
+
if number_of_configurations_limit != 0:
|
840 |
+
api.search(
|
841 |
+
X_train=X_train.copy(),
|
842 |
+
y_train=y_train.copy(),
|
843 |
+
X_test=X_test.copy(),
|
844 |
+
y_test=y_test.copy(),
|
845 |
+
optimize_metric='balanced_accuracy',
|
846 |
+
total_walltime_limit=max_time,
|
847 |
+
memory_limit=12000,
|
848 |
+
func_eval_time_limit_secs=min(func_eval_time, max_time),
|
849 |
+
enable_traditional_pipeline=False,
|
850 |
+
get_smac_object_callback=get_smac_object,
|
851 |
+
smac_scenario_args={
|
852 |
+
'runcount_limit': number_of_configurations_limit,
|
853 |
+
},
|
854 |
+
)
|
855 |
+
|
856 |
+
############################################################################
|
857 |
+
# Refit on the best hp configuration
|
858 |
+
# ==================================
|
859 |
+
input_validator = TabularInputValidator(
|
860 |
+
is_classification=True,
|
861 |
+
)
|
862 |
+
input_validator.fit(
|
863 |
+
X_train=X_train.copy(),
|
864 |
+
y_train=y_train.copy(),
|
865 |
+
X_test=X_test.copy(),
|
866 |
+
y_test=y_test.copy(),
|
867 |
+
)
|
868 |
+
|
869 |
+
dataset = TabularDataset(
|
870 |
+
X=X_train,
|
871 |
+
Y=y_train,
|
872 |
+
X_test=X_test,
|
873 |
+
Y_test=y_test,
|
874 |
+
seed=seed,
|
875 |
+
validator=input_validator,
|
876 |
+
resampling_strategy=NoResamplingStrategyTypes.no_resampling,
|
877 |
+
)
|
878 |
+
dataset.is_small_preprocess = False
|
879 |
+
print(f"Fitting pipeline with {epochs} epochs")
|
880 |
+
|
881 |
+
search_space = api.get_search_space(dataset)
|
882 |
+
# only when we perform hpo will there be an incumbent configuration
|
883 |
+
# otherwise take a default configuration.
|
884 |
+
if number_of_configurations_limit != 0:
|
885 |
+
configuration, incumbent_run_value = get_incumbent_results(
|
886 |
+
os.path.join(
|
887 |
+
temp_dir,
|
888 |
+
'smac3-output',
|
889 |
+
'run_{}'.format(seed),
|
890 |
+
'runhistory.json'),
|
891 |
+
search_space,
|
892 |
+
)
|
893 |
+
print(f"Incumbent configuration: {configuration}")
|
894 |
+
print(f"Incumbent trajectory: {api.trajectory}")
|
895 |
+
else:
|
896 |
+
# default configuration
|
897 |
+
configuration = search_space.get_default_configuration()
|
898 |
+
print(f"Default configuration: {configuration}")
|
899 |
+
|
900 |
+
fitted_pipeline, run_info, run_value, dataset = api.fit_pipeline(
|
901 |
+
configuration=configuration,
|
902 |
+
budget_type='epochs',
|
903 |
+
budget=epochs,
|
904 |
+
dataset=dataset,
|
905 |
+
run_time_limit_secs=func_eval_time,
|
906 |
+
eval_metric='balanced_accuracy',
|
907 |
+
memory_limit=12000,
|
908 |
+
)
|
909 |
+
|
910 |
+
X_train = dataset.train_tensors[0]
|
911 |
+
y_train = dataset.train_tensors[1]
|
912 |
+
X_test = dataset.test_tensors[0]
|
913 |
+
y_test = dataset.test_tensors[1]
|
914 |
+
|
915 |
+
if fitted_pipeline is None:
|
916 |
+
api.get_incumbent_config
|
917 |
+
|
918 |
+
|
919 |
+
train_predictions = fitted_pipeline.predict(X_train)
|
920 |
+
test_predictions = fitted_pipeline.predict(X_test)
|
921 |
+
|
922 |
+
metric = metric_used(y_test, test_predictions.squeeze())
|
923 |
+
duration = time.time() - start_time
|
924 |
+
|
925 |
+
print(f'Time taken: {duration} for {metric} metric')
|
926 |
+
print(test_predictions[:10])
|
927 |
+
return metric, test_predictions, None
|
928 |
+
|
929 |
+
|
930 |
+
|
931 |
## AUTO Sklearn
|
932 |
def autosklearn_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
933 |
+
import autosklearn.classification
|
934 |
return autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=max_time, version=1)
|
935 |
|
|
|
|
|
936 |
def autosklearn2_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300, version=2):
|
937 |
+
from autosklearn.experimental.askl2 import AutoSklearn2Classifier
|
938 |
+
from autosklearn.classification import AutoSklearnClassifier
|
939 |
+
from autosklearn.regression import AutoSklearnRegressor
|
940 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
941 |
, one_hot=False
|
942 |
, cat_features=cat_features
|
|
|
952 |
x = make_pd_from_np(x)
|
953 |
test_x = make_pd_from_np(test_x)
|
954 |
|
955 |
+
if is_classification(metric_used):
|
956 |
+
clf_ = AutoSklearn2Classifier if version == 2 else AutoSklearnClassifier
|
957 |
+
else:
|
958 |
+
if version == 2:
|
959 |
+
raise Exception("AutoSklearn 2 doesn't do regression.")
|
960 |
+
clf_ = AutoSklearnRegressor
|
961 |
clf = clf_(time_left_for_this_task=max_time,
|
962 |
memory_limit=4000,
|
963 |
n_jobs=MULTITHREAD,
|
|
|
968 |
# fit model to data
|
969 |
clf.fit(x, y)
|
970 |
|
971 |
+
if is_classification(metric_used):
|
972 |
+
pred = clf.predict_proba(test_x)
|
973 |
+
else:
|
974 |
+
pred = clf.predict(test_x)
|
975 |
metric = metric_used(test_y, pred)
|
976 |
|
977 |
return metric, pred, None
|
978 |
|
979 |
+
param_grid_hyperopt['ridge'] = {
|
980 |
+
'max_iter': hp.randint('max_iter', 50, 500)
|
981 |
+
, 'fit_intercept': hp.choice('fit_intercept', [True, False])
|
982 |
+
, 'alpha': hp.loguniform('alpha', -5, math.log(5.0))} # 'normalize': [False],
|
983 |
+
|
984 |
+
def ridge_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
985 |
+
if is_classification(metric_used):
|
986 |
+
raise Exception("Ridge is only applicable to pointwise Regression.")
|
987 |
+
|
988 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
989 |
+
, one_hot=True, impute=True, standardize=True
|
990 |
+
, cat_features=cat_features)
|
991 |
+
def clf_(**params):
|
992 |
+
return Ridge(tol=1e-4, **params)
|
993 |
+
|
994 |
+
start_time = time.time()
|
995 |
+
|
996 |
+
def stop(trial):
|
997 |
+
return time.time() - start_time > max_time, []
|
998 |
+
|
999 |
+
best = fmin(
|
1000 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
1001 |
+
space=param_grid_hyperopt['ridge'],
|
1002 |
+
algo=rand.suggest,
|
1003 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
1004 |
+
early_stop_fn=stop,
|
1005 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
1006 |
+
max_evals=10000)
|
1007 |
+
best = space_eval(param_grid_hyperopt['ridge'], best)
|
1008 |
+
|
1009 |
+
clf = clf_(**best)
|
1010 |
+
clf.fit(x, y)
|
1011 |
+
|
1012 |
+
pred = clf.predict(test_x)
|
1013 |
+
metric = metric_used(test_y, pred)
|
1014 |
+
|
1015 |
+
return metric, pred, best
|
1016 |
+
|
1017 |
+
from lightautoml.automl.presets.tabular_presets import TabularAutoML, TabularUtilizedAutoML
|
1018 |
+
from lightautoml.tasks import Task
|
1019 |
+
|
1020 |
+
def lightautoml_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1021 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
1022 |
+
, one_hot=False, impute=False, standardize=False
|
1023 |
+
, cat_features=cat_features)
|
1024 |
+
|
1025 |
+
roles = {'target': str(x.shape[-1])}
|
1026 |
+
task = Task('multiclass', metric = lambda x, y : metric_used(x, y, numpy=True))
|
1027 |
+
automl = TabularUtilizedAutoML(task=task,
|
1028 |
+
timeout=max_time,
|
1029 |
+
cpu_limit=4, # Optimal for Kaggle kernels
|
1030 |
+
general_params={'use_algos': [['linear_l2',
|
1031 |
+
'lgb', 'lgb_tuned']]})
|
1032 |
+
|
1033 |
+
tr_data = np.concatenate([x, np.expand_dims(y, -1)], -1)
|
1034 |
+
tr_data = pd.DataFrame(tr_data, columns=[str(k) for k in range(0, x.shape[-1] + 1)])
|
1035 |
+
oof_pred = automl.fit_predict(tr_data, roles=roles)
|
1036 |
+
te_data = pd.DataFrame(test_x, columns=[str(k) for k in range(0, x.shape[-1])])
|
1037 |
+
|
1038 |
+
probabilities = automl.predict(te_data).data
|
1039 |
+
probabilities_mapped = probabilities.copy()
|
1040 |
+
|
1041 |
+
class_map = automl.outer_pipes[0].ml_algos[0].models[0][0].reader.class_mapping
|
1042 |
+
if class_map:
|
1043 |
+
column_to_class = {col: class_ for class_, col in class_map.items()}
|
1044 |
+
for i in range(0, len(column_to_class)):
|
1045 |
+
probabilities_mapped[:, int(column_to_class[int(i)])] = probabilities[:, int(i)]
|
1046 |
+
|
1047 |
+
metric = metric_used(test_y, probabilities_mapped)
|
1048 |
+
|
1049 |
+
return metric, probabilities_mapped, None
|
1050 |
+
|
1051 |
+
param_grid_hyperopt['lightgbm'] = {
|
1052 |
+
'num_leaves': hp.randint('num_leaves', 5, 50)
|
1053 |
+
, 'max_depth': hp.randint('max_depth', 3, 20)
|
1054 |
+
, 'learning_rate': hp.loguniform('learning_rate', -3, math.log(1.0))
|
1055 |
+
, 'n_estimators': hp.randint('n_estimators', 50, 2000)
|
1056 |
+
#, 'feature_fraction': 0.8,
|
1057 |
+
#, 'subsample': 0.2
|
1058 |
+
, 'min_child_weight': hp.choice('min_child_weight', [1e-5, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4])
|
1059 |
+
, 'subsample': hp.uniform('subsample', 0.2, 0.8)
|
1060 |
+
, 'colsample_bytree': hp.uniform('colsample_bytree', 0.2, 0.8)
|
1061 |
+
, 'reg_alpha': hp.choice('reg_alpha', [0, 1e-1, 1, 2, 5, 7, 10, 50, 100])
|
1062 |
+
, 'reg_lambda': hp.choice('reg_lambda', [0, 1e-1, 1, 5, 10, 20, 50, 100])
|
1063 |
+
} # 'normalize': [False],
|
1064 |
+
|
1065 |
+
from lightgbm import LGBMClassifier
|
1066 |
+
|
1067 |
+
def lightgbm_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1068 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
1069 |
+
, one_hot=False, impute=False, standardize=False
|
1070 |
+
, cat_features=cat_features)
|
1071 |
+
|
1072 |
+
def clf_(**params):
|
1073 |
+
return LGBMClassifier(categorical_feature=cat_features, use_missing=True
|
1074 |
+
, objective=get_scoring_string(metric_used, usage='lightgbm', multiclass=len(np.unique(y)) > 2), **params)
|
1075 |
+
|
1076 |
+
start_time = time.time()
|
1077 |
+
|
1078 |
+
def stop(trial):
|
1079 |
+
return time.time() - start_time > max_time, []
|
1080 |
+
|
1081 |
+
best = fmin(
|
1082 |
+
fn=lambda params: eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
1083 |
+
space=param_grid_hyperopt['lightgbm'],
|
1084 |
+
algo=rand.suggest,
|
1085 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
1086 |
+
early_stop_fn=stop,
|
1087 |
+
# The seed is deterministic but varies for each dataset and each split of it
|
1088 |
+
max_evals=10000)
|
1089 |
+
best = space_eval(param_grid_hyperopt['lightgbm'], best)
|
1090 |
+
|
1091 |
+
clf = clf_(**best)
|
1092 |
+
clf.fit(x, y)
|
1093 |
+
|
1094 |
+
pred = clf.predict_proba(test_x)
|
1095 |
+
metric = metric_used(test_y, pred)
|
1096 |
+
|
1097 |
+
return metric, pred, best
|
1098 |
+
|
1099 |
param_grid_hyperopt['logistic'] = {
|
1100 |
'penalty': hp.choice('penalty', ['l1', 'l2', 'none'])
|
1101 |
+
, 'max_iter': hp.randint('max_iter', 50, 500)
|
1102 |
, 'fit_intercept': hp.choice('fit_intercept', [True, False])
|
1103 |
, 'C': hp.loguniform('C', -5, math.log(5.0))} # 'normalize': [False],
|
1104 |
|
1105 |
+
|
1106 |
def logistic_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1107 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
1108 |
, one_hot=True, impute=True, standardize=True
|
|
|
1143 |
cat_features=cat_features)
|
1144 |
|
1145 |
def clf_(**params):
|
1146 |
+
if is_classification(metric_used):
|
1147 |
+
return neighbors.KNeighborsClassifier(n_jobs=1, **params)
|
1148 |
+
return neighbors.KNeighborsRegressor(n_jobs=1, **params)
|
1149 |
|
1150 |
start_time = time.time()
|
1151 |
|
|
|
1165 |
clf = clf_(**best)
|
1166 |
clf.fit(x, y)
|
1167 |
|
1168 |
+
if is_classification(metric_used):
|
1169 |
+
pred = clf.predict_proba(test_x)
|
1170 |
+
else:
|
1171 |
+
pred = clf.predict(test_x)
|
1172 |
metric = metric_used(test_y, pred)
|
1173 |
|
1174 |
return metric, pred, best
|
|
|
1176 |
## GP
|
1177 |
param_grid_hyperopt['gp'] = {
|
1178 |
'params_y_scale': hp.loguniform('params_y_scale', math.log(0.05), math.log(5.0)),
|
1179 |
+
'params_length_scale': hp.loguniform('params_length_scale', math.log(0.1), math.log(1.0))
|
|
|
1180 |
}
|
1181 |
def gp_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1182 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y,
|
|
|
1184 |
cat_features=cat_features)
|
1185 |
|
1186 |
def clf_(params_y_scale,params_length_scale, **params):
|
1187 |
+
if is_classification(metric_used):
|
1188 |
+
return GaussianProcessClassifier(kernel= params_y_scale * RBF(params_length_scale), **params)
|
1189 |
+
else:
|
1190 |
+
return GaussianProcessRegressor(kernel= params_y_scale * RBF(params_length_scale), **params)
|
1191 |
|
1192 |
start_time = time.time()
|
1193 |
def stop(trial):
|
|
|
1207 |
clf = clf_(**best)
|
1208 |
clf.fit(x, y)
|
1209 |
|
1210 |
+
if is_classification(metric_used):
|
1211 |
+
pred = clf.predict_proba(test_x)
|
1212 |
+
else:
|
1213 |
+
pred = clf.predict(test_x)
|
1214 |
+
metric = metric_used(test_y, pred)
|
1215 |
+
|
1216 |
+
return metric, pred, best
|
1217 |
+
|
1218 |
+
## Tabnet
|
1219 |
+
# https://github.com/dreamquark-ai/tabnet
|
1220 |
+
#param_grid['tabnet'] = {'n_d': [2, 4], 'n_steps': [2,4,6], 'gamma': [1.3], 'optimizer_params': [{'lr': 2e-2}, {'lr': 2e-1}]}
|
1221 |
+
|
1222 |
+
# Hyperparameter space from dreamquarks implementation recommendations
|
1223 |
+
param_grid_hyperopt['tabnet'] = {
|
1224 |
+
'n_d': hp.randint('n_d', 8, 64),
|
1225 |
+
'n_steps': hp.randint('n_steps', 3, 10),
|
1226 |
+
'max_epochs': hp.randint('max_epochs', 50, 200),
|
1227 |
+
'gamma': hp.uniform('relax', 1.0, 2.0),
|
1228 |
+
'momentum': hp.uniform('momentum', 0.01, 0.4),
|
1229 |
+
}
|
1230 |
+
|
1231 |
+
def tabnet_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1232 |
+
from pytorch_tabnet.tab_model import TabNetClassifier
|
1233 |
+
# TabNet inputs raw tabular data without any preprocessing and is trained using gradient descent-based optimisation.
|
1234 |
+
# However Tabnet cannot handle nans so we impute with mean
|
1235 |
+
|
1236 |
+
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y, impute=True, one_hot=False, standardize=False)
|
1237 |
+
|
1238 |
+
def clf_(**params):
|
1239 |
+
return TabNetClassifier(cat_idxs=cat_features, verbose=True, n_a=params['n_d'], seed=int(y[:].sum()), **params)
|
1240 |
+
|
1241 |
+
def tabnet_eval_f(params, clf_, x, y, metric_used, start_time, max_time):
|
1242 |
+
if time.time() - start_time > max_time:
|
1243 |
+
return np.nan
|
1244 |
+
|
1245 |
+
kf = KFold(n_splits=min(CV, x.shape[0] // 2), random_state=None, shuffle=True)
|
1246 |
+
metrics = []
|
1247 |
+
|
1248 |
+
params = {**params}
|
1249 |
+
max_epochs = params['max_epochs']
|
1250 |
+
del params['max_epochs']
|
1251 |
+
|
1252 |
+
for train_index, test_index in kf.split(x):
|
1253 |
+
X_train, X_valid, y_train, y_valid = x[train_index], x[test_index], y[train_index], y[test_index]
|
1254 |
+
|
1255 |
+
clf = clf_(**params)
|
1256 |
+
|
1257 |
+
clf.fit(
|
1258 |
+
X_train, y_train,
|
1259 |
+
# eval_metric=[get_scoring_string(metric_used, multiclass=len(np.unique(y_train)) > 2, usage='tabnet')],
|
1260 |
+
# eval_set=[(X_valid, y_valid)],
|
1261 |
+
# patience=15,
|
1262 |
+
max_epochs=max_epochs
|
1263 |
+
)
|
1264 |
+
metrics += [metric_used(y_valid, clf.predict_proba(X_valid))]
|
1265 |
+
|
1266 |
+
return -np.nanmean(np.array(metrics))
|
1267 |
+
|
1268 |
+
start_time = time.time()
|
1269 |
+
def stop(trial):
|
1270 |
+
return time.time() - start_time > max_time, []
|
1271 |
+
|
1272 |
+
best = fmin(
|
1273 |
+
fn=lambda params: tabnet_eval_f(params, clf_, x, y, metric_used, start_time, max_time),
|
1274 |
+
space=param_grid_hyperopt['tabnet'],
|
1275 |
+
algo=rand.suggest,
|
1276 |
+
rstate=np.random.RandomState(int(y[:].sum())),
|
1277 |
+
early_stop_fn=stop,
|
1278 |
+
max_evals=1000)
|
1279 |
+
best = space_eval(param_grid_hyperopt['tabnet'], best)
|
1280 |
+
max_epochs = best['max_epochs']
|
1281 |
+
del best['max_epochs']
|
1282 |
+
|
1283 |
+
clf = clf_(**best)
|
1284 |
+
clf.fit(x, y, max_epochs=max_epochs) # , max_epochs=mean_best_epochs[best_idx]
|
1285 |
+
|
1286 |
pred = clf.predict_proba(test_x)
|
1287 |
metric = metric_used(test_y, pred)
|
1288 |
|
1289 |
return metric, pred, best
|
1290 |
|
1291 |
+
return metric, pred, params_used[best_idx]
|
1292 |
+
|
1293 |
|
1294 |
# Catboost
|
1295 |
# Hyperparameter space: https://arxiv.org/pdf/2106.03253.pdf
|
|
|
1304 |
}
|
1305 |
|
1306 |
def catboost_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
|
|
|
|
1307 |
x, y, test_x, test_y = preprocess_impute(x, y, test_x, test_y
|
1308 |
, one_hot=False
|
1309 |
, cat_features=cat_features
|
|
|
1324 |
test_x = make_pd_from_np(test_x)
|
1325 |
|
1326 |
def clf_(**params):
|
1327 |
+
if is_classification(metric_used):
|
1328 |
+
return CatBoostClassifier(
|
1329 |
+
loss_function=get_scoring_string(metric_used, usage='catboost'),
|
1330 |
+
thread_count = MULTITHREAD,
|
1331 |
+
used_ram_limit='4gb',
|
1332 |
+
random_seed=int(y[:].sum()),
|
1333 |
+
logging_level='Silent',
|
1334 |
+
cat_features=cat_features,
|
1335 |
+
**params)
|
1336 |
+
else:
|
1337 |
+
return CatBoostRegressor(
|
1338 |
+
loss_function=get_scoring_string(metric_used, usage='catboost'),
|
1339 |
+
thread_count=MULTITHREAD,
|
1340 |
+
used_ram_limit='4gb',
|
1341 |
+
random_seed=int(y[:].sum()),
|
1342 |
+
logging_level='Silent',
|
1343 |
+
cat_features=cat_features,
|
1344 |
+
**params)
|
1345 |
|
1346 |
start_time = time.time()
|
1347 |
def stop(trial):
|
|
|
1359 |
|
1360 |
clf = clf_(**best)
|
1361 |
clf.fit(x, y)
|
1362 |
+
if is_classification(metric_used):
|
1363 |
+
pred = clf.predict_proba(test_x)
|
1364 |
+
else:
|
1365 |
+
pred = clf.predict(test_x)
|
1366 |
metric = metric_used(test_y, pred)
|
1367 |
|
1368 |
return metric, pred, best
|
|
|
1384 |
}
|
1385 |
|
1386 |
def xgb_metric(x, y, test_x, test_y, cat_features, metric_used, max_time=300):
|
1387 |
+
import xgboost as xgb
|
1388 |
# XGB Documentation:
|
1389 |
# XGB handles categorical data appropriately without using One Hot Encoding, categorical features are experimetal
|
1390 |
# XGB handles missing values appropriately without imputation
|
|
|
1396 |
, standardize=False)
|
1397 |
|
1398 |
def clf_(**params):
|
1399 |
+
if is_classification(metric_used):
|
1400 |
+
return xgb.XGBClassifier(use_label_encoder=False
|
1401 |
+
, nthread=1
|
1402 |
+
, **params
|
1403 |
+
, eval_metric=get_scoring_string(metric_used, usage='xgb') # AUC not implemented
|
1404 |
+
)
|
1405 |
+
else:
|
1406 |
+
return xgb.XGBRegressor(use_label_encoder=False
|
1407 |
+
, nthread=1
|
1408 |
+
, **params
|
1409 |
+
, eval_metric=get_scoring_string(metric_used, usage='xgb') # AUC not implemented
|
1410 |
+
)
|
1411 |
|
1412 |
start_time = time.time()
|
1413 |
def stop(trial):
|
|
|
1426 |
clf = clf_(**best)
|
1427 |
clf.fit(x, y)
|
1428 |
|
1429 |
+
if is_classification(metric_used):
|
1430 |
+
pred = clf.predict_proba(test_x)
|
1431 |
+
else:
|
1432 |
+
pred = clf.predict(test_x)
|
1433 |
metric = metric_used(test_y, pred)
|
1434 |
|
1435 |
return metric, pred, best
|
1436 |
|
1437 |
+
"""
|
1438 |
+
LEGACY UNUSED
|
1439 |
+
"""
|
1440 |
+
|
1441 |
+
## Ridge
|
1442 |
+
from sklearn.linear_model import RidgeClassifier
|
1443 |
+
param_grid['ridge'] = {'alpha': [0, 0.1, .5, 1.0, 2.0], 'fit_intercept': [True, False]} # 'normalize': [False],
|
1444 |
+
def ridge_metric(x, y, test_x, test_y, cat_features, metric_used):
|
1445 |
+
import warnings
|
1446 |
+
def warn(*args, **kwargs):
|
1447 |
+
pass
|
1448 |
+
|
1449 |
+
warnings.warn = warn
|
1450 |
+
|
1451 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
1452 |
+
x, test_x = torch.nan_to_num(x), torch.nan_to_num(test_x)
|
1453 |
+
|
1454 |
+
clf = RidgeClassifier(n_jobs=1)
|
1455 |
+
|
1456 |
+
# create a dictionary of all values we want to test for n_neighbors
|
1457 |
+
# use gridsearch to test all values for n_neighbors
|
1458 |
+
clf = GridSearchCV(clf, param_grid['ridge'], cv=min(CV, x.shape[0]//2)
|
1459 |
+
, scoring=get_scoring_string(metric_used)
|
1460 |
+
, n_jobs=MULTITHREAD)
|
1461 |
+
# fit model to data
|
1462 |
+
clf.fit(x, y.long())
|
1463 |
+
|
1464 |
+
pred = clf.decision_function(test_x)
|
1465 |
+
metric = metric_used(test_y, pred)
|
1466 |
+
|
1467 |
+
return metric, pred
|
1468 |
+
|
1469 |
+
def mlp_acc(x, y, test_x, test_y, hyperparameters):
|
1470 |
+
num_layers, hidden_dim, activation_module, fixed_dropout_prob, is_binary_classification, epochs, lr, weight_decay = hyperparameters
|
1471 |
+
num_features = x.shape[1]
|
1472 |
+
|
1473 |
+
x, y = x.to(device), y.to(device)
|
1474 |
+
test_x, test_y = test_x.to(device), test_y.to(device)
|
1475 |
+
|
1476 |
+
def get_model():
|
1477 |
+
model = nn.Sequential(*[
|
1478 |
+
module for layer_idx in range(num_layers) for module in [
|
1479 |
+
nn.Linear(hidden_dim if layer_idx > 0 else num_features,
|
1480 |
+
2 if layer_idx == num_layers - 1 else hidden_dim),
|
1481 |
+
torch.nn.Identity() if layer_idx == num_layers - 1 else activation_module(),
|
1482 |
+
torch.nn.Identity() if layer_idx == num_layers - 1 else torch.nn.Dropout(p=fixed_dropout_prob,
|
1483 |
+
inplace=False)]
|
1484 |
+
])
|
1485 |
+
if is_binary_classification:
|
1486 |
+
model.add_module(str(len(model)), torch.nn.Softmax(dim=1)) # TODO might also just do an round!?
|
1487 |
+
return model
|
1488 |
+
|
1489 |
+
model = get_model().to(device)
|
1490 |
+
criterion = torch.nn.BCELoss()
|
1491 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
1492 |
+
|
1493 |
+
model.train()
|
1494 |
+
for epoch in range(epochs):
|
1495 |
+
optimizer.zero_grad()
|
1496 |
+
# Forward pass
|
1497 |
+
y_pred = model(x)[:, 1]
|
1498 |
+
# Compute Loss
|
1499 |
+
|
1500 |
+
loss = criterion(y_pred.squeeze(), y.float())
|
1501 |
+
|
1502 |
+
# print('Epoch {}: train loss: {}'.format(epoch, loss.item()))
|
1503 |
+
# Backward pass
|
1504 |
+
loss.backward()
|
1505 |
+
optimizer.step()
|
1506 |
+
|
1507 |
+
model.eval()
|
1508 |
+
pred_y = model(test_x)[:, 1] > 0.5
|
1509 |
+
acc = (pred_y == test_y).float().mean()
|
1510 |
+
return acc
|
1511 |
|
1512 |
clf_dict = {'gp': gp_metric
|
1513 |
, 'knn': knn_metric
|
1514 |
, 'catboost': catboost_metric
|
1515 |
+
, 'tabnet': tabnet_metric
|
1516 |
, 'xgb': xgb_metric
|
1517 |
+
, 'ridge': ridge_metric
|
1518 |
, 'logistic': logistic_metric
|
1519 |
, 'autosklearn': autosklearn_metric
|
1520 |
, 'autosklearn2': autosklearn2_metric
|
1521 |
+
, 'autogluon': autogluon_metric,
|
1522 |
+
'cocktail': well_tuned_simple_nets_metric}
|
TabPFN/scripts/tabular_baselines_deep.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pathlib
|
3 |
+
from argparse import Namespace
|
4 |
+
|
5 |
+
from sklearn.model_selection import GridSearchCV
|
6 |
+
import sys
|
7 |
+
|
8 |
+
CV = 5
|
9 |
+
param_grid = {}
|
10 |
+
|
11 |
+
param_grid['saint'] = {
|
12 |
+
# as in https://github.com/kathrinse/TabSurvey/blob/main/models/saint.py#L268
|
13 |
+
"dim": [32, 64, 128, 256],
|
14 |
+
"depth": [1, 2, 3, 6, 12],
|
15 |
+
"heads": [2, 4, 8],
|
16 |
+
"dropout": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
17 |
+
}
|
18 |
+
|
19 |
+
def saint_metric(x, y, test_x, test_y, cat_features, metric_used):
|
20 |
+
## Original Implementation https://github.com/somepago/saint
|
21 |
+
## Reimplementation from https://github.com/kathrinse/TabSurvey
|
22 |
+
## HowTo install
|
23 |
+
# git clone git@github.com:kathrinse/TabSurvey.git
|
24 |
+
# cd TabSurvey
|
25 |
+
# requirements
|
26 |
+
# optuna
|
27 |
+
# scikit-learn
|
28 |
+
# pandas
|
29 |
+
# configargparse
|
30 |
+
# torch
|
31 |
+
# einops
|
32 |
+
pre_cwd = os.getcwd()
|
33 |
+
|
34 |
+
# TODO: Make sure that we change to TabSurvey in here
|
35 |
+
# Assume it is in ../../TabSurvey
|
36 |
+
dest_wd = pathlib.Path(__file__).absolute().parent.parent.joinpath("../TabSurvey")
|
37 |
+
print(f"Change from {pre_cwd} to {dest_wd}")
|
38 |
+
sys.chdir(dest_wd)
|
39 |
+
|
40 |
+
try:
|
41 |
+
from models.saint import SAINT
|
42 |
+
|
43 |
+
import warnings
|
44 |
+
def warn(*args, **kwargs):
|
45 |
+
pass
|
46 |
+
|
47 |
+
# get cat dims
|
48 |
+
# assume cat_features is a list of idx
|
49 |
+
# TODO: FIX this if wrong
|
50 |
+
cat_dims = []
|
51 |
+
for idx in cat_features:
|
52 |
+
cat_dims.append(len(set(x[idx, :])))
|
53 |
+
model_args = Namespace(
|
54 |
+
num_features=x.shape[1],
|
55 |
+
cat_idx=cat_features,
|
56 |
+
cat_dims=cat_dims,
|
57 |
+
)
|
58 |
+
warnings.warn = warn
|
59 |
+
|
60 |
+
x, y, test_x, test_y = x.cpu(), y.cpu(), test_x.cpu(), test_y.cpu()
|
61 |
+
|
62 |
+
clf = SAINT(model_args)
|
63 |
+
|
64 |
+
clf = GridSearchCV(clf, param_grid['saint'], cv=min(CV, x.shape[0]//2))
|
65 |
+
# fit model to data
|
66 |
+
clf.fit(x, y.long())
|
67 |
+
|
68 |
+
pred = clf.decision_function(test_x)
|
69 |
+
metric = metric_used(test_y.cpu().numpy(), pred)
|
70 |
+
except:
|
71 |
+
raise
|
72 |
+
finally:
|
73 |
+
os.chdir(pre_cwd)
|
74 |
+
return metric, pred
|
TabPFN/scripts/tabular_evaluation.py
CHANGED
@@ -1,16 +1,17 @@
|
|
1 |
import time
|
2 |
import os
|
3 |
from pathlib import Path
|
|
|
4 |
|
|
|
5 |
from tqdm import tqdm
|
6 |
import random
|
7 |
import numpy as np
|
8 |
|
9 |
from torch import nn
|
10 |
|
11 |
-
from utils import
|
12 |
-
from
|
13 |
-
from model_builder import load_model
|
14 |
from scripts.tabular_baselines import get_scoring_string
|
15 |
from scripts import tabular_metrics
|
16 |
from scripts.transformer_prediction_interface import *
|
@@ -52,7 +53,6 @@ def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_
|
|
52 |
model_file, model_path, results_file = check_file(e)
|
53 |
|
54 |
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
|
55 |
-
print(model[2].style_encoder)
|
56 |
|
57 |
params = {'max_features': config_sample['num_features']
|
58 |
, 'rescale_features': config_sample["normalize_by_used_features"]
|
@@ -79,7 +79,7 @@ def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_
|
|
79 |
return metrics_valid, config_sample, model_path
|
80 |
|
81 |
|
82 |
-
def evaluate(datasets, bptt, eval_positions, metric_used, model
|
83 |
, verbose=False
|
84 |
, return_tensor=False
|
85 |
, **kwargs):
|
@@ -102,10 +102,10 @@ def evaluate(datasets, bptt, eval_positions, metric_used, model
|
|
102 |
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
|
103 |
|
104 |
# For each dataset
|
105 |
-
for [ds_name, X, y, categorical_feats, _, _] in
|
106 |
dataset_bptt = min(len(X), bptt)
|
107 |
-
#
|
108 |
-
# print(f'Dataset too small for given
|
109 |
|
110 |
aggregated_metric, num = torch.tensor(0.0), 0
|
111 |
ds_result = {}
|
@@ -121,9 +121,11 @@ def evaluate(datasets, bptt, eval_positions, metric_used, model
|
|
121 |
, ds_name=ds_name
|
122 |
, eval_position = eval_position_real
|
123 |
, metric_used = metric_used
|
|
|
124 |
,**kwargs)
|
125 |
|
126 |
if r is None:
|
|
|
127 |
continue
|
128 |
|
129 |
_, outputs, ys, best_configs, time_used = r
|
@@ -132,6 +134,17 @@ def evaluate(datasets, bptt, eval_positions, metric_used, model
|
|
132 |
outputs = outputs.to(outputs.device)
|
133 |
ys = ys.to(outputs.device)
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
ys = ys.T
|
136 |
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
|
137 |
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
|
@@ -171,7 +184,7 @@ def check_file_exists(path):
|
|
171 |
return np.load(f, allow_pickle=True).tolist()
|
172 |
return None
|
173 |
|
174 |
-
def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
175 |
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
|
176 |
the entire datasets. If no such split can be sampled in 7 passes, returns None.
|
177 |
|
@@ -187,7 +200,6 @@ def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
|
187 |
torch.manual_seed(split_number)
|
188 |
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
|
189 |
X, y = X[perm], y[perm]
|
190 |
-
|
191 |
while not done:
|
192 |
if seed > 20:
|
193 |
return None, None # No split could be generated in 7 passes, return None
|
@@ -195,13 +207,16 @@ def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
|
195 |
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
|
196 |
y_ = y[i:i + bptt]
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
205 |
|
206 |
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
|
207 |
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
|
@@ -211,7 +226,7 @@ def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
|
211 |
|
212 |
def evaluate_position(X, y, categorical_feats, model, bptt
|
213 |
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
|
214 |
-
, max_time=300, split_number=1
|
215 |
, per_step_normalization=False, **kwargs):
|
216 |
"""
|
217 |
Evaluates a dataset with a 'bptt' number of training samples.
|
@@ -250,24 +265,37 @@ def evaluate_position(X, y, categorical_feats, model, bptt
|
|
250 |
return None
|
251 |
|
252 |
## Generate data splits
|
253 |
-
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position
|
|
|
|
|
254 |
if eval_xs is None:
|
255 |
-
return None
|
256 |
print(f"No dataset could be generated {ds_name} {bptt}")
|
|
|
257 |
|
258 |
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
|
259 |
|
|
|
|
|
|
|
|
|
|
|
260 |
start_time = time.time()
|
261 |
|
262 |
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
|
263 |
-
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position,
|
|
|
|
|
|
|
|
|
|
|
264 |
else:
|
265 |
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
|
266 |
, eval_pos=eval_position
|
267 |
-
,
|
268 |
-
|
269 |
eval_ys = eval_ys[eval_position:]
|
270 |
if outputs is None:
|
|
|
271 |
return None
|
272 |
|
273 |
if torch.is_tensor(outputs): # Transfers data to cpu for saving
|
|
|
1 |
import time
|
2 |
import os
|
3 |
from pathlib import Path
|
4 |
+
from contextlib import nullcontext
|
5 |
|
6 |
+
import torch
|
7 |
from tqdm import tqdm
|
8 |
import random
|
9 |
import numpy as np
|
10 |
|
11 |
from torch import nn
|
12 |
|
13 |
+
from torch.utils.checkpoint import checkpoint
|
14 |
+
from utils import normalize_data, torch_nanmean, to_ranking_low_mem, remove_outliers
|
|
|
15 |
from scripts.tabular_baselines import get_scoring_string
|
16 |
from scripts import tabular_metrics
|
17 |
from scripts.transformer_prediction_interface import *
|
|
|
53 |
model_file, model_path, results_file = check_file(e)
|
54 |
|
55 |
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
|
|
|
56 |
|
57 |
params = {'max_features': config_sample['num_features']
|
58 |
, 'rescale_features': config_sample["normalize_by_used_features"]
|
|
|
79 |
return metrics_valid, config_sample, model_path
|
80 |
|
81 |
|
82 |
+
def evaluate(datasets, bptt, eval_positions, metric_used, model, device='cpu'
|
83 |
, verbose=False
|
84 |
, return_tensor=False
|
85 |
, **kwargs):
|
|
|
102 |
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
|
103 |
|
104 |
# For each dataset
|
105 |
+
for [ds_name, X, y, categorical_feats, _, _] in datasets:
|
106 |
dataset_bptt = min(len(X), bptt)
|
107 |
+
#if verbose and dataset_bptt < bptt:
|
108 |
+
# print(f'Dataset too small for given bptt, reducing to {len(X)} ({bptt})')
|
109 |
|
110 |
aggregated_metric, num = torch.tensor(0.0), 0
|
111 |
ds_result = {}
|
|
|
121 |
, ds_name=ds_name
|
122 |
, eval_position = eval_position_real
|
123 |
, metric_used = metric_used
|
124 |
+
, device=device
|
125 |
,**kwargs)
|
126 |
|
127 |
if r is None:
|
128 |
+
print('Execution failed')
|
129 |
continue
|
130 |
|
131 |
_, outputs, ys, best_configs, time_used = r
|
|
|
134 |
outputs = outputs.to(outputs.device)
|
135 |
ys = ys.to(outputs.device)
|
136 |
|
137 |
+
# WARNING: This leaks information on the scaling of the labels
|
138 |
+
if isinstance(model, nn.Module) and "BarDistribution" in str(type(model.criterion)):
|
139 |
+
ys = (ys - torch.min(ys, axis=0)[0]) / (torch.max(ys, axis=0)[0] - torch.min(ys, axis=0)[0])
|
140 |
+
|
141 |
+
# If we use the bar distribution and the metric_used is r2 -> convert buckets
|
142 |
+
# metric used is prob -> keep
|
143 |
+
if isinstance(model, nn.Module) and "BarDistribution" in str(type(model.criterion)) and (
|
144 |
+
metric_used == tabular_metrics.r2_metric or metric_used == tabular_metrics.root_mean_squared_error_metric):
|
145 |
+
ds_result[f'{ds_name}_bar_dist_at_{eval_position}'] = outputs
|
146 |
+
outputs = model.criterion.mean(outputs)
|
147 |
+
|
148 |
ys = ys.T
|
149 |
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
|
150 |
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
|
|
|
184 |
return np.load(f, allow_pickle=True).tolist()
|
185 |
return None
|
186 |
|
187 |
+
def generate_valid_split(X, y, bptt, eval_position, is_classification, split_number=1):
|
188 |
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
|
189 |
the entire datasets. If no such split can be sampled in 7 passes, returns None.
|
190 |
|
|
|
200 |
torch.manual_seed(split_number)
|
201 |
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
|
202 |
X, y = X[perm], y[perm]
|
|
|
203 |
while not done:
|
204 |
if seed > 20:
|
205 |
return None, None # No split could be generated in 7 passes, return None
|
|
|
207 |
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
|
208 |
y_ = y[i:i + bptt]
|
209 |
|
210 |
+
if is_classification:
|
211 |
+
# Checks if all classes from dataset are contained and classes in train and test are equal (contain same
|
212 |
+
# classes) and
|
213 |
+
done = len(torch.unique(y_)) == len(torch.unique(y))
|
214 |
+
done = done and torch.all(torch.unique(y_) == torch.unique(y))
|
215 |
+
done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
|
216 |
+
done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
|
217 |
+
seed = seed + 1
|
218 |
+
else:
|
219 |
+
done = True
|
220 |
|
221 |
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
|
222 |
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
|
|
|
226 |
|
227 |
def evaluate_position(X, y, categorical_feats, model, bptt
|
228 |
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
|
229 |
+
, max_time=300, split_number=1, metric_used=None, device='cpu'
|
230 |
, per_step_normalization=False, **kwargs):
|
231 |
"""
|
232 |
Evaluates a dataset with a 'bptt' number of training samples.
|
|
|
265 |
return None
|
266 |
|
267 |
## Generate data splits
|
268 |
+
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position
|
269 |
+
, is_classification=tabular_metrics.is_classification(metric_used)
|
270 |
+
, split_number=split_number)
|
271 |
if eval_xs is None:
|
|
|
272 |
print(f"No dataset could be generated {ds_name} {bptt}")
|
273 |
+
return None
|
274 |
|
275 |
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
|
276 |
|
277 |
+
if isinstance(model, nn.Module):
|
278 |
+
model = model.to(device)
|
279 |
+
eval_xs = eval_xs.to(device)
|
280 |
+
eval_ys = eval_ys.to(device)
|
281 |
+
|
282 |
start_time = time.time()
|
283 |
|
284 |
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
|
285 |
+
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, metric_used=metric_used
|
286 |
+
, categorical_feats=categorical_feats
|
287 |
+
, inference_mode=True
|
288 |
+
, device=device
|
289 |
+
, extend_features=True,
|
290 |
+
**kwargs), None
|
291 |
else:
|
292 |
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
|
293 |
, eval_pos=eval_position
|
294 |
+
, device=device
|
295 |
+
, max_time=max_time, metric_used=metric_used, **kwargs)
|
296 |
eval_ys = eval_ys[eval_position:]
|
297 |
if outputs is None:
|
298 |
+
print('Execution failed')
|
299 |
return None
|
300 |
|
301 |
if torch.is_tensor(outputs): # Transfers data to cpu for saving
|
TabPFN/scripts/tabular_metrics.py
CHANGED
@@ -10,10 +10,25 @@ Includes a few metric as well as functions composing metrics on results files.
|
|
10 |
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
-
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, average_precision_score
|
14 |
from scipy.stats import rankdata
|
15 |
import pandas as pd
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
"""
|
18 |
===============================
|
19 |
Metrics calculation
|
@@ -37,7 +52,7 @@ def auc_metric(target, pred, multi_class='ovo', numpy=False):
|
|
37 |
return roc_auc_score(target, pred)
|
38 |
except ValueError as e:
|
39 |
print(e)
|
40 |
-
return np.nan
|
41 |
|
42 |
def accuracy_metric(target, pred):
|
43 |
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
@@ -73,6 +88,19 @@ def cross_entropy(target, pred):
|
|
73 |
bce = torch.nn.BCELoss()
|
74 |
return bce(pred[:, 1].float(), target.float())
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def time_metric():
|
77 |
"""
|
78 |
Dummy function, will just be used as a handler.
|
@@ -90,7 +118,7 @@ def count_metric(x, y):
|
|
90 |
Metrics composition
|
91 |
===============================
|
92 |
"""
|
93 |
-
def calculate_score_per_method(metric, name:str, global_results:dict, ds:list, eval_positions:list, aggregator:str='mean'):
|
94 |
"""
|
95 |
Calculates the metric given by 'metric' and saves it under 'name' in the 'global_results'
|
96 |
|
@@ -156,15 +184,18 @@ def calculate_score(metric, name, global_results, ds, eval_positions, aggregator
|
|
156 |
def make_metric_matrix(global_results, methods, pos, name, ds):
|
157 |
result = []
|
158 |
for m in global_results:
|
159 |
-
|
|
|
|
|
|
|
160 |
result = np.array(result)
|
161 |
-
result = pd.DataFrame(result.T, index=[d[0] for d in ds], columns=[k
|
162 |
|
163 |
matrix_means, matrix_stds = [], []
|
164 |
|
165 |
for method in methods:
|
166 |
-
matrix_means += [result.iloc[:, [(method)
|
167 |
-
matrix_stds += [result.iloc[:, [(method)
|
168 |
|
169 |
matrix_means = pd.DataFrame(matrix_means, index=methods).T
|
170 |
matrix_stds = pd.DataFrame(matrix_stds, index=methods).T
|
|
|
10 |
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
+
from sklearn.metrics import roc_auc_score, accuracy_score, balanced_accuracy_score, average_precision_score, mean_squared_error, mean_absolute_error, r2_score
|
14 |
from scipy.stats import rankdata
|
15 |
import pandas as pd
|
16 |
|
17 |
+
def root_mean_squared_error_metric(target, pred):
|
18 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
19 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
20 |
+
return torch.sqrt(torch.nn.functional.mse_loss(target, pred))
|
21 |
+
|
22 |
+
def mean_squared_error_metric(target, pred):
|
23 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
24 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
25 |
+
return torch.nn.functional.mse_loss(target, pred)
|
26 |
+
|
27 |
+
def mean_absolute_error_metric(target, pred):
|
28 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
29 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
30 |
+
return torch.tensor(mean_absolute_error(target, pred))
|
31 |
+
|
32 |
"""
|
33 |
===============================
|
34 |
Metrics calculation
|
|
|
52 |
return roc_auc_score(target, pred)
|
53 |
except ValueError as e:
|
54 |
print(e)
|
55 |
+
return np.nan if numpy else torch.tensor(np.nan)
|
56 |
|
57 |
def accuracy_metric(target, pred):
|
58 |
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
|
|
88 |
bce = torch.nn.BCELoss()
|
89 |
return bce(pred[:, 1].float(), target.float())
|
90 |
|
91 |
+
def r2_metric(target, pred):
|
92 |
+
target = torch.tensor(target) if not torch.is_tensor(target) else target
|
93 |
+
pred = torch.tensor(pred) if not torch.is_tensor(pred) else pred
|
94 |
+
return torch.tensor(neg_r2(target, pred))
|
95 |
+
|
96 |
+
def neg_r2(target, pred):
|
97 |
+
return -r2_score(pred.float(), target.float())
|
98 |
+
|
99 |
+
def is_classification(metric_used):
|
100 |
+
if metric_used == auc_metric or metric_used == cross_entropy:
|
101 |
+
return True
|
102 |
+
return False
|
103 |
+
|
104 |
def time_metric():
|
105 |
"""
|
106 |
Dummy function, will just be used as a handler.
|
|
|
118 |
Metrics composition
|
119 |
===============================
|
120 |
"""
|
121 |
+
def calculate_score_per_method(metric, name:str, global_results:dict, ds:list, eval_positions:list[int], aggregator:str='mean'):
|
122 |
"""
|
123 |
Calculates the metric given by 'metric' and saves it under 'name' in the 'global_results'
|
124 |
|
|
|
184 |
def make_metric_matrix(global_results, methods, pos, name, ds):
|
185 |
result = []
|
186 |
for m in global_results:
|
187 |
+
try:
|
188 |
+
result += [[global_results[m][d[0] + '_' + name + '_at_' + str(pos)] for d in ds]]
|
189 |
+
except Exception as e:
|
190 |
+
result += [[np.nan]]
|
191 |
result = np.array(result)
|
192 |
+
result = pd.DataFrame(result.T, index=[d[0] for d in ds], columns=[k for k in list(global_results.keys())])
|
193 |
|
194 |
matrix_means, matrix_stds = [], []
|
195 |
|
196 |
for method in methods:
|
197 |
+
matrix_means += [result.iloc[:, [c.startswith(method+'_time') for c in result.columns]].mean(axis=1)]
|
198 |
+
matrix_stds += [result.iloc[:, [c.startswith(method+'_time') for c in result.columns]].std(axis=1)]
|
199 |
|
200 |
matrix_means = pd.DataFrame(matrix_means, index=methods).T
|
201 |
matrix_stds = pd.DataFrame(matrix_stds, index=methods).T
|
TabPFN/scripts/transformer_prediction_interface.py
CHANGED
@@ -94,7 +94,7 @@ class TabPFNClassifier(BaseEstimator, ClassifierMixin):
|
|
94 |
i, e = i, -1
|
95 |
|
96 |
# File which contains result of hyperparameter tuning run: style (i.e. hyperparameters) and a dataframe with results.
|
97 |
-
style_file = 'prior_tuning_result.pkl'
|
98 |
|
99 |
model, c, results_file = load_model_workflow(i, e, add_name=model_string, base_path=base_path, device=device,
|
100 |
eval_addition='')
|
|
|
94 |
i, e = i, -1
|
95 |
|
96 |
# File which contains result of hyperparameter tuning run: style (i.e. hyperparameters) and a dataframe with results.
|
97 |
+
#style_file = 'prior_tuning_result.pkl'
|
98 |
|
99 |
model, c, results_file = load_model_workflow(i, e, add_name=model_string, base_path=base_path, device=device,
|
100 |
eval_addition='')
|
TabPFN/tabular_evaluation.py
DELETED
@@ -1,283 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
from tqdm import tqdm
|
6 |
-
import random
|
7 |
-
import numpy as np
|
8 |
-
|
9 |
-
from torch import nn
|
10 |
-
|
11 |
-
from utils import torch_nanmean
|
12 |
-
from datasets import *
|
13 |
-
from model_builder import load_model
|
14 |
-
from scripts.tabular_baselines import get_scoring_string
|
15 |
-
from scripts import tabular_metrics
|
16 |
-
from scripts.transformer_prediction_interface import *
|
17 |
-
from scripts.baseline_prediction_interface import *
|
18 |
-
"""
|
19 |
-
===============================
|
20 |
-
PUBLIC FUNCTIONS FOR EVALUATION
|
21 |
-
===============================
|
22 |
-
"""
|
23 |
-
|
24 |
-
|
25 |
-
def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
26 |
-
metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
27 |
-
metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
|
28 |
-
return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
|
29 |
-
|
30 |
-
def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
|
31 |
-
|
32 |
-
# How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
|
33 |
-
def check_file(e):
|
34 |
-
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
|
35 |
-
model_path = os.path.join(base_path, model_file)
|
36 |
-
# print('Evaluate ', model_path)
|
37 |
-
results_file = os.path.join(base_path,
|
38 |
-
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
|
39 |
-
if not Path(model_path).is_file(): # or Path(results_file).is_file():
|
40 |
-
# print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
|
41 |
-
return None, None, None
|
42 |
-
return model_file, model_path, results_file
|
43 |
-
|
44 |
-
if e == -1: # use last checkpoint, if e == -1
|
45 |
-
for e_ in range(100, -1, -1):
|
46 |
-
model_file_, model_path_, results_file_ = check_file(e_)
|
47 |
-
if model_file_ is not None:
|
48 |
-
e = e_
|
49 |
-
model_file, model_path, results_file = model_file_, model_path_, results_file_
|
50 |
-
break
|
51 |
-
else:
|
52 |
-
model_file, model_path, results_file = check_file(e)
|
53 |
-
|
54 |
-
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
|
55 |
-
|
56 |
-
params = {'max_features': config_sample['num_features']
|
57 |
-
, 'rescale_features': config_sample["normalize_by_used_features"]
|
58 |
-
, 'normalize_to_ranking': config_sample["normalize_to_ranking"]
|
59 |
-
, 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
|
60 |
-
}
|
61 |
-
metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
|
62 |
-
extend_features=True
|
63 |
-
# just removed the style keyword but transformer is trained with style, just empty
|
64 |
-
, save=False
|
65 |
-
, metric_used=tabular_metrics.cross_entropy
|
66 |
-
, return_tensor=True
|
67 |
-
, verbose=False
|
68 |
-
, eval_positions=eval_positions
|
69 |
-
, bptt=bptt
|
70 |
-
, base_path=None
|
71 |
-
, inference_mode=True
|
72 |
-
, **params
|
73 |
-
, **kwargs)
|
74 |
-
|
75 |
-
tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
|
76 |
-
tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
|
77 |
-
|
78 |
-
return metrics_valid, config_sample, model_path
|
79 |
-
|
80 |
-
|
81 |
-
def evaluate(datasets, bptt, eval_positions, metric_used, model
|
82 |
-
, verbose=False
|
83 |
-
, return_tensor=False
|
84 |
-
, **kwargs):
|
85 |
-
"""
|
86 |
-
Evaluates a list of datasets for a model function.
|
87 |
-
|
88 |
-
:param datasets: List of datasets
|
89 |
-
:param bptt: maximum sequence length
|
90 |
-
:param eval_positions: List of positions where to evaluate models
|
91 |
-
:param verbose: If True, is verbose.
|
92 |
-
:param metric_used: Which metric is optimized for.
|
93 |
-
:param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
|
94 |
-
:param kwargs:
|
95 |
-
:return:
|
96 |
-
"""
|
97 |
-
overall_result = {'metric_used': get_scoring_string(metric_used)
|
98 |
-
, 'bptt': bptt
|
99 |
-
, 'eval_positions': eval_positions}
|
100 |
-
|
101 |
-
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
|
102 |
-
|
103 |
-
# For each dataset
|
104 |
-
for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
|
105 |
-
dataset_bptt = min(len(X), bptt)
|
106 |
-
# if verbose and dataset_bptt < bptt:
|
107 |
-
# print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
|
108 |
-
|
109 |
-
aggregated_metric, num = torch.tensor(0.0), 0
|
110 |
-
ds_result = {}
|
111 |
-
|
112 |
-
for eval_position in (eval_positions if verbose else eval_positions):
|
113 |
-
eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
|
114 |
-
eval_position_bptt = int(eval_position_real * 2.0)
|
115 |
-
|
116 |
-
r = evaluate_position(X, y, model=model
|
117 |
-
, num_classes=len(torch.unique(y))
|
118 |
-
, categorical_feats = categorical_feats
|
119 |
-
, bptt = eval_position_bptt
|
120 |
-
, ds_name=ds_name
|
121 |
-
, eval_position = eval_position_real
|
122 |
-
, metric_used = metric_used
|
123 |
-
,**kwargs)
|
124 |
-
|
125 |
-
if r is None:
|
126 |
-
continue
|
127 |
-
|
128 |
-
_, outputs, ys, best_configs, time_used = r
|
129 |
-
|
130 |
-
if torch.is_tensor(outputs):
|
131 |
-
outputs = outputs.to(outputs.device)
|
132 |
-
ys = ys.to(outputs.device)
|
133 |
-
|
134 |
-
ys = ys.T
|
135 |
-
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
|
136 |
-
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
|
137 |
-
ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
|
138 |
-
ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
|
139 |
-
|
140 |
-
new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
|
141 |
-
|
142 |
-
if not return_tensor:
|
143 |
-
make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
|
144 |
-
new_metric = make_scalar(new_metric)
|
145 |
-
ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
|
146 |
-
|
147 |
-
lib = torch if return_tensor else np
|
148 |
-
if not lib.isnan(new_metric).any():
|
149 |
-
aggregated_metric, num = aggregated_metric + new_metric, num + 1
|
150 |
-
|
151 |
-
overall_result.update(ds_result)
|
152 |
-
if num > 0:
|
153 |
-
aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
|
154 |
-
|
155 |
-
overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
|
156 |
-
|
157 |
-
return overall_result
|
158 |
-
|
159 |
-
"""
|
160 |
-
===============================
|
161 |
-
INTERNAL HELPER FUNCTIONS
|
162 |
-
===============================
|
163 |
-
"""
|
164 |
-
|
165 |
-
def check_file_exists(path):
|
166 |
-
"""Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
|
167 |
-
if (os.path.isfile(path)):
|
168 |
-
print(f'loading results from {path}')
|
169 |
-
with open(path, 'rb') as f:
|
170 |
-
return np.load(f, allow_pickle=True).tolist()
|
171 |
-
return None
|
172 |
-
|
173 |
-
def generate_valid_split(X, y, bptt, eval_position, split_number=1):
|
174 |
-
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
|
175 |
-
the entire datasets. If no such split can be sampled in 7 passes, returns None.
|
176 |
-
|
177 |
-
:param X: torch tensor, feature values
|
178 |
-
:param y: torch tensor, class values
|
179 |
-
:param bptt: Number of samples in train + test
|
180 |
-
:param eval_position: Number of samples in train, i.e. from which index values are in test
|
181 |
-
:param split_number: The split id
|
182 |
-
:return:
|
183 |
-
"""
|
184 |
-
done, seed = False, 13
|
185 |
-
|
186 |
-
torch.manual_seed(split_number)
|
187 |
-
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
|
188 |
-
X, y = X[perm], y[perm]
|
189 |
-
|
190 |
-
while not done:
|
191 |
-
if seed > 20:
|
192 |
-
return None, None # No split could be generated in 7 passes, return None
|
193 |
-
random.seed(seed)
|
194 |
-
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
|
195 |
-
y_ = y[i:i + bptt]
|
196 |
-
|
197 |
-
# Checks if all classes from dataset are contained and classes in train and test are equal (contain same
|
198 |
-
# classes) and
|
199 |
-
done = len(torch.unique(y_)) == len(torch.unique(y))
|
200 |
-
done = done and torch.all(torch.unique(y_) == torch.unique(y))
|
201 |
-
done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
|
202 |
-
done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
|
203 |
-
seed = seed + 1
|
204 |
-
|
205 |
-
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
|
206 |
-
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
|
207 |
-
|
208 |
-
return eval_xs, eval_ys
|
209 |
-
|
210 |
-
|
211 |
-
def evaluate_position(X, y, categorical_feats, model, bptt
|
212 |
-
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
|
213 |
-
, max_time=300, split_number=1
|
214 |
-
, per_step_normalization=False, **kwargs):
|
215 |
-
"""
|
216 |
-
Evaluates a dataset with a 'bptt' number of training samples.
|
217 |
-
|
218 |
-
:param X: Dataset X
|
219 |
-
:param y: Dataset labels
|
220 |
-
:param categorical_feats: Indices of categorical features.
|
221 |
-
:param model: Model function
|
222 |
-
:param bptt: Sequence length.
|
223 |
-
:param eval_position: Number of training samples.
|
224 |
-
:param overwrite: Wheater to ove
|
225 |
-
:param overwrite: If True, results on disk are overwritten.
|
226 |
-
:param save:
|
227 |
-
:param path_interfix: Used for constructing path to write on disk.
|
228 |
-
:param method: Model name.
|
229 |
-
:param ds_name: Datset name.
|
230 |
-
:param fetch_only: Wheater to calculate or only fetch results.
|
231 |
-
:param per_step_normalization:
|
232 |
-
:param kwargs:
|
233 |
-
:return:
|
234 |
-
"""
|
235 |
-
|
236 |
-
if save:
|
237 |
-
path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
|
238 |
-
#log_path =
|
239 |
-
|
240 |
-
## Load results if on disk
|
241 |
-
if not overwrite:
|
242 |
-
result = check_file_exists(path)
|
243 |
-
if result is not None:
|
244 |
-
if not fetch_only:
|
245 |
-
print(f'Loaded saved result for {path}')
|
246 |
-
return result
|
247 |
-
elif fetch_only:
|
248 |
-
print(f'Could not load saved result for {path}')
|
249 |
-
return None
|
250 |
-
|
251 |
-
## Generate data splits
|
252 |
-
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
|
253 |
-
if eval_xs is None:
|
254 |
-
return None
|
255 |
-
print(f"No dataset could be generated {ds_name} {bptt}")
|
256 |
-
|
257 |
-
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
|
258 |
-
|
259 |
-
start_time = time.time()
|
260 |
-
|
261 |
-
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
|
262 |
-
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
|
263 |
-
else:
|
264 |
-
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
|
265 |
-
, eval_pos=eval_position
|
266 |
-
, max_time=max_time, **kwargs)
|
267 |
-
|
268 |
-
eval_ys = eval_ys[eval_position:]
|
269 |
-
if outputs is None:
|
270 |
-
return None
|
271 |
-
|
272 |
-
if torch.is_tensor(outputs): # Transfers data to cpu for saving
|
273 |
-
outputs = outputs.cpu()
|
274 |
-
eval_ys = eval_ys.cpu()
|
275 |
-
|
276 |
-
ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
|
277 |
-
|
278 |
-
if save:
|
279 |
-
with open(path, 'wb') as f:
|
280 |
-
np.save(f, ds_result)
|
281 |
-
print(f'saved results to {path}')
|
282 |
-
|
283 |
-
return ds_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
encoders.py
DELETED
@@ -1,243 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
from utils import normalize_data
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
8 |
-
|
9 |
-
|
10 |
-
class StyleEncoder(nn.Module):
|
11 |
-
def __init__(self, num_hyperparameters, em_size):
|
12 |
-
super().__init__()
|
13 |
-
self.em_size = em_size
|
14 |
-
self.embedding = nn.Linear(num_hyperparameters, self.em_size)
|
15 |
-
|
16 |
-
def forward(self, hyperparameters): # B x num_hps
|
17 |
-
return self.embedding(hyperparameters)
|
18 |
-
|
19 |
-
|
20 |
-
class StyleEmbEncoder(nn.Module):
|
21 |
-
def __init__(self, num_hyperparameters, em_size, num_embeddings=100):
|
22 |
-
super().__init__()
|
23 |
-
assert num_hyperparameters == 1
|
24 |
-
self.em_size = em_size
|
25 |
-
self.embedding = nn.Embedding(num_embeddings, self.em_size)
|
26 |
-
|
27 |
-
def forward(self, hyperparameters): # B x num_hps
|
28 |
-
return self.embedding(hyperparameters.squeeze(1))
|
29 |
-
|
30 |
-
|
31 |
-
class _PositionalEncoding(nn.Module):
|
32 |
-
def __init__(self, d_model, dropout=0.):
|
33 |
-
super().__init__()
|
34 |
-
self.dropout = nn.Dropout(p=dropout)
|
35 |
-
self.d_model = d_model
|
36 |
-
self.device_test_tensor = nn.Parameter(torch.tensor(1.))
|
37 |
-
|
38 |
-
def forward(self, x):# T x B x num_features
|
39 |
-
assert self.d_model % x.shape[-1]*2 == 0
|
40 |
-
d_per_feature = self.d_model // x.shape[-1]
|
41 |
-
pe = torch.zeros(*x.shape, d_per_feature, device=self.device_test_tensor.device)
|
42 |
-
#position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
43 |
-
interval_size = 10
|
44 |
-
div_term = (1./interval_size) * 2*math.pi*torch.exp(torch.arange(0, d_per_feature, 2, device=self.device_test_tensor.device).float()*math.log(math.sqrt(2)))
|
45 |
-
#print(div_term/2/math.pi)
|
46 |
-
pe[..., 0::2] = torch.sin(x.unsqueeze(-1) * div_term)
|
47 |
-
pe[..., 1::2] = torch.cos(x.unsqueeze(-1) * div_term)
|
48 |
-
return self.dropout(pe).view(x.shape[0],x.shape[1],self.d_model)
|
49 |
-
|
50 |
-
|
51 |
-
Positional = lambda _, emsize: _PositionalEncoding(d_model=emsize)
|
52 |
-
|
53 |
-
class EmbeddingEncoder(nn.Module):
|
54 |
-
def __init__(self, num_features, em_size, num_embs=100):
|
55 |
-
super().__init__()
|
56 |
-
self.num_embs = num_embs
|
57 |
-
self.embeddings = nn.Embedding(num_embs * num_features, em_size, max_norm=True)
|
58 |
-
self.init_weights(.1)
|
59 |
-
self.min_max = (-2,+2)
|
60 |
-
|
61 |
-
@property
|
62 |
-
def width(self):
|
63 |
-
return self.min_max[1] - self.min_max[0]
|
64 |
-
|
65 |
-
def init_weights(self, initrange):
|
66 |
-
self.embeddings.weight.data.uniform_(-initrange, initrange)
|
67 |
-
|
68 |
-
def discretize(self, x):
|
69 |
-
split_size = self.width / self.num_embs
|
70 |
-
return (x - self.min_max[0] // split_size).int().clamp(0, self.num_embs - 1)
|
71 |
-
|
72 |
-
def forward(self, x): # T x B x num_features
|
73 |
-
x_idxs = self.discretize(x)
|
74 |
-
x_idxs += torch.arange(x.shape[-1], device=x.device).view(1, 1, -1) * self.num_embs
|
75 |
-
# print(x_idxs,self.embeddings.weight.shape)
|
76 |
-
return self.embeddings(x_idxs).mean(-2)
|
77 |
-
|
78 |
-
|
79 |
-
class Normalize(nn.Module):
|
80 |
-
def __init__(self, mean, std):
|
81 |
-
super().__init__()
|
82 |
-
self.mean = mean
|
83 |
-
self.std = std
|
84 |
-
|
85 |
-
def forward(self, x):
|
86 |
-
return (x-self.mean)/self.std
|
87 |
-
|
88 |
-
|
89 |
-
def get_normalized_uniform_encoder(encoder_creator):
|
90 |
-
"""
|
91 |
-
This can be used to wrap an encoder that is fed uniform samples in [0,1] and normalizes these to 0 mean and 1 std.
|
92 |
-
For example, it can be used as `encoder_creator = get_normalized_uniform_encoder(encoders.Linear)`, now this can
|
93 |
-
be initialized with `encoder_creator(feature_dim, in_dim)`.
|
94 |
-
:param encoder:
|
95 |
-
:return:
|
96 |
-
"""
|
97 |
-
return lambda in_dim, out_dim: nn.Sequential(Normalize(.5, math.sqrt(1/12)), encoder_creator(in_dim, out_dim))
|
98 |
-
|
99 |
-
|
100 |
-
def get_normalized_encoder(encoder_creator, data_std):
|
101 |
-
return lambda in_dim, out_dim: nn.Sequential(Normalize(0., data_std), encoder_creator(in_dim, out_dim))
|
102 |
-
|
103 |
-
|
104 |
-
class ZNormalize(nn.Module):
|
105 |
-
def forward(self, x):
|
106 |
-
return (x-x.mean(-1,keepdim=True))/x.std(-1,keepdim=True)
|
107 |
-
|
108 |
-
|
109 |
-
class AppendEmbeddingEncoder(nn.Module):
|
110 |
-
def __init__(self, base_encoder, num_features, emsize):
|
111 |
-
super().__init__()
|
112 |
-
self.num_features = num_features
|
113 |
-
self.base_encoder = base_encoder
|
114 |
-
self.emb = nn.Parameter(torch.zeros(emsize))
|
115 |
-
|
116 |
-
def forward(self, x):
|
117 |
-
if (x[-1] == 1.).all():
|
118 |
-
append_embedding = True
|
119 |
-
else:
|
120 |
-
assert (x[-1] == 0.).all(), "You need to specify as last position whether to append embedding. " \
|
121 |
-
"If you don't want this behavior, please use the wrapped encoder instead."
|
122 |
-
append_embedding = False
|
123 |
-
x = x[:-1]
|
124 |
-
encoded_x = self.base_encoder(x)
|
125 |
-
if append_embedding:
|
126 |
-
encoded_x = torch.cat([encoded_x, self.emb[None, None, :].repeat(1, encoded_x.shape[1], 1)], 0)
|
127 |
-
return encoded_x
|
128 |
-
|
129 |
-
def get_append_embedding_encoder(encoder_creator):
|
130 |
-
return lambda num_features, emsize: AppendEmbeddingEncoder(encoder_creator(num_features, emsize), num_features, emsize)
|
131 |
-
|
132 |
-
|
133 |
-
class VariableNumFeaturesEncoder(nn.Module):
|
134 |
-
def __init__(self, base_encoder, num_features):
|
135 |
-
super().__init__()
|
136 |
-
self.base_encoder = base_encoder
|
137 |
-
self.num_features = num_features
|
138 |
-
|
139 |
-
def forward(self, x):
|
140 |
-
x = x * (self.num_features/x.shape[-1])
|
141 |
-
x = torch.cat((x, torch.zeros(*x.shape[:-1], self.num_features - x.shape[-1], device=x.device)), -1)
|
142 |
-
return self.base_encoder(x)
|
143 |
-
|
144 |
-
|
145 |
-
def get_variable_num_features_encoder(encoder_creator):
|
146 |
-
return lambda num_features, emsize: VariableNumFeaturesEncoder(encoder_creator(num_features, emsize), num_features)
|
147 |
-
|
148 |
-
class NoMeanEncoder(nn.Module):
|
149 |
-
"""
|
150 |
-
This can be useful for any prior that is translation invariant in x or y.
|
151 |
-
A standard GP for example is translation invariant in x.
|
152 |
-
That is, GP(x_test+const,x_train+const,y_train) = GP(x_test,x_train,y_train).
|
153 |
-
"""
|
154 |
-
def __init__(self, base_encoder):
|
155 |
-
super().__init__()
|
156 |
-
self.base_encoder = base_encoder
|
157 |
-
|
158 |
-
def forward(self, x):
|
159 |
-
return self.base_encoder(x - x.mean(0, keepdim=True))
|
160 |
-
|
161 |
-
|
162 |
-
def get_no_mean_encoder(encoder_creator):
|
163 |
-
return lambda num_features, emsize: NoMeanEncoder(encoder_creator(num_features, emsize))
|
164 |
-
|
165 |
-
Linear = nn.Linear
|
166 |
-
MLP = lambda num_features, emsize: nn.Sequential(nn.Linear(num_features+1,emsize*2),
|
167 |
-
nn.ReLU(),
|
168 |
-
nn.Linear(emsize*2,emsize))
|
169 |
-
|
170 |
-
class NanHandlingEncoder(nn.Module):
|
171 |
-
def __init__(self, num_features, emsize, keep_nans=True):
|
172 |
-
super().__init__()
|
173 |
-
self.num_features = 2 * num_features if keep_nans else num_features
|
174 |
-
self.emsize = emsize
|
175 |
-
self.keep_nans = keep_nans
|
176 |
-
self.layer = nn.Linear(self.num_features, self.emsize)
|
177 |
-
|
178 |
-
def forward(self, x):
|
179 |
-
if self.keep_nans:
|
180 |
-
x = torch.cat([torch.nan_to_num(x, nan=0.0), normalize_data(torch.isnan(x) * -1
|
181 |
-
+ torch.logical_and(torch.isinf(x), torch.sign(x) == 1) * 1
|
182 |
-
+ torch.logical_and(torch.isinf(x), torch.sign(x) == -1) * 2
|
183 |
-
)], -1)
|
184 |
-
else:
|
185 |
-
x = torch.nan_to_num(x, nan=0.0)
|
186 |
-
return self.layer(x)
|
187 |
-
|
188 |
-
|
189 |
-
class Linear(nn.Linear):
|
190 |
-
def __init__(self, num_features, emsize, replace_nan_by_zero=False):
|
191 |
-
super().__init__(num_features, emsize)
|
192 |
-
self.num_features = num_features
|
193 |
-
self.emsize = emsize
|
194 |
-
self.replace_nan_by_zero = replace_nan_by_zero
|
195 |
-
|
196 |
-
def forward(self, x):
|
197 |
-
if self.replace_nan_by_zero:
|
198 |
-
x = torch.nan_to_num(x, nan=0.0)
|
199 |
-
return super().forward(x)
|
200 |
-
|
201 |
-
def __setstate__(self, state):
|
202 |
-
super().__setstate__(state)
|
203 |
-
self.__dict__.setdefault('replace_nan_by_zero', True)
|
204 |
-
|
205 |
-
|
206 |
-
class Conv(nn.Module):
|
207 |
-
def __init__(self, input_size, emsize):
|
208 |
-
super().__init__()
|
209 |
-
self.convs = torch.nn.ModuleList([nn.Conv2d(64 if i else 1, 64, 3) for i in range(5)])
|
210 |
-
self.linear = nn.Linear(64,emsize)
|
211 |
-
|
212 |
-
def forward(self, x):
|
213 |
-
size = math.isqrt(x.shape[-1])
|
214 |
-
assert size*size == x.shape[-1]
|
215 |
-
x = x.reshape(*x.shape[:-1], 1, size, size)
|
216 |
-
for conv in self.convs:
|
217 |
-
if x.shape[-1] < 4:
|
218 |
-
break
|
219 |
-
x = conv(x)
|
220 |
-
x.relu_()
|
221 |
-
x = nn.AdaptiveAvgPool2d((1,1))(x).squeeze(-1).squeeze(-1)
|
222 |
-
return self.linear(x)
|
223 |
-
|
224 |
-
|
225 |
-
class CanEmb(nn.Embedding):
|
226 |
-
def __init__(self, num_features, num_embeddings: int, embedding_dim: int, *args, **kwargs):
|
227 |
-
assert embedding_dim % num_features == 0
|
228 |
-
embedding_dim = embedding_dim // num_features
|
229 |
-
super().__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
230 |
-
|
231 |
-
def forward(self, x):
|
232 |
-
lx = x.long()
|
233 |
-
assert (lx == x).all(), "CanEmb only works with tensors of whole numbers"
|
234 |
-
x = super().forward(lx)
|
235 |
-
return x.view(*x.shape[:-2], -1)
|
236 |
-
|
237 |
-
|
238 |
-
def get_Canonical(num_classes):
|
239 |
-
return lambda num_features, emsize: CanEmb(num_features, num_classes, emsize)
|
240 |
-
|
241 |
-
|
242 |
-
def get_Embedding(num_embs_per_feature=100):
|
243 |
-
return lambda num_features, emsize: EmbeddingEncoder(num_features, emsize, num_embs=num_embs_per_feature)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|