MilesCranmer commited on
Commit
f06ee71
·
1 Parent(s): 43bc86a

Create function to setup equation file during fit

Browse files
Files changed (1) hide show
  1. pysr/sr.py +21 -11
pysr/sr.py CHANGED
@@ -894,14 +894,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
894
  if self.maxdepth is None:
895
  self.maxdepth = self.maxsize
896
 
897
- # Cast tempdir string as a Path object
898
- self.tempdir_ = Path(tempfile.mkdtemp(dir=self.tempdir))
899
- if self.temp_equation_file:
900
- self.equation_file = self.tempdir_ / "hall_of_fame.csv"
901
- elif self.equation_file is None:
902
- date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
903
- self.equation_file = "hall_of_fame_" + date_time + ".csv"
904
-
905
  # Handle type conversion for instance parameters:
906
  if isinstance(self.binary_operators, str):
907
  self.binary_operators = [self.binary_operators]
@@ -967,6 +959,22 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
967
 
968
  return self
969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
970
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
971
  """
972
  Validates the parameters passed to the :term`fit` method.
@@ -1267,7 +1275,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1267
  nested_constraints=self.nested_constraints,
1268
  loss=Main.custom_loss,
1269
  maxsize=int(self.maxsize),
1270
- hofFile=_escape_filename(self.equation_file),
1271
  npopulations=int(self.populations),
1272
  batching=self.batching,
1273
  batchSize=int(min([self.batch_size, len(X)]) if self.batching else len(X)),
@@ -1399,6 +1407,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1399
  self.selection_mask_ = None
1400
  self.raw_julia_state_ = None
1401
 
 
 
1402
  # Parameter input validation (for parameters defined in __init__)
1403
  X, y, Xresampled, variable_names = self._validate_fit_params(
1404
  X, y, Xresampled, variable_names
@@ -1654,7 +1664,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1654
  all_outputs = []
1655
  for i in range(1, self.nout_ + 1):
1656
  df = pd.read_csv(
1657
- str(self.equation_file) + f".out{i}" + ".bkup",
1658
  sep="|",
1659
  )
1660
  # Rename Complexity column to complexity:
@@ -1669,7 +1679,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1669
 
1670
  all_outputs.append(df)
1671
  else:
1672
- all_outputs = [pd.read_csv(str(self.equation_file) + ".bkup", sep="|")]
1673
  all_outputs[-1].rename(
1674
  columns={
1675
  "Complexity": "complexity",
 
894
  if self.maxdepth is None:
895
  self.maxdepth = self.maxsize
896
 
 
 
 
 
 
 
 
 
897
  # Handle type conversion for instance parameters:
898
  if isinstance(self.binary_operators, str):
899
  self.binary_operators = [self.binary_operators]
 
959
 
960
  return self
961
 
962
+ def _setup_equation_file(self):
963
+ """
964
+ Sets the full pathname of the equation file, using :param`tempdir` and
965
+ :param`equation_file`.
966
+ """
967
+ # Cast tempdir string as a Path object
968
+ self.tempdir_ = Path(tempfile.mkdtemp(dir=self.tempdir))
969
+ if self.temp_equation_file:
970
+ self.equation_file_ = self.tempdir_ / "hall_of_fame.csv"
971
+ elif self.equation_file is None:
972
+ date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
973
+ self.equation_file_ = "hall_of_fame_" + date_time + ".csv"
974
+ else:
975
+ self.equation_file_ = self.equation_file
976
+
977
+
978
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
979
  """
980
  Validates the parameters passed to the :term`fit` method.
 
1275
  nested_constraints=self.nested_constraints,
1276
  loss=Main.custom_loss,
1277
  maxsize=int(self.maxsize),
1278
+ hofFile=_escape_filename(self.equation_file_),
1279
  npopulations=int(self.populations),
1280
  batching=self.batching,
1281
  batchSize=int(min([self.batch_size, len(X)]) if self.batching else len(X)),
 
1407
  self.selection_mask_ = None
1408
  self.raw_julia_state_ = None
1409
 
1410
+ self._setup_equation_file()
1411
+
1412
  # Parameter input validation (for parameters defined in __init__)
1413
  X, y, Xresampled, variable_names = self._validate_fit_params(
1414
  X, y, Xresampled, variable_names
 
1664
  all_outputs = []
1665
  for i in range(1, self.nout_ + 1):
1666
  df = pd.read_csv(
1667
+ str(self.equation_file_) + f".out{i}" + ".bkup",
1668
  sep="|",
1669
  )
1670
  # Rename Complexity column to complexity:
 
1679
 
1680
  all_outputs.append(df)
1681
  else:
1682
+ all_outputs = [pd.read_csv(str(self.equation_file_) + ".bkup", sep="|")]
1683
  all_outputs[-1].rename(
1684
  columns={
1685
  "Complexity": "complexity",