MilesCranmer commited on
Commit
358f0ab
·
1 Parent(s): 40f498c

Remove additional changes to internal parameters

Browse files
Files changed (1) hide show
  1. pysr/sr.py +16 -14
pysr/sr.py CHANGED
@@ -80,7 +80,8 @@ def pysr(X, y, weights=None, **kwargs): # pragma: no cover
80
  return model.equations
81
 
82
 
83
- def _handle_constraints(binary_operators, unary_operators, constraints):
 
84
  for op in unary_operators:
85
  if op not in constraints:
86
  constraints[op] = -1
@@ -101,10 +102,13 @@ def _handle_constraints(binary_operators, unary_operators, constraints):
101
  constraints[op][1],
102
  constraints[op][0],
103
  )
 
104
 
105
 
106
- def _create_inline_operators(binary_operators, unary_operators):
107
  global Main
 
 
108
  for op_list in [binary_operators, unary_operators]:
109
  for i, op in enumerate(op_list):
110
  is_user_defined_operator = "(" in op
@@ -123,6 +127,7 @@ def _create_inline_operators(binary_operators, unary_operators):
123
  "Only alphanumeric characters, numbers, and underscores are allowed."
124
  )
125
  op_list[i] = function_name
 
126
 
127
 
128
  def _check_assertions(
@@ -1214,17 +1219,18 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1214
  Main.pow = Main.eval("(^)")
1215
  Main.div = Main.eval("(/)")
1216
 
1217
- _create_inline_operators(
 
1218
  binary_operators=self.binary_operators, unary_operators=self.unary_operators
1219
  )
1220
- _handle_constraints(
1221
- binary_operators=self.binary_operators,
1222
- unary_operators=self.unary_operators,
1223
  constraints=self.constraints,
1224
  )
1225
 
1226
- una_constraints = [self.constraints[op] for op in self.unary_operators]
1227
- bin_constraints = [self.constraints[op] for op in self.binary_operators]
1228
 
1229
  # Parse dict into Julia Dict for nested constraints::
1230
  if self.nested_constraints is not None:
@@ -1265,12 +1271,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1265
  # Call to Julia backend.
1266
  # See https://github.com/search?q=%22function+Options%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3AOptions.jl+language%3AJulia&type=Code
1267
  options = Main.Options(
1268
- binary_operators=Main.eval(
1269
- str(tuple(self.binary_operators)).replace("'", "")
1270
- ),
1271
- unary_operators=Main.eval(
1272
- str(tuple(self.unary_operators)).replace("'", "")
1273
- ),
1274
  bin_constraints=bin_constraints,
1275
  una_constraints=una_constraints,
1276
  complexity_of_operators=complexity_of_operators,
 
80
  return model.equations
81
 
82
 
83
+ def _process_constraints(binary_operators, unary_operators, constraints):
84
+ constraints = constraints.copy()
85
  for op in unary_operators:
86
  if op not in constraints:
87
  constraints[op] = -1
 
102
  constraints[op][1],
103
  constraints[op][0],
104
  )
105
+ return constraints
106
 
107
 
108
+ def _maybe_create_inline_operators(binary_operators, unary_operators):
109
  global Main
110
+ binary_operators = binary_operators.copy()
111
+ unary_operators = unary_operators.copy()
112
  for op_list in [binary_operators, unary_operators]:
113
  for i, op in enumerate(op_list):
114
  is_user_defined_operator = "(" in op
 
127
  "Only alphanumeric characters, numbers, and underscores are allowed."
128
  )
129
  op_list[i] = function_name
130
+ return binary_operators, unary_operators
131
 
132
 
133
  def _check_assertions(
 
1219
  Main.pow = Main.eval("(^)")
1220
  Main.div = Main.eval("(/)")
1221
 
1222
+ # TODO(mcranmer): These functions should be part of this class.
1223
+ binary_operators, unary_operators = _maybe_create_inline_operators(
1224
  binary_operators=self.binary_operators, unary_operators=self.unary_operators
1225
  )
1226
+ constraints = _process_constraints(
1227
+ binary_operators=binary_operators,
1228
+ unary_operators=unary_operators,
1229
  constraints=self.constraints,
1230
  )
1231
 
1232
+ una_constraints = [constraints[op] for op in unary_operators]
1233
+ bin_constraints = [constraints[op] for op in binary_operators]
1234
 
1235
  # Parse dict into Julia Dict for nested constraints::
1236
  if self.nested_constraints is not None:
 
1271
  # Call to Julia backend.
1272
  # See https://github.com/search?q=%22function+Options%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3AOptions.jl+language%3AJulia&type=Code
1273
  options = Main.Options(
1274
+ binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
1275
+ unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
 
 
 
 
1276
  bin_constraints=bin_constraints,
1277
  una_constraints=una_constraints,
1278
  complexity_of_operators=complexity_of_operators,