Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
6d58816
1
Parent(s):
49b163d
Refactor backend loading
Browse files- pysr/julia_helpers.py +22 -0
- pysr/sr.py +7 -14
pysr/julia_helpers.py
CHANGED
@@ -4,6 +4,7 @@ import subprocess
|
|
4 |
import warnings
|
5 |
from pathlib import Path
|
6 |
import os
|
|
|
7 |
|
8 |
from .version import __version__, __symbolic_regression_jl_version__
|
9 |
|
@@ -230,3 +231,24 @@ def _version_assertion():
|
|
230 |
"PySR requires Julia 1.6.0 or greater. "
|
231 |
"Please update your Julia installation."
|
232 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import warnings
|
5 |
from pathlib import Path
|
6 |
import os
|
7 |
+
from julia.api import JuliaError
|
8 |
|
9 |
from .version import __version__, __symbolic_regression_jl_version__
|
10 |
|
|
|
231 |
"PySR requires Julia 1.6.0 or greater. "
|
232 |
"Please update your Julia installation."
|
233 |
)
|
234 |
+
|
235 |
+
|
236 |
+
def _load_cluster_manager(Main, cluster_manager):
|
237 |
+
Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
|
238 |
+
return Main.eval(f"addprocs_{cluster_manager}")
|
239 |
+
|
240 |
+
|
241 |
+
def _update_julia_project(Main, julia_project, is_shared, io_arg):
|
242 |
+
try:
|
243 |
+
if is_shared:
|
244 |
+
_add_sr_to_julia_project(Main, io_arg)
|
245 |
+
Main.eval(f"Pkg.resolve({io_arg})")
|
246 |
+
except (JuliaError, RuntimeError) as e:
|
247 |
+
raise ImportError(_import_error_string(julia_project)) from e
|
248 |
+
|
249 |
+
|
250 |
+
def _load_backend(Main, julia_project):
|
251 |
+
try:
|
252 |
+
Main.eval("using SymbolicRegression")
|
253 |
+
except (JuliaError, RuntimeError) as e:
|
254 |
+
raise ImportError(_import_error_string(julia_project)) from e
|
pysr/sr.py
CHANGED
@@ -26,8 +26,9 @@ from .julia_helpers import (
|
|
26 |
_process_julia_project,
|
27 |
is_julia_version_greater_eq,
|
28 |
_escape_filename,
|
29 |
-
|
30 |
-
|
|
|
31 |
)
|
32 |
from .export_numpy import CallableEquation
|
33 |
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
|
@@ -1453,8 +1454,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1453 |
Main = init_julia(self.julia_project)
|
1454 |
|
1455 |
if cluster_manager is not None:
|
1456 |
-
|
1457 |
-
cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
|
1458 |
|
1459 |
if not already_ran:
|
1460 |
julia_project, is_shared = _process_julia_project(self.julia_project)
|
@@ -1470,16 +1470,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1470 |
from julia.api import JuliaError
|
1471 |
|
1472 |
if self.update:
|
1473 |
-
|
1474 |
-
|
1475 |
-
|
1476 |
-
Main.eval(f"Pkg.resolve({io_arg})")
|
1477 |
-
except (JuliaError, RuntimeError) as e:
|
1478 |
-
raise ImportError(_import_error_string(julia_project)) from e
|
1479 |
-
try:
|
1480 |
-
Main.eval("using SymbolicRegression")
|
1481 |
-
except (JuliaError, RuntimeError) as e:
|
1482 |
-
raise ImportError(_import_error_string(julia_project)) from e
|
1483 |
|
1484 |
Main.plus = Main.eval("(+)")
|
1485 |
Main.sub = Main.eval("(-)")
|
|
|
26 |
_process_julia_project,
|
27 |
is_julia_version_greater_eq,
|
28 |
_escape_filename,
|
29 |
+
_load_cluster_manager,
|
30 |
+
_update_julia_project,
|
31 |
+
_load_backend,
|
32 |
)
|
33 |
from .export_numpy import CallableEquation
|
34 |
from .export_latex import generate_single_table, generate_multiple_tables, to_latex
|
|
|
1454 |
Main = init_julia(self.julia_project)
|
1455 |
|
1456 |
if cluster_manager is not None:
|
1457 |
+
cluster_manager = _load_cluster_manager(cluster_manager)
|
|
|
1458 |
|
1459 |
if not already_ran:
|
1460 |
julia_project, is_shared = _process_julia_project(self.julia_project)
|
|
|
1470 |
from julia.api import JuliaError
|
1471 |
|
1472 |
if self.update:
|
1473 |
+
_update_julia_project(Main, julia_project, is_shared, io_arg)
|
1474 |
+
|
1475 |
+
_load_backend(Main, julia_project)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1476 |
|
1477 |
Main.plus = Main.eval("(+)")
|
1478 |
Main.sub = Main.eval("(-)")
|