Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
e68c63f
1
Parent(s):
a06de5e
Add feature for operator-level size constraints
Browse files- docs/options.md +14 -8
- julia/sr.jl +110 -45
- pysr/sr.py +43 -4
docs/options.md
CHANGED
@@ -14,7 +14,7 @@ may find useful include:
|
|
14 |
- `maxsize`, `maxdepth`
|
15 |
- `batching`, `batchSize`
|
16 |
- `variable_names` (or pandas input)
|
17 |
-
-
|
18 |
- LaTeX, SymPy, and callable equation output
|
19 |
|
20 |
These are described below
|
@@ -129,13 +129,19 @@ alphabetical characters and `_` are used in these names.
|
|
129 |
|
130 |
## Limiting pow complexity
|
131 |
|
132 |
-
One can limit the complexity of
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
## LaTeX, SymPy, callables
|
141 |
|
|
|
14 |
- `maxsize`, `maxdepth`
|
15 |
- `batching`, `batchSize`
|
16 |
- `variable_names` (or pandas input)
|
17 |
+
- Constraining operator complexity
|
18 |
- LaTeX, SymPy, and callable equation output
|
19 |
|
20 |
These are described below
|
|
|
129 |
|
130 |
## Limiting pow complexity
|
131 |
|
132 |
+
One can limit the complexity of specific operators with the `constraints` parameter.
|
133 |
+
There is a "maxsize" parameter to PySR, but there is also an operator-level
|
134 |
+
"constraints" parameter. One supplies a dict, like so:
|
135 |
+
|
136 |
+
```python
|
137 |
+
constraints={'pow': (-1, 1), 'mult': (3, 3), 'cos': 5}
|
138 |
+
```
|
139 |
+
|
140 |
+
What this says is that: a power law x^y can have an expression of arbitrary (-1) complexity in the x, but only complexity 1 (e.g., a constant or variable) in the y. So (x0 + 3)^5.5 is allowed, but 5.5^(x0 + 3) is not.
|
141 |
+
I find this helps a lot for getting more interpretable equations.
|
142 |
+
The other terms say that each multiplication can only have sub-expressions
|
143 |
+
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
144 |
+
expressions of complexity 5 (e.g., 5.0 + x2*exp(x3)).
|
145 |
|
146 |
## LaTeX, SymPy, callables
|
147 |
|
julia/sr.jl
CHANGED
@@ -646,24 +646,46 @@ mutable struct PopMember
|
|
646 |
|
647 |
end
|
648 |
|
649 |
-
# Check if any
|
650 |
-
function
|
651 |
if tree.degree == 0
|
652 |
-
return
|
653 |
elseif tree.degree == 1
|
654 |
-
return
|
655 |
else
|
656 |
-
if
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
663 |
end
|
664 |
-
else
|
665 |
-
return 0 + deepPow(tree.l) + deepPow(tree.r)
|
666 |
end
|
|
|
|
|
|
|
667 |
end
|
668 |
end
|
669 |
|
@@ -671,61 +693,104 @@ end
|
|
671 |
# exp(-delta/T) defines probability of accepting a change
|
672 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
673 |
prev = member.tree
|
674 |
-
tree =
|
675 |
#TODO - reconsider this
|
676 |
if batching
|
677 |
-
beforeLoss = scoreFuncBatch(
|
678 |
else
|
679 |
beforeLoss = member.score
|
680 |
end
|
681 |
|
682 |
mutationChoice = rand()
|
683 |
-
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
684 |
-
cur_weights = copy(mutationWeights) .* 1.0
|
685 |
#More constants => more likely to do constant mutation
|
|
|
|
|
686 |
cur_weights[1] *= weightAdjustmentMutateConstant
|
687 |
-
n = countNodes(
|
688 |
-
depth = countDepth(
|
689 |
|
690 |
# If equation too big, don't add new operators
|
691 |
if n >= curmaxsize || depth >= maxdepth
|
692 |
cur_weights[3] = 0.0
|
693 |
cur_weights[4] = 0.0
|
694 |
end
|
695 |
-
|
696 |
cur_weights /= sum(cur_weights)
|
697 |
cweights = cumsum(cur_weights)
|
698 |
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
end
|
709 |
-
elseif mutationChoice < cweights[4]
|
710 |
-
tree = insertRandomOp(tree)
|
711 |
-
elseif mutationChoice < cweights[5]
|
712 |
-
tree = deleteRandomOp(tree)
|
713 |
-
elseif mutationChoice < cweights[6]
|
714 |
-
tree = simplifyTree(tree) # Sometimes we simplify tree
|
715 |
-
tree = combineOperators(tree) # See if repeated constants at outer levels
|
716 |
-
return PopMember(tree, beforeLoss)
|
717 |
-
elseif mutationChoice < cweights[7]
|
718 |
-
tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
|
719 |
-
else
|
720 |
-
return PopMember(tree, beforeLoss)
|
721 |
-
end
|
722 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
723 |
|
724 |
-
|
725 |
-
if limitPowComplexity && (deepPow(tree) > 0)
|
726 |
-
return PopMember(copyNode(prev), beforeLoss)
|
727 |
end
|
|
|
728 |
|
|
|
|
|
|
|
729 |
|
730 |
if batching
|
731 |
afterLoss = scoreFuncBatch(tree)
|
|
|
646 |
|
647 |
end
|
648 |
|
649 |
+
# Check if any binary operator are overly complex
|
650 |
+
function flagBinOperatorComplexity(tree::Node, op::Int)::Bool
|
651 |
if tree.degree == 0
|
652 |
+
return false
|
653 |
elseif tree.degree == 1
|
654 |
+
return flagBinOperatorComplexity(tree.l, op)
|
655 |
else
|
656 |
+
if tree.op == op
|
657 |
+
overly_complex = (
|
658 |
+
((bin_constraints[op][1] > -1) &&
|
659 |
+
(countNodes(tree.l) > bin_constraints[op][1]))
|
660 |
+
||
|
661 |
+
((bin_constraints[op][2] > -1) &&
|
662 |
+
(countNodes(tree.r) > bin_constraints[op][2]))
|
663 |
+
)
|
664 |
+
if overly_complex
|
665 |
+
return true
|
666 |
+
end
|
667 |
+
end
|
668 |
+
return (flagBinOperatorComplexity(tree.l, op) || flagBinOperatorComplexity(tree.r, op))
|
669 |
+
end
|
670 |
+
end
|
671 |
+
|
672 |
+
# Check if any unary operators are overly complex
|
673 |
+
function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
|
674 |
+
if tree.degree == 0
|
675 |
+
return false
|
676 |
+
elseif tree.degree == 1
|
677 |
+
if tree.op == op
|
678 |
+
overly_complex = (
|
679 |
+
(una_constraints[op] > -1) &&
|
680 |
+
(countNodes(tree.l) > una_constraints[op])
|
681 |
+
)
|
682 |
+
if overly_complex
|
683 |
+
return true
|
684 |
end
|
|
|
|
|
685 |
end
|
686 |
+
return flagUnaOperatorComplexity(tree.l, op)
|
687 |
+
else
|
688 |
+
return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
|
689 |
end
|
690 |
end
|
691 |
|
|
|
693 |
# exp(-delta/T) defines probability of accepting a change
|
694 |
function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
|
695 |
prev = member.tree
|
696 |
+
tree = prev
|
697 |
#TODO - reconsider this
|
698 |
if batching
|
699 |
+
beforeLoss = scoreFuncBatch(prev)
|
700 |
else
|
701 |
beforeLoss = member.score
|
702 |
end
|
703 |
|
704 |
mutationChoice = rand()
|
|
|
|
|
705 |
#More constants => more likely to do constant mutation
|
706 |
+
weightAdjustmentMutateConstant = min(8, countConstants(prev))/8.0
|
707 |
+
cur_weights = copy(mutationWeights) .* 1.0
|
708 |
cur_weights[1] *= weightAdjustmentMutateConstant
|
709 |
+
n = countNodes(prev)
|
710 |
+
depth = countDepth(prev)
|
711 |
|
712 |
# If equation too big, don't add new operators
|
713 |
if n >= curmaxsize || depth >= maxdepth
|
714 |
cur_weights[3] = 0.0
|
715 |
cur_weights[4] = 0.0
|
716 |
end
|
|
|
717 |
cur_weights /= sum(cur_weights)
|
718 |
cweights = cumsum(cur_weights)
|
719 |
|
720 |
+
successful_mutation = false
|
721 |
+
#TODO: Currently we dont take this \/ into account
|
722 |
+
is_success_always_possible = true
|
723 |
+
attempts = 0
|
724 |
+
max_attempts = 10
|
725 |
+
|
726 |
+
#############################################
|
727 |
+
# Mutations
|
728 |
+
#############################################
|
729 |
+
while (!successful_mutation) && attempts < max_attempts
|
730 |
+
tree = copyNode(prev)
|
731 |
+
successful_mutation = true
|
732 |
+
if mutationChoice < cweights[1]
|
733 |
+
tree = mutateConstant(tree, T)
|
734 |
+
|
735 |
+
is_success_always_possible = true
|
736 |
+
# Mutating a constant shouldn't invalidate an already-valid function
|
737 |
+
|
738 |
+
elseif mutationChoice < cweights[2]
|
739 |
+
tree = mutateOperator(tree)
|
740 |
+
|
741 |
+
is_success_always_possible = true
|
742 |
+
# Can always mutate to the same operator
|
743 |
+
|
744 |
+
elseif mutationChoice < cweights[3]
|
745 |
+
if rand() < 0.5
|
746 |
+
tree = appendRandomOp(tree)
|
747 |
+
else
|
748 |
+
tree = prependRandomOp(tree)
|
749 |
+
end
|
750 |
+
is_success_always_possible = false
|
751 |
+
# Can potentially have a situation without success
|
752 |
+
elseif mutationChoice < cweights[4]
|
753 |
+
tree = insertRandomOp(tree)
|
754 |
+
is_success_always_possible = false
|
755 |
+
elseif mutationChoice < cweights[5]
|
756 |
+
tree = deleteRandomOp(tree)
|
757 |
+
is_success_always_possible = true
|
758 |
+
elseif mutationChoice < cweights[6]
|
759 |
+
tree = simplifyTree(tree) # Sometimes we simplify tree
|
760 |
+
tree = combineOperators(tree) # See if repeated constants at outer levels
|
761 |
+
return PopMember(tree, beforeLoss)
|
762 |
+
|
763 |
+
is_success_always_possible = true
|
764 |
+
# Simplification shouldn't hurt complexity; unless some non-symmetric constraint
|
765 |
+
# to commutative operator...
|
766 |
+
|
767 |
+
elseif mutationChoice < cweights[7]
|
768 |
+
tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
|
769 |
+
|
770 |
+
is_success_always_possible = true
|
771 |
+
else # no mutation applied
|
772 |
+
return PopMember(tree, beforeLoss)
|
773 |
end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
+
# Check for illegal equations
|
776 |
+
for i=1:nbin
|
777 |
+
if successful_mutation && flagBinOperatorComplexity(tree, i)
|
778 |
+
successful_mutation = false
|
779 |
+
end
|
780 |
+
end
|
781 |
+
for i=1:nuna
|
782 |
+
if successful_mutation && flagUnaOperatorComplexity(tree, i)
|
783 |
+
successful_mutation = false
|
784 |
+
end
|
785 |
+
end
|
786 |
|
787 |
+
attempts += 1
|
|
|
|
|
788 |
end
|
789 |
+
#############################################
|
790 |
|
791 |
+
if !successful_mutation
|
792 |
+
return PopMember(copyNode(prev), beforeLoss)
|
793 |
+
end
|
794 |
|
795 |
if batching
|
796 |
afterLoss = scoreFuncBatch(tree)
|
pysr/sr.py
CHANGED
@@ -89,7 +89,8 @@ def pysr(X=None, y=None, weights=None,
|
|
89 |
batchSize=50,
|
90 |
select_k_features=None,
|
91 |
warmupMaxsize=0,
|
92 |
-
|
|
|
93 |
threads=None, #deprecated
|
94 |
julia_optimization=3,
|
95 |
):
|
@@ -166,9 +167,11 @@ def pysr(X=None, y=None, weights=None,
|
|
166 |
a small number up to the maxsize (if greater than 0).
|
167 |
If greater than 0, says how many cycles before the maxsize
|
168 |
is increased.
|
169 |
-
:param
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
173 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
174 |
(as strings).
|
@@ -176,6 +179,8 @@ def pysr(X=None, y=None, weights=None,
|
|
176 |
"""
|
177 |
if threads is not None:
|
178 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
|
|
|
|
179 |
if maxdepth is None:
|
180 |
maxdepth = maxsize
|
181 |
|
@@ -207,6 +212,17 @@ def pysr(X=None, y=None, weights=None,
|
|
207 |
if populations is None:
|
208 |
populations = procs
|
209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
211 |
|
212 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
@@ -247,7 +263,30 @@ def pysr(X=None, y=None, weights=None,
|
|
247 |
function_name = op[:first_non_char]
|
248 |
op_list[i] = function_name
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
|
|
251 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
252 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
253 |
const ns=10;
|
|
|
89 |
batchSize=50,
|
90 |
select_k_features=None,
|
91 |
warmupMaxsize=0,
|
92 |
+
constraints={},
|
93 |
+
limitPowComplexity=False, #deprecated
|
94 |
threads=None, #deprecated
|
95 |
julia_optimization=3,
|
96 |
):
|
|
|
167 |
a small number up to the maxsize (if greater than 0).
|
168 |
If greater than 0, says how many cycles before the maxsize
|
169 |
is increased.
|
170 |
+
:param constraints: dict of int (unary) or 2-tuples (binary),
|
171 |
+
this enforces maxsize constraints on the individual
|
172 |
+
arguments of operators. E.g., `'pow': (-1, 1)`
|
173 |
+
says that power laws can have any complexity left argument, but only
|
174 |
+
1 complexity exponent. Use this to force more interpretable solutions.
|
175 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
176 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
177 |
(as strings).
|
|
|
179 |
"""
|
180 |
if threads is not None:
|
181 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
182 |
+
if limitPowComplexity:
|
183 |
+
raise ValueError("The limitPowComplexity kwarg is deprecated. Use constraints.")
|
184 |
if maxdepth is None:
|
185 |
maxdepth = maxsize
|
186 |
|
|
|
212 |
if populations is None:
|
213 |
populations = procs
|
214 |
|
215 |
+
#arbitrary complexity by default
|
216 |
+
for op in unary_operators:
|
217 |
+
if op not in constraints:
|
218 |
+
constraints[op] = -1
|
219 |
+
for op in binary_operators:
|
220 |
+
if op not in constraints:
|
221 |
+
constraints[op] = (-1, -1)
|
222 |
+
if op in ['mult', 'plus', 'sub']:
|
223 |
+
if constraints[op][0] != constraints[op][1]:
|
224 |
+
raise NotImplementedError("You need equal constraints on both sides for +, -, and *, due to simplification strategies.")
|
225 |
+
|
226 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
227 |
|
228 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
|
263 |
function_name = op[:first_non_char]
|
264 |
op_list[i] = function_name
|
265 |
|
266 |
+
constraints_str = "const una_constraints = ["
|
267 |
+
first = True
|
268 |
+
for op in unary_operators:
|
269 |
+
val = constraints[op]
|
270 |
+
if not first:
|
271 |
+
constraints_str += ", "
|
272 |
+
constraints_str += f"{val:d}"
|
273 |
+
first = False
|
274 |
+
|
275 |
+
constraints_str += """]
|
276 |
+
const bin_constraints = ["""
|
277 |
+
|
278 |
+
first = True
|
279 |
+
for op in binary_operators:
|
280 |
+
tup = constraints[op]
|
281 |
+
if not first:
|
282 |
+
constraints_str += ", "
|
283 |
+
constraints_str += f"({tup[0]:d}, {tup[1]:d})"
|
284 |
+
first = False
|
285 |
+
constraints_str += "]"
|
286 |
+
|
287 |
+
|
288 |
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
289 |
+
{constraints_str}
|
290 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
291 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
292 |
const ns=10;
|