Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
ba6e296
1
Parent(s):
ebce55f
Clean up and format example
Browse files- examples/pysr_demo.ipynb +87 -69
examples/pysr_demo.ipynb
CHANGED
@@ -33,9 +33,7 @@
|
|
33 |
"id": "COndi88gbDgO"
|
34 |
},
|
35 |
"source": [
|
36 |
-
"**Run the following code
|
37 |
-
"\n",
|
38 |
-
"**(select all lines -> Option-/)**"
|
39 |
]
|
40 |
},
|
41 |
{
|
@@ -95,7 +93,14 @@
|
|
95 |
},
|
96 |
"outputs": [],
|
97 |
"source": [
|
98 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
]
|
100 |
},
|
101 |
{
|
@@ -106,13 +111,13 @@
|
|
106 |
},
|
107 |
"outputs": [],
|
108 |
"source": [
|
109 |
-
"# Required to get printing from Julia working in colab \n",
|
110 |
-
"# (you don't need to normally do this)\n",
|
111 |
"from julia import Julia\n",
|
|
|
112 |
"julia = Julia(compiled_modules=False)\n",
|
113 |
"from julia import Main\n",
|
114 |
"from julia.tools import redirect_output_streams\n",
|
115 |
-
"
|
|
|
116 |
]
|
117 |
},
|
118 |
{
|
@@ -123,7 +128,7 @@
|
|
123 |
"source": [
|
124 |
"Let's install the backend of PySR, and all required libraries. We will also precompile them so they are faster at startup.\n",
|
125 |
"\n",
|
126 |
-
"**This
|
127 |
]
|
128 |
},
|
129 |
{
|
@@ -139,7 +144,8 @@
|
|
139 |
"outputs": [],
|
140 |
"source": [
|
141 |
"import pysr\n",
|
142 |
-
"
|
|
|
143 |
]
|
144 |
},
|
145 |
{
|
@@ -159,7 +165,7 @@
|
|
159 |
"from torch.nn import functional as F\n",
|
160 |
"from torch.utils.data import DataLoader, TensorDataset\n",
|
161 |
"import pytorch_lightning as pl\n",
|
162 |
-
"from sklearn.model_selection import train_test_split"
|
163 |
]
|
164 |
},
|
165 |
{
|
@@ -192,8 +198,8 @@
|
|
192 |
"source": [
|
193 |
"# Dataset\n",
|
194 |
"np.random.seed(0)\n",
|
195 |
-
"X = 2*np.random.randn(100, 5)\n",
|
196 |
-
"y = 2.5382*np.cos(X[:, 3]) + X[:, 0]**2 - 2"
|
197 |
]
|
198 |
},
|
199 |
{
|
@@ -217,7 +223,7 @@
|
|
217 |
" populations=30,\n",
|
218 |
" procs=4,\n",
|
219 |
" model_selection=\"best\",\n",
|
220 |
-
")"
|
221 |
]
|
222 |
},
|
223 |
{
|
@@ -246,12 +252,13 @@
|
|
246 |
"source": [
|
247 |
"# Learn equations\n",
|
248 |
"model = PySRRegressor(\n",
|
249 |
-
"
|
250 |
-
"
|
251 |
-
"
|
252 |
-
"
|
|
|
253 |
"\n",
|
254 |
-
"model.fit(X, y)"
|
255 |
]
|
256 |
},
|
257 |
{
|
@@ -275,7 +282,7 @@
|
|
275 |
},
|
276 |
"outputs": [],
|
277 |
"source": [
|
278 |
-
"model"
|
279 |
]
|
280 |
},
|
281 |
{
|
@@ -300,7 +307,7 @@
|
|
300 |
},
|
301 |
"outputs": [],
|
302 |
"source": [
|
303 |
-
"model.sympy()"
|
304 |
]
|
305 |
},
|
306 |
{
|
@@ -309,7 +316,7 @@
|
|
309 |
"id": "EHIIPlmClltn"
|
310 |
},
|
311 |
"source": [
|
312 |
-
"We can also view the SymPy of any other expression in the list, using the index of it in `model.
|
313 |
]
|
314 |
},
|
315 |
{
|
@@ -325,7 +332,7 @@
|
|
325 |
},
|
326 |
"outputs": [],
|
327 |
"source": [
|
328 |
-
"model.sympy(2)"
|
329 |
]
|
330 |
},
|
331 |
{
|
@@ -359,7 +366,7 @@
|
|
359 |
},
|
360 |
"outputs": [],
|
361 |
"source": [
|
362 |
-
"model.latex()"
|
363 |
]
|
364 |
},
|
365 |
{
|
@@ -389,7 +396,7 @@
|
|
389 |
"ypredict_simpler = model.predict(X, 2)\n",
|
390 |
"\n",
|
391 |
"print(\"Default selection MSE:\", np.power(ypredict - y, 2).mean())\n",
|
392 |
-
"print(\"Manual selection MSE for index 2:\", np.power(ypredict_simpler - y, 2).mean())"
|
393 |
]
|
394 |
},
|
395 |
{
|
@@ -423,7 +430,7 @@
|
|
423 |
},
|
424 |
"outputs": [],
|
425 |
"source": [
|
426 |
-
"y = X[:, 0]**4 - 2"
|
427 |
]
|
428 |
},
|
429 |
{
|
@@ -451,12 +458,13 @@
|
|
451 |
"outputs": [],
|
452 |
"source": [
|
453 |
"model = PySRRegressor(\n",
|
454 |
-
"
|
455 |
-
"
|
456 |
-
"
|
457 |
-
"
|
458 |
-
"
|
459 |
-
"
|
|
|
460 |
]
|
461 |
},
|
462 |
{
|
@@ -472,7 +480,7 @@
|
|
472 |
},
|
473 |
"outputs": [],
|
474 |
"source": [
|
475 |
-
"model.sympy()"
|
476 |
]
|
477 |
},
|
478 |
{
|
@@ -571,10 +579,10 @@
|
|
571 |
"np.random.seed(0)\n",
|
572 |
"N = 3000\n",
|
573 |
"upper_sigma = 5\n",
|
574 |
-
"X = 2*np.random.rand(N, 5)\n",
|
575 |
-
"sigma = np.random.rand(N)*(5-0.1) + 0.1\n",
|
576 |
-
"eps = sigma*np.random.randn(N)\n",
|
577 |
-
"y = 5*np.cos(3.5*X[:, 0]) - 1.3 + eps"
|
578 |
]
|
579 |
},
|
580 |
{
|
@@ -601,7 +609,7 @@
|
|
601 |
"source": [
|
602 |
"plt.scatter(X[:, 0], y, alpha=0.2)\n",
|
603 |
"plt.xlabel(\"$x_0$\")\n",
|
604 |
-
"plt.ylabel(\"$y$\")"
|
605 |
]
|
606 |
},
|
607 |
{
|
@@ -621,7 +629,7 @@
|
|
621 |
},
|
622 |
"outputs": [],
|
623 |
"source": [
|
624 |
-
"weights = 1/sigma**2"
|
625 |
]
|
626 |
},
|
627 |
{
|
@@ -636,7 +644,7 @@
|
|
636 |
},
|
637 |
"outputs": [],
|
638 |
"source": [
|
639 |
-
"weights[:5]"
|
640 |
]
|
641 |
},
|
642 |
{
|
@@ -662,13 +670,13 @@
|
|
662 |
"outputs": [],
|
663 |
"source": [
|
664 |
"model = PySRRegressor(\n",
|
665 |
-
" loss
|
666 |
" niterations=20,\n",
|
667 |
" populations=20, # Use more populations\n",
|
668 |
" binary_operators=[\"plus\", \"mult\"],\n",
|
669 |
-
" unary_operators=[\"cos\"]
|
670 |
")\n",
|
671 |
-
"model.fit(X, y, weights=weights)"
|
672 |
]
|
673 |
},
|
674 |
{
|
@@ -692,7 +700,7 @@
|
|
692 |
},
|
693 |
"outputs": [],
|
694 |
"source": [
|
695 |
-
"model"
|
696 |
]
|
697 |
},
|
698 |
{
|
@@ -717,8 +725,10 @@
|
|
717 |
},
|
718 |
"outputs": [],
|
719 |
"source": [
|
720 |
-
"best_idx = model.equations_.query(
|
721 |
-
"model.
|
|
|
|
|
722 |
]
|
723 |
},
|
724 |
{
|
@@ -753,7 +763,8 @@
|
|
753 |
"outputs": [],
|
754 |
"source": [
|
755 |
"plt.scatter(X[:, 0], y, alpha=0.1)\n",
|
756 |
-
"
|
|
|
757 |
]
|
758 |
},
|
759 |
{
|
@@ -800,10 +811,10 @@
|
|
800 |
"N = 100000\n",
|
801 |
"Nt = 100\n",
|
802 |
"X = 6 * np.random.rand(N, Nt, 5) - 3\n",
|
803 |
-
"y_i = X[..., 0]**2 + 6*np.cos(2*X[..., 2])\n",
|
804 |
-
"y = np.sum(y_i, axis=1)/y_i.shape[1]\n",
|
805 |
"z = y**2\n",
|
806 |
-
"X.shape, y.shape"
|
807 |
]
|
808 |
},
|
809 |
{
|
@@ -843,6 +854,7 @@
|
|
843 |
"hidden = 128\n",
|
844 |
"total_steps = 50000\n",
|
845 |
"\n",
|
|
|
846 |
"def mlp(size_in, size_out, act=nn.ReLU):\n",
|
847 |
" return nn.Sequential(\n",
|
848 |
" nn.Linear(size_in, hidden),\n",
|
@@ -851,13 +863,14 @@
|
|
851 |
" act(),\n",
|
852 |
" nn.Linear(hidden, hidden),\n",
|
853 |
" act(),\n",
|
854 |
-
" nn.Linear(hidden, size_out)
|
|
|
855 |
"\n",
|
856 |
"\n",
|
857 |
"class SumNet(pl.LightningModule):\n",
|
858 |
" def __init__(self):\n",
|
859 |
" super().__init__()\n",
|
860 |
-
"
|
861 |
" ########################################################\n",
|
862 |
" # The same inductive bias as above!\n",
|
863 |
" self.g = mlp(5, 1)\n",
|
@@ -865,11 +878,12 @@
|
|
865 |
"\n",
|
866 |
" def forward(self, x):\n",
|
867 |
" y_i = self.g(x)[:, :, 0]\n",
|
868 |
-
" y = torch.sum(y_i, dim=1, keepdim=True)/y_i.shape[1]\n",
|
869 |
" z = self.f(y)\n",
|
870 |
" return z[:, 0]\n",
|
|
|
871 |
" ########################################################\n",
|
872 |
-
"
|
873 |
" # PyTorch Lightning bookkeeping:\n",
|
874 |
" def training_step(self, batch, batch_idx):\n",
|
875 |
" x, z = batch\n",
|
@@ -882,14 +896,18 @@
|
|
882 |
"\n",
|
883 |
" def configure_optimizers(self):\n",
|
884 |
" self.trainer.reset_train_dataloader()\n",
|
885 |
-
" # self.train_dataloader.loaders # access it here.\n",
|
886 |
"\n",
|
887 |
" optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)\n",
|
888 |
-
" scheduler = {
|
889 |
-
"
|
890 |
-
"
|
891 |
-
"
|
892 |
-
"
|
|
|
|
|
|
|
|
|
|
|
893 |
]
|
894 |
},
|
895 |
{
|
@@ -924,7 +942,7 @@
|
|
924 |
"train_set = TensorDataset(X_train, z_train)\n",
|
925 |
"train = DataLoader(train_set, batch_size=128, num_workers=2)\n",
|
926 |
"test_set = TensorDataset(X_test, z_test)\n",
|
927 |
-
"test = DataLoader(test_set, batch_size=256, num_workers=2)"
|
928 |
]
|
929 |
},
|
930 |
{
|
@@ -960,7 +978,7 @@
|
|
960 |
"pl.seed_everything(0)\n",
|
961 |
"model = SumNet()\n",
|
962 |
"model.total_steps = total_steps\n",
|
963 |
-
"model.max_lr = 1e-2"
|
964 |
]
|
965 |
},
|
966 |
{
|
@@ -984,7 +1002,7 @@
|
|
984 |
},
|
985 |
"outputs": [],
|
986 |
"source": [
|
987 |
-
"trainer = pl.Trainer(max_steps=total_steps, gpus=1, benchmark=True)"
|
988 |
]
|
989 |
},
|
990 |
{
|
@@ -1033,7 +1051,7 @@
|
|
1033 |
},
|
1034 |
"outputs": [],
|
1035 |
"source": [
|
1036 |
-
"trainer.fit(model, train_dataloaders=train, val_dataloaders=test)"
|
1037 |
]
|
1038 |
},
|
1039 |
{
|
@@ -1064,10 +1082,10 @@
|
|
1064 |
"\n",
|
1065 |
"X_for_pysr = Xt[idx]\n",
|
1066 |
"y_i_for_pysr = model.g(X_for_pysr)[:, :, 0]\n",
|
1067 |
-
"y_for_pysr = torch.sum(y_i_for_pysr, dim=1)/y_i_for_pysr.shape[1]\n",
|
1068 |
-
"z_for_pysr = zt[idx]
|
1069 |
"\n",
|
1070 |
-
"X_for_pysr.shape, y_i_for_pysr.shape"
|
1071 |
]
|
1072 |
},
|
1073 |
{
|
@@ -1102,9 +1120,9 @@
|
|
1102 |
"model = PySRRegressor(\n",
|
1103 |
" niterations=20,\n",
|
1104 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
1105 |
-
" unary_operators=[\"cos\", \"square\", \"neg\"]
|
1106 |
")\n",
|
1107 |
-
"model.fit(X=tmpX[idx2], y=tmpy[idx2])"
|
1108 |
]
|
1109 |
},
|
1110 |
{
|
@@ -1135,7 +1153,7 @@
|
|
1135 |
},
|
1136 |
"outputs": [],
|
1137 |
"source": [
|
1138 |
-
"model"
|
1139 |
]
|
1140 |
},
|
1141 |
{
|
|
|
33 |
"id": "COndi88gbDgO"
|
34 |
},
|
35 |
"source": [
|
36 |
+
"**Run the following code to install Julia**"
|
|
|
|
|
37 |
]
|
38 |
},
|
39 |
{
|
|
|
93 |
},
|
94 |
"outputs": [],
|
95 |
"source": [
|
96 |
+
"%pip install -Uq pysr pytorch_lightning"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "markdown",
|
101 |
+
"metadata": {},
|
102 |
+
"source": [
|
103 |
+
"The following step is not normally required, but colab's printing is non-standard and we need to manually set it up PyJulia:\n"
|
104 |
]
|
105 |
},
|
106 |
{
|
|
|
111 |
},
|
112 |
"outputs": [],
|
113 |
"source": [
|
|
|
|
|
114 |
"from julia import Julia\n",
|
115 |
+
"\n",
|
116 |
"julia = Julia(compiled_modules=False)\n",
|
117 |
"from julia import Main\n",
|
118 |
"from julia.tools import redirect_output_streams\n",
|
119 |
+
"\n",
|
120 |
+
"redirect_output_streams()\n"
|
121 |
]
|
122 |
},
|
123 |
{
|
|
|
128 |
"source": [
|
129 |
"Let's install the backend of PySR, and all required libraries. We will also precompile them so they are faster at startup.\n",
|
130 |
"\n",
|
131 |
+
"**(This may take some time)**"
|
132 |
]
|
133 |
},
|
134 |
{
|
|
|
144 |
"outputs": [],
|
145 |
"source": [
|
146 |
"import pysr\n",
|
147 |
+
"\n",
|
148 |
+
"pysr.install()\n"
|
149 |
]
|
150 |
},
|
151 |
{
|
|
|
165 |
"from torch.nn import functional as F\n",
|
166 |
"from torch.utils.data import DataLoader, TensorDataset\n",
|
167 |
"import pytorch_lightning as pl\n",
|
168 |
+
"from sklearn.model_selection import train_test_split\n"
|
169 |
]
|
170 |
},
|
171 |
{
|
|
|
198 |
"source": [
|
199 |
"# Dataset\n",
|
200 |
"np.random.seed(0)\n",
|
201 |
+
"X = 2 * np.random.randn(100, 5)\n",
|
202 |
+
"y = 2.5382 * np.cos(X[:, 3]) + X[:, 0] ** 2 - 2\n"
|
203 |
]
|
204 |
},
|
205 |
{
|
|
|
223 |
" populations=30,\n",
|
224 |
" procs=4,\n",
|
225 |
" model_selection=\"best\",\n",
|
226 |
+
")\n"
|
227 |
]
|
228 |
},
|
229 |
{
|
|
|
252 |
"source": [
|
253 |
"# Learn equations\n",
|
254 |
"model = PySRRegressor(\n",
|
255 |
+
" niterations=30,\n",
|
256 |
+
" binary_operators=[\"plus\", \"mult\"],\n",
|
257 |
+
" unary_operators=[\"cos\", \"exp\", \"sin\"],\n",
|
258 |
+
" **default_pysr_params\n",
|
259 |
+
")\n",
|
260 |
"\n",
|
261 |
+
"model.fit(X, y)\n"
|
262 |
]
|
263 |
},
|
264 |
{
|
|
|
282 |
},
|
283 |
"outputs": [],
|
284 |
"source": [
|
285 |
+
"model\n"
|
286 |
]
|
287 |
},
|
288 |
{
|
|
|
307 |
},
|
308 |
"outputs": [],
|
309 |
"source": [
|
310 |
+
"model.sympy()\n"
|
311 |
]
|
312 |
},
|
313 |
{
|
|
|
316 |
"id": "EHIIPlmClltn"
|
317 |
},
|
318 |
"source": [
|
319 |
+
"We can also view the SymPy of any other expression in the list, using the index of it in `model.equations_`."
|
320 |
]
|
321 |
},
|
322 |
{
|
|
|
332 |
},
|
333 |
"outputs": [],
|
334 |
"source": [
|
335 |
+
"model.sympy(2)\n"
|
336 |
]
|
337 |
},
|
338 |
{
|
|
|
366 |
},
|
367 |
"outputs": [],
|
368 |
"source": [
|
369 |
+
"model.latex()\n"
|
370 |
]
|
371 |
},
|
372 |
{
|
|
|
396 |
"ypredict_simpler = model.predict(X, 2)\n",
|
397 |
"\n",
|
398 |
"print(\"Default selection MSE:\", np.power(ypredict - y, 2).mean())\n",
|
399 |
+
"print(\"Manual selection MSE for index 2:\", np.power(ypredict_simpler - y, 2).mean())\n"
|
400 |
]
|
401 |
},
|
402 |
{
|
|
|
430 |
},
|
431 |
"outputs": [],
|
432 |
"source": [
|
433 |
+
"y = X[:, 0] ** 4 - 2\n"
|
434 |
]
|
435 |
},
|
436 |
{
|
|
|
458 |
"outputs": [],
|
459 |
"source": [
|
460 |
"model = PySRRegressor(\n",
|
461 |
+
" niterations=5,\n",
|
462 |
+
" populations=40,\n",
|
463 |
+
" binary_operators=[\"plus\", \"mult\"],\n",
|
464 |
+
" unary_operators=[\"cos\", \"exp\", \"sin\", \"quart(x) = x^4\"],\n",
|
465 |
+
" extra_sympy_mappings={\"quart\": lambda x: x**4},\n",
|
466 |
+
")\n",
|
467 |
+
"model.fit(X, y)\n"
|
468 |
]
|
469 |
},
|
470 |
{
|
|
|
480 |
},
|
481 |
"outputs": [],
|
482 |
"source": [
|
483 |
+
"model.sympy()\n"
|
484 |
]
|
485 |
},
|
486 |
{
|
|
|
579 |
"np.random.seed(0)\n",
|
580 |
"N = 3000\n",
|
581 |
"upper_sigma = 5\n",
|
582 |
+
"X = 2 * np.random.rand(N, 5)\n",
|
583 |
+
"sigma = np.random.rand(N) * (5 - 0.1) + 0.1\n",
|
584 |
+
"eps = sigma * np.random.randn(N)\n",
|
585 |
+
"y = 5 * np.cos(3.5 * X[:, 0]) - 1.3 + eps\n"
|
586 |
]
|
587 |
},
|
588 |
{
|
|
|
609 |
"source": [
|
610 |
"plt.scatter(X[:, 0], y, alpha=0.2)\n",
|
611 |
"plt.xlabel(\"$x_0$\")\n",
|
612 |
+
"plt.ylabel(\"$y$\")\n"
|
613 |
]
|
614 |
},
|
615 |
{
|
|
|
629 |
},
|
630 |
"outputs": [],
|
631 |
"source": [
|
632 |
+
"weights = 1 / sigma[:, None] ** 2\n"
|
633 |
]
|
634 |
},
|
635 |
{
|
|
|
644 |
},
|
645 |
"outputs": [],
|
646 |
"source": [
|
647 |
+
"weights[:5, 0]\n"
|
648 |
]
|
649 |
},
|
650 |
{
|
|
|
670 |
"outputs": [],
|
671 |
"source": [
|
672 |
"model = PySRRegressor(\n",
|
673 |
+
" loss=\"myloss(x, y, w) = w * abs(x - y)\", # Custom loss function with weights.\n",
|
674 |
" niterations=20,\n",
|
675 |
" populations=20, # Use more populations\n",
|
676 |
" binary_operators=[\"plus\", \"mult\"],\n",
|
677 |
+
" unary_operators=[\"cos\"],\n",
|
678 |
")\n",
|
679 |
+
"model.fit(X, y, weights=weights)\n"
|
680 |
]
|
681 |
},
|
682 |
{
|
|
|
700 |
},
|
701 |
"outputs": [],
|
702 |
"source": [
|
703 |
+
"model\n"
|
704 |
]
|
705 |
},
|
706 |
{
|
|
|
725 |
},
|
726 |
"outputs": [],
|
727 |
"source": [
|
728 |
+
"best_idx = model.equations_.query(\n",
|
729 |
+
" f\"loss < {2 * model.equations_.loss.min()}\"\n",
|
730 |
+
").score.idxmax()\n",
|
731 |
+
"model.sympy(best_idx)\n"
|
732 |
]
|
733 |
},
|
734 |
{
|
|
|
763 |
"outputs": [],
|
764 |
"source": [
|
765 |
"plt.scatter(X[:, 0], y, alpha=0.1)\n",
|
766 |
+
"y_prediction = model.predict(X, index=best_idx)\n",
|
767 |
+
"plt.scatter(X[:, 0], y_prediction)\n"
|
768 |
]
|
769 |
},
|
770 |
{
|
|
|
811 |
"N = 100000\n",
|
812 |
"Nt = 100\n",
|
813 |
"X = 6 * np.random.rand(N, Nt, 5) - 3\n",
|
814 |
+
"y_i = X[..., 0] ** 2 + 6 * np.cos(2 * X[..., 2])\n",
|
815 |
+
"y = np.sum(y_i, axis=1) / y_i.shape[1]\n",
|
816 |
"z = y**2\n",
|
817 |
+
"X.shape, y.shape\n"
|
818 |
]
|
819 |
},
|
820 |
{
|
|
|
854 |
"hidden = 128\n",
|
855 |
"total_steps = 50000\n",
|
856 |
"\n",
|
857 |
+
"\n",
|
858 |
"def mlp(size_in, size_out, act=nn.ReLU):\n",
|
859 |
" return nn.Sequential(\n",
|
860 |
" nn.Linear(size_in, hidden),\n",
|
|
|
863 |
" act(),\n",
|
864 |
" nn.Linear(hidden, hidden),\n",
|
865 |
" act(),\n",
|
866 |
+
" nn.Linear(hidden, size_out),\n",
|
867 |
+
" )\n",
|
868 |
"\n",
|
869 |
"\n",
|
870 |
"class SumNet(pl.LightningModule):\n",
|
871 |
" def __init__(self):\n",
|
872 |
" super().__init__()\n",
|
873 |
+
"\n",
|
874 |
" ########################################################\n",
|
875 |
" # The same inductive bias as above!\n",
|
876 |
" self.g = mlp(5, 1)\n",
|
|
|
878 |
"\n",
|
879 |
" def forward(self, x):\n",
|
880 |
" y_i = self.g(x)[:, :, 0]\n",
|
881 |
+
" y = torch.sum(y_i, dim=1, keepdim=True) / y_i.shape[1]\n",
|
882 |
" z = self.f(y)\n",
|
883 |
" return z[:, 0]\n",
|
884 |
+
"\n",
|
885 |
" ########################################################\n",
|
886 |
+
"\n",
|
887 |
" # PyTorch Lightning bookkeeping:\n",
|
888 |
" def training_step(self, batch, batch_idx):\n",
|
889 |
" x, z = batch\n",
|
|
|
896 |
"\n",
|
897 |
" def configure_optimizers(self):\n",
|
898 |
" self.trainer.reset_train_dataloader()\n",
|
|
|
899 |
"\n",
|
900 |
" optimizer = torch.optim.Adam(self.parameters(), lr=self.max_lr)\n",
|
901 |
+
" scheduler = {\n",
|
902 |
+
" \"scheduler\": torch.optim.lr_scheduler.OneCycleLR(\n",
|
903 |
+
" optimizer,\n",
|
904 |
+
" max_lr=self.max_lr,\n",
|
905 |
+
" total_steps=self.total_steps,\n",
|
906 |
+
" final_div_factor=1e4,\n",
|
907 |
+
" ),\n",
|
908 |
+
" \"interval\": \"step\",\n",
|
909 |
+
" }\n",
|
910 |
+
" return [optimizer], [scheduler]\n"
|
911 |
]
|
912 |
},
|
913 |
{
|
|
|
942 |
"train_set = TensorDataset(X_train, z_train)\n",
|
943 |
"train = DataLoader(train_set, batch_size=128, num_workers=2)\n",
|
944 |
"test_set = TensorDataset(X_test, z_test)\n",
|
945 |
+
"test = DataLoader(test_set, batch_size=256, num_workers=2)\n"
|
946 |
]
|
947 |
},
|
948 |
{
|
|
|
978 |
"pl.seed_everything(0)\n",
|
979 |
"model = SumNet()\n",
|
980 |
"model.total_steps = total_steps\n",
|
981 |
+
"model.max_lr = 1e-2\n"
|
982 |
]
|
983 |
},
|
984 |
{
|
|
|
1002 |
},
|
1003 |
"outputs": [],
|
1004 |
"source": [
|
1005 |
+
"trainer = pl.Trainer(max_steps=total_steps, gpus=1, benchmark=True)\n"
|
1006 |
]
|
1007 |
},
|
1008 |
{
|
|
|
1051 |
},
|
1052 |
"outputs": [],
|
1053 |
"source": [
|
1054 |
+
"trainer.fit(model, train_dataloaders=train, val_dataloaders=test)\n"
|
1055 |
]
|
1056 |
},
|
1057 |
{
|
|
|
1082 |
"\n",
|
1083 |
"X_for_pysr = Xt[idx]\n",
|
1084 |
"y_i_for_pysr = model.g(X_for_pysr)[:, :, 0]\n",
|
1085 |
+
"y_for_pysr = torch.sum(y_i_for_pysr, dim=1) / y_i_for_pysr.shape[1]\n",
|
1086 |
+
"z_for_pysr = zt[idx] # Use true values.\n",
|
1087 |
"\n",
|
1088 |
+
"X_for_pysr.shape, y_i_for_pysr.shape\n"
|
1089 |
]
|
1090 |
},
|
1091 |
{
|
|
|
1120 |
"model = PySRRegressor(\n",
|
1121 |
" niterations=20,\n",
|
1122 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
1123 |
+
" unary_operators=[\"cos\", \"square\", \"neg\"],\n",
|
1124 |
")\n",
|
1125 |
+
"model.fit(X=tmpX[idx2], y=tmpy[idx2])\n"
|
1126 |
]
|
1127 |
},
|
1128 |
{
|
|
|
1153 |
},
|
1154 |
"outputs": [],
|
1155 |
"source": [
|
1156 |
+
"model\n"
|
1157 |
]
|
1158 |
},
|
1159 |
{
|