tttc3 commited on
Commit
bd90cfc
·
1 Parent(s): 3ef5500

Added pickle support

Browse files
Files changed (3) hide show
  1. pysr/export_numpy.py +4 -1
  2. pysr/sr.py +36 -0
  3. test/test.py +4 -4
pysr/export_numpy.py CHANGED
@@ -13,7 +13,6 @@ class CallableEquation:
13
  self._sympy_symbols = sympy_symbols
14
  self._selection = selection
15
  self._variable_names = variable_names
16
- self._lambda = lambdify(sympy_symbols, eqn)
17
 
18
  def __repr__(self):
19
  return f"PySRFunction(X=>{self._sympy})"
@@ -35,3 +34,7 @@ class CallableEquation:
35
  )
36
  X = X[:, self._selection]
37
  return self._lambda(*X.T) * np.ones(expected_shape)
 
 
 
 
 
13
  self._sympy_symbols = sympy_symbols
14
  self._selection = selection
15
  self._variable_names = variable_names
 
16
 
17
  def __repr__(self):
18
  return f"PySRFunction(X=>{self._sympy})"
 
34
  )
35
  X = X[:, self._selection]
36
  return self._lambda(*X.T) * np.ones(expected_shape)
37
+
38
+ @property
39
+ def _lambda(self):
40
+ return lambdify(self._sympy_symbols, self._sympy)
pysr/sr.py CHANGED
@@ -816,6 +816,42 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
816
  output += "]"
817
  return output
818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
819
  @property
820
  def equations(self): # pragma: no cover
821
  warnings.warn(
 
816
  output += "]"
817
  return output
818
 
819
+ def __getstate__(self):
820
+ """
821
+ Handles pickle serialization for PySRRegressor.
822
+
823
+ The Scikit-learn standard requires estimators to be serializable via
824
+ `pickle.dumps()`. However, `PyCall.jlwrap` does not support pickle
825
+ serialization.
826
+
827
+ Thus, for `PySRRegressor` to support pickle serialization, the
828
+ `raw_julia_state_` attribute must be hidden from pickle. This will
829
+ prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
830
+ but does allow all other attributes of a fitted `PySRRegressor` estimator
831
+ to be serialized. Note: Jax and Torch format equations are also removed
832
+ from the pickled instance.
833
+ """
834
+ warnings.warn(
835
+ "raw_julia_state_ cannot be pickled and will be removed from the "
836
+ "serialized instance. This will prevent a `warm_start` fit of any "
837
+ "model that is deserialized via `pickle.loads()`."
838
+ )
839
+ state = self.__dict__
840
+ pickled_state = {
841
+ key: None if key == "raw_julia_state_" else value
842
+ for key, value in state.items()
843
+ }
844
+ if "equations_" in pickled_state:
845
+ pickled_state["output_torch_format"] = False
846
+ pickled_state["output_jax_format"] = False
847
+ pickled_columns = ~pickled_state["equations_"].columns.isin(
848
+ ["jax_format", "torch_format"]
849
+ )
850
+ pickled_state["equations_"] = (
851
+ pickled_state["equations_"].loc[:, pickled_columns].copy()
852
+ )
853
+ return pickled_state
854
+
855
  @property
856
  def equations(self): # pragma: no cover
857
  warnings.warn(
test/test.py CHANGED
@@ -348,18 +348,18 @@ class TestMiscellaneous(unittest.TestCase):
348
  max_evals=10000, verbosity=0, progress=False
349
  ) # Return early.
350
  check_generator = check_estimator(model, generate_only=True)
 
351
  for (_, check) in check_generator:
352
- if "pickle" in check.func.__name__:
353
- # Skip pickling tests.
354
- continue
355
-
356
  try:
357
  with warnings.catch_warnings():
358
  warnings.simplefilter("ignore")
359
  check(model)
360
  print("Passed", check.func.__name__)
361
  except Exception as e:
 
362
  print("Failed", check.func.__name__, "with:")
363
  # Add a leading tab to error message, which
364
  # might be multi-line:
365
  print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
 
 
 
348
  max_evals=10000, verbosity=0, progress=False
349
  ) # Return early.
350
  check_generator = check_estimator(model, generate_only=True)
351
+ exception_messages = []
352
  for (_, check) in check_generator:
 
 
 
 
353
  try:
354
  with warnings.catch_warnings():
355
  warnings.simplefilter("ignore")
356
  check(model)
357
  print("Passed", check.func.__name__)
358
  except Exception as e:
359
+ exception_messages.append(f"{check.func.__name__}: {e}\n")
360
  print("Failed", check.func.__name__, "with:")
361
  # Add a leading tab to error message, which
362
  # might be multi-line:
363
  print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
364
+ # If any checks failed don't let the test pass.
365
+ self.assertEqual([], exception_messages)