MilesCranmer commited on
Commit
eccee44
1 Parent(s): d0ea029

Allow user to specify whole loss function

Browse files
Files changed (2) hide show
  1. pysr/sr.py +29 -4
  2. pysr/version.py +1 -1
pysr/sr.py CHANGED
@@ -319,9 +319,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
319
  argument is constrained.
320
  Default is `None`.
321
  loss : str
322
- String of Julia code specifying the loss function. Can either
323
- be a loss from LossFunctions.jl, or your own loss written as a
324
- function. Examples of custom written losses include:
325
  `myloss(x, y) = abs(x-y)` for non-weighted, or
326
  `myloss(x, y, w) = w*abs(x-y)` for weighted.
327
  The included losses include:
@@ -334,6 +334,23 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
334
  `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`,
335
  `SigmoidLoss()`, `DWDMarginLoss(q)`.
336
  Default is `"L2DistLoss()"`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  complexity_of_operators : dict[str, float]
338
  If you would like to use a complexity other than 1 for an
339
  operator, specify the complexity here. For example,
@@ -675,7 +692,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
675
  timeout_in_seconds=None,
676
  constraints=None,
677
  nested_constraints=None,
678
- loss="L2DistLoss()",
 
679
  complexity_of_operators=None,
680
  complexity_of_constants=1,
681
  complexity_of_variables=1,
@@ -763,6 +781,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
763
  self.early_stop_condition = early_stop_condition
764
  # - Loss parameters
765
  self.loss = loss
 
766
  self.complexity_of_operators = complexity_of_operators
767
  self.complexity_of_constants = complexity_of_constants
768
  self.complexity_of_variables = complexity_of_variables
@@ -1217,6 +1236,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1217
  "to True and `procs` to 0 will result in non-deterministic searches. "
1218
  )
1219
 
 
 
 
1220
  # NotImplementedError - Values that could be supported at a later time
1221
  if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
1222
  raise NotImplementedError(
@@ -1546,6 +1568,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1546
  complexity_of_operators = Main.eval(complexity_of_operators_str)
1547
 
1548
  custom_loss = Main.eval(self.loss)
 
 
1549
  early_stop_condition = Main.eval(
1550
  str(self.early_stop_condition) if self.early_stop_condition else None
1551
  )
@@ -1574,6 +1598,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1574
  complexity_of_variables=self.complexity_of_variables,
1575
  nested_constraints=nested_constraints,
1576
  elementwise_loss=custom_loss,
 
1577
  maxsize=int(self.maxsize),
1578
  output_file=_escape_filename(self.equation_file_),
1579
  npopulations=int(self.populations),
 
319
  argument is constrained.
320
  Default is `None`.
321
  loss : str
322
+ String of Julia code specifying an elementwise loss function.
323
+ Can either be a loss from LossFunctions.jl, or your own loss
324
+ written as a function. Examples of custom written losses include:
325
  `myloss(x, y) = abs(x-y)` for non-weighted, or
326
  `myloss(x, y, w) = w*abs(x-y)` for weighted.
327
  The included losses include:
 
334
  `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`,
335
  `SigmoidLoss()`, `DWDMarginLoss(q)`.
336
  Default is `"L2DistLoss()"`.
337
+ full_objective : str
338
+ Alternatively, you can specify the full objective function as
339
+ a snippet of Julia code, including any sort of custom evaluation
340
+ (including symbolic manipulations beforehand), and any sort
341
+ of loss function or regularizations. The default `full_objective`
342
+ used in SymbolicRegression.jl is roughly equal to:
343
+ ```julia
344
+ function eval_loss(tree, dataset::Dataset{T}, options) where T
345
+ prediction, flag = eval_tree_array(tree, dataset.X, options)
346
+ if !flag
347
+ return T(Inf)
348
+ end
349
+ sum((prediction .- dataset.y) .^ 2) / dataset.n
350
+ end
351
+ ```
352
+ where the example elementwise loss is mean-squared error.
353
+ Default is `None`.
354
  complexity_of_operators : dict[str, float]
355
  If you would like to use a complexity other than 1 for an
356
  operator, specify the complexity here. For example,
 
692
  timeout_in_seconds=None,
693
  constraints=None,
694
  nested_constraints=None,
695
+ loss=None,
696
+ full_objective=None,
697
  complexity_of_operators=None,
698
  complexity_of_constants=1,
699
  complexity_of_variables=1,
 
781
  self.early_stop_condition = early_stop_condition
782
  # - Loss parameters
783
  self.loss = loss
784
+ self.full_objective = full_objective
785
  self.complexity_of_operators = complexity_of_operators
786
  self.complexity_of_constants = complexity_of_constants
787
  self.complexity_of_variables = complexity_of_variables
 
1236
  "to True and `procs` to 0 will result in non-deterministic searches. "
1237
  )
1238
 
1239
+ if self.loss is not None and self.full_objective is not None:
1240
+ raise ValueError("You cannot set both `loss` and `objective`.")
1241
+
1242
  # NotImplementedError - Values that could be supported at a later time
1243
  if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
1244
  raise NotImplementedError(
 
1568
  complexity_of_operators = Main.eval(complexity_of_operators_str)
1569
 
1570
  custom_loss = Main.eval(self.loss)
1571
+ custom_full_objective = Main.eval(self.full_objective)
1572
+
1573
  early_stop_condition = Main.eval(
1574
  str(self.early_stop_condition) if self.early_stop_condition else None
1575
  )
 
1598
  complexity_of_variables=self.complexity_of_variables,
1599
  nested_constraints=nested_constraints,
1600
  elementwise_loss=custom_loss,
1601
+ loss_function=custom_full_objective,
1602
  maxsize=int(self.maxsize),
1603
  output_file=_escape_filename(self.equation_file_),
1604
  npopulations=int(self.populations),
pysr/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = "0.11.17"
2
  __symbolic_regression_jl_version__ = "0.15.3"
 
1
+ __version__ = "0.11.18"
2
  __symbolic_regression_jl_version__ = "0.15.3"