Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d72c643
1
Parent(s):
7a42396
Save raw bytes so can warm-restart in new python session
Browse files- pysr/julia_helpers.py +10 -2
- pysr/sr.py +34 -22
pysr/julia_helpers.py
CHANGED
@@ -22,8 +22,7 @@ import juliapkg
|
|
22 |
from juliacall import Main as jl
|
23 |
from juliacall import convert as jl_convert
|
24 |
|
25 |
-
jl.seval("using
|
26 |
-
PythonCall = jl.PythonCall
|
27 |
|
28 |
juliainfo = None
|
29 |
julia_initialized = False
|
@@ -63,3 +62,12 @@ def jl_array(x):
|
|
63 |
if x is None:
|
64 |
return None
|
65 |
return jl_convert(jl.Array, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from juliacall import Main as jl
|
23 |
from juliacall import convert as jl_convert
|
24 |
|
25 |
+
jl.seval("using Serialization: Serialization")
|
|
|
26 |
|
27 |
juliainfo = None
|
28 |
julia_initialized = False
|
|
|
62 |
if x is None:
|
63 |
return None
|
64 |
return jl_convert(jl.Array, x)
|
65 |
+
|
66 |
+
|
67 |
+
def jl_deserialize_s(s):
|
68 |
+
if s is None:
|
69 |
+
return s
|
70 |
+
buf = jl.IOBuffer()
|
71 |
+
jl.write(buf, jl_array(s))
|
72 |
+
jl.seekstart(buf)
|
73 |
+
return jl.Serialization.deserialize(buf)
|
pysr/sr.py
CHANGED
@@ -34,12 +34,11 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
|
|
34 |
from .export_torch import sympy2torch
|
35 |
from .feature_selection import run_feature_selection
|
36 |
from .julia_helpers import (
|
37 |
-
PythonCall,
|
38 |
_escape_filename,
|
39 |
_load_cluster_manager,
|
40 |
jl,
|
41 |
jl_array,
|
42 |
-
|
43 |
)
|
44 |
from .utils import (
|
45 |
_csv_filename_to_pkl_filename,
|
@@ -614,8 +613,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
614 |
Path to the temporary equations directory.
|
615 |
equation_file_ : str
|
616 |
Output equation file name produced by the julia backend.
|
617 |
-
|
618 |
-
The state for the julia SymbolicRegression.jl backend
|
619 |
equation_file_contents_ : list[pandas.DataFrame]
|
620 |
Contents of the equation file output by the Julia backend.
|
621 |
show_pickle_warnings_ : bool
|
@@ -1048,22 +1047,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1048 |
serialization.
|
1049 |
|
1050 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
1051 |
-
`
|
1052 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
1053 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
1054 |
to be serialized. Note: Jax and Torch format equations are also removed
|
1055 |
from the pickled instance.
|
1056 |
"""
|
1057 |
state = self.__dict__
|
1058 |
-
show_pickle_warning = not (
|
1059 |
-
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
|
1060 |
-
)
|
1061 |
-
if "raw_julia_state_" in state and show_pickle_warning:
|
1062 |
-
warnings.warn(
|
1063 |
-
"raw_julia_state_ cannot be pickled and will be removed from the "
|
1064 |
-
"serialized instance. This will prevent a `warm_start` fit of any "
|
1065 |
-
"model that is deserialized via `pickle.load()`."
|
1066 |
-
)
|
1067 |
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
1068 |
for state_key in state_keys_containing_lambdas:
|
1069 |
if state[state_key] is not None and show_pickle_warning:
|
@@ -1072,7 +1062,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1072 |
"serialized instance. When loading the model, please redefine "
|
1073 |
f"`{state_key}` at runtime."
|
1074 |
)
|
1075 |
-
state_keys_to_clear =
|
1076 |
pickled_state = {
|
1077 |
key: (None if key in state_keys_to_clear else value)
|
1078 |
for key, value in state.items()
|
@@ -1122,6 +1112,20 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1122 |
)
|
1123 |
return self.equations_
|
1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1125 |
def get_best(self, index=None):
|
1126 |
"""
|
1127 |
Get best equation using `model_selection`.
|
@@ -1724,7 +1728,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1724 |
# Python's garbage collection is unaware of them.
|
1725 |
jl._equation_search_args = (jl_X, jl_y)
|
1726 |
jl._equation_search_kwargs = namedtuple(
|
1727 |
-
"
|
1728 |
(
|
1729 |
"weights",
|
1730 |
"niterations",
|
@@ -1754,18 +1758,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1754 |
options=options,
|
1755 |
numprocs=cprocs,
|
1756 |
parallelism=parallelism,
|
1757 |
-
saved_state=self.
|
1758 |
return_state=True,
|
1759 |
addprocs_function=cluster_manager,
|
1760 |
heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
|
1761 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1762 |
verbosity=int(self.verbosity),
|
1763 |
)
|
1764 |
-
|
1765 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1766 |
)
|
1767 |
jl._equation_search_args = None
|
1768 |
jl._equation_search_kwargs = None
|
|
|
1769 |
|
1770 |
# Set attributes
|
1771 |
self.equations_ = self.get_hof()
|
@@ -1829,10 +1841,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1829 |
Fitted estimator.
|
1830 |
"""
|
1831 |
# Init attributes that are not specified in BaseEstimator
|
1832 |
-
if self.warm_start and hasattr(self, "
|
1833 |
pass
|
1834 |
else:
|
1835 |
-
if hasattr(self, "
|
1836 |
warnings.warn(
|
1837 |
"The discovered expressions are being reset. "
|
1838 |
"Please set `warm_start=True` if you wish to continue "
|
@@ -1842,7 +1854,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1842 |
self.equations_ = None
|
1843 |
self.nout_ = 1
|
1844 |
self.selection_mask_ = None
|
1845 |
-
self.
|
1846 |
self.X_units_ = None
|
1847 |
self.y_units_ = None
|
1848 |
|
|
|
34 |
from .export_torch import sympy2torch
|
35 |
from .feature_selection import run_feature_selection
|
36 |
from .julia_helpers import (
|
|
|
37 |
_escape_filename,
|
38 |
_load_cluster_manager,
|
39 |
jl,
|
40 |
jl_array,
|
41 |
+
jl_deserialize_s,
|
42 |
)
|
43 |
from .utils import (
|
44 |
_csv_filename_to_pkl_filename,
|
|
|
613 |
Path to the temporary equations directory.
|
614 |
equation_file_ : str
|
615 |
Output equation file name produced by the julia backend.
|
616 |
+
raw_julia_state_stream_ : ndarray
|
617 |
+
The serialized state for the julia SymbolicRegression.jl backend (after fitting).
|
618 |
equation_file_contents_ : list[pandas.DataFrame]
|
619 |
Contents of the equation file output by the Julia backend.
|
620 |
show_pickle_warnings_ : bool
|
|
|
1047 |
serialization.
|
1048 |
|
1049 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
1050 |
+
`raw_julia_state_stream_` attribute must be hidden from pickle. This will
|
1051 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
1052 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
1053 |
to be serialized. Note: Jax and Torch format equations are also removed
|
1054 |
from the pickled instance.
|
1055 |
"""
|
1056 |
state = self.__dict__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1057 |
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
1058 |
for state_key in state_keys_containing_lambdas:
|
1059 |
if state[state_key] is not None and show_pickle_warning:
|
|
|
1062 |
"serialized instance. When loading the model, please redefine "
|
1063 |
f"`{state_key}` at runtime."
|
1064 |
)
|
1065 |
+
state_keys_to_clear = state_keys_containing_lambdas
|
1066 |
pickled_state = {
|
1067 |
key: (None if key in state_keys_to_clear else value)
|
1068 |
for key, value in state.items()
|
|
|
1112 |
)
|
1113 |
return self.equations_
|
1114 |
|
1115 |
+
@property
|
1116 |
+
def julia_state(self):
|
1117 |
+
return jl_deserialize_s(self.raw_julia_state_stream_)
|
1118 |
+
|
1119 |
+
@property
|
1120 |
+
def raw_julia_state_(self):
|
1121 |
+
warnings.warn(
|
1122 |
+
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
1123 |
+
"Please use PySRRegressor.julia_state instead, or `raw_julia_state_stream_` "
|
1124 |
+
"for the raw stream of bytes.",
|
1125 |
+
FutureWarning,
|
1126 |
+
)
|
1127 |
+
return self.julia_state
|
1128 |
+
|
1129 |
def get_best(self, index=None):
|
1130 |
"""
|
1131 |
Get best equation using `model_selection`.
|
|
|
1728 |
# Python's garbage collection is unaware of them.
|
1729 |
jl._equation_search_args = (jl_X, jl_y)
|
1730 |
jl._equation_search_kwargs = namedtuple(
|
1731 |
+
"equation_search_kwargs",
|
1732 |
(
|
1733 |
"weights",
|
1734 |
"niterations",
|
|
|
1758 |
options=options,
|
1759 |
numprocs=cprocs,
|
1760 |
parallelism=parallelism,
|
1761 |
+
saved_state=self.julia_state,
|
1762 |
return_state=True,
|
1763 |
addprocs_function=cluster_manager,
|
1764 |
heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
|
1765 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1766 |
verbosity=int(self.verbosity),
|
1767 |
)
|
1768 |
+
output_stream = jl.seval(
|
1769 |
+
"""
|
1770 |
+
let args = deepcopy(_equation_search_args), kwargs=deepcopy(_equation_search_kwargs)
|
1771 |
+
out = SymbolicRegression.equation_search(args...; kwargs...)
|
1772 |
+
buf = IOBuffer()
|
1773 |
+
Serialization.serialize(buf, out)
|
1774 |
+
take!(buf)
|
1775 |
+
end
|
1776 |
+
"""
|
1777 |
)
|
1778 |
jl._equation_search_args = None
|
1779 |
jl._equation_search_kwargs = None
|
1780 |
+
self.raw_julia_state_stream_ = np.array(output_stream).copy()
|
1781 |
|
1782 |
# Set attributes
|
1783 |
self.equations_ = self.get_hof()
|
|
|
1841 |
Fitted estimator.
|
1842 |
"""
|
1843 |
# Init attributes that are not specified in BaseEstimator
|
1844 |
+
if self.warm_start and hasattr(self, "raw_julia_state_stream_"):
|
1845 |
pass
|
1846 |
else:
|
1847 |
+
if hasattr(self, "raw_julia_state_stream_"):
|
1848 |
warnings.warn(
|
1849 |
"The discovered expressions are being reset. "
|
1850 |
"Please set `warm_start=True` if you wish to continue "
|
|
|
1854 |
self.equations_ = None
|
1855 |
self.nout_ = 1
|
1856 |
self.selection_mask_ = None
|
1857 |
+
self.raw_julia_state_stream_ = None
|
1858 |
self.X_units_ = None
|
1859 |
self.y_units_ = None
|
1860 |
|