MilesCranmer commited on
Commit
dde0ef7
1 Parent(s): 85371bb

Remove extra_sympy_mappings from pickle file

Browse files
Files changed (1) hide show
  1. pysr/sr.py +17 -2
pysr/sr.py CHANGED
@@ -562,6 +562,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
562
  equation_file_contents_ : list[pandas.DataFrame]
563
  Contents of the equation file output by the Julia backend.
564
 
 
 
 
565
  Notes
566
  -----
567
  Most default parameters have been tuned over several example equations,
@@ -873,14 +876,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
873
  from the pickled instance.
874
  """
875
  state = self.__dict__
876
- if "raw_julia_state_" in state:
 
 
 
877
  warnings.warn(
878
  "raw_julia_state_ cannot be pickled and will be removed from the "
879
  "serialized instance. This will prevent a `warm_start` fit of any "
880
  "model that is deserialized via `pickle.load()`."
881
  )
 
 
 
 
 
 
 
 
 
882
  pickled_state = {
883
- key: None if key == "raw_julia_state_" else value
884
  for key, value in state.items()
885
  }
886
  if ("equations_" in pickled_state) and (
 
562
  equation_file_contents_ : list[pandas.DataFrame]
563
  Contents of the equation file output by the Julia backend.
564
 
565
+ show_pickle_warnings_ : bool
566
+ Whether to show warnings about what attributes can be pickled.
567
+
568
  Notes
569
  -----
570
  Most default parameters have been tuned over several example equations,
 
876
  from the pickled instance.
877
  """
878
  state = self.__dict__
879
+ show_pickle_warning = not (
880
+ "show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
881
+ )
882
+ if "raw_julia_state_" in state and show_pickle_warning:
883
  warnings.warn(
884
  "raw_julia_state_ cannot be pickled and will be removed from the "
885
  "serialized instance. This will prevent a `warm_start` fit of any "
886
  "model that is deserialized via `pickle.load()`."
887
  )
888
+ state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
889
+ for state_key in state_keys_containing_lambdas:
890
+ if state[state_key] is not None and show_pickle_warning:
891
+ warnings.warn(
892
+ f"`{state_key}` cannot be pickled and will be removed from the "
893
+ "serialized instance. When loading the model, please redefine "
894
+ f"`{state_key}` at runtime."
895
+ )
896
+ state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
897
  pickled_state = {
898
+ key: (None if key in state_keys_to_clear else value)
899
  for key, value in state.items()
900
  }
901
  if ("equations_" in pickled_state) and (