Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
6fa78c9
1
Parent(s):
1d23dc2
Use egal instead of equal for many ops
Browse files- julia/sr.jl +59 -59
- pysr/sr.py +18 -10
julia/sr.jl
CHANGED
@@ -96,9 +96,9 @@ end
|
|
96 |
|
97 |
# Copy an equation (faster than deepcopy)
|
98 |
function copyNode(tree::Node)::Node
|
99 |
-
if tree.degree
|
100 |
return Node(tree.val)
|
101 |
-
elseif tree.degree
|
102 |
return Node(tree.op, copyNode(tree.l))
|
103 |
else
|
104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
@@ -107,9 +107,9 @@ end
|
|
107 |
|
108 |
# Count the operators, constants, variables in an equation
|
109 |
function countNodes(tree::Node)::Integer
|
110 |
-
if tree.degree
|
111 |
return 1
|
112 |
-
elseif tree.degree
|
113 |
return 1 + countNodes(tree.l)
|
114 |
else
|
115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
@@ -118,9 +118,9 @@ end
|
|
118 |
|
119 |
# Count the max depth of a tree
|
120 |
function countDepth(tree::Node)::Integer
|
121 |
-
if tree.degree
|
122 |
return 1
|
123 |
-
elseif tree.degree
|
124 |
return 1 + countDepth(tree.l)
|
125 |
else
|
126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
@@ -129,7 +129,7 @@ end
|
|
129 |
|
130 |
# Convert an equation to a string
|
131 |
function stringTree(tree::Node)::String
|
132 |
-
if tree.degree
|
133 |
if tree.constant
|
134 |
return string(tree.val)
|
135 |
else
|
@@ -139,7 +139,7 @@ function stringTree(tree::Node)::String
|
|
139 |
return "x$(tree.val - 1)"
|
140 |
end
|
141 |
end
|
142 |
-
elseif tree.degree
|
143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
144 |
else
|
145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
@@ -153,7 +153,7 @@ end
|
|
153 |
|
154 |
# Return a random node from the tree
|
155 |
function randomNode(tree::Node)::Node
|
156 |
-
if tree.degree
|
157 |
return tree
|
158 |
end
|
159 |
a = countNodes(tree)
|
@@ -162,14 +162,14 @@ function randomNode(tree::Node)::Node
|
|
162 |
if tree.degree >= 1
|
163 |
b = countNodes(tree.l)
|
164 |
end
|
165 |
-
if tree.degree
|
166 |
c = countNodes(tree.r)
|
167 |
end
|
168 |
|
169 |
i = rand(1:1+b+c)
|
170 |
if i <= b
|
171 |
return randomNode(tree.l)
|
172 |
-
elseif i
|
173 |
return tree
|
174 |
end
|
175 |
|
@@ -178,9 +178,9 @@ end
|
|
178 |
|
179 |
# Count the number of unary operators in the equation
|
180 |
function countUnaryOperators(tree::Node)::Integer
|
181 |
-
if tree.degree
|
182 |
return 0
|
183 |
-
elseif tree.degree
|
184 |
return 1 + countUnaryOperators(tree.l)
|
185 |
else
|
186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
@@ -189,9 +189,9 @@ end
|
|
189 |
|
190 |
# Count the number of binary operators in the equation
|
191 |
function countBinaryOperators(tree::Node)::Integer
|
192 |
-
if tree.degree
|
193 |
return 0
|
194 |
-
elseif tree.degree
|
195 |
return 0 + countBinaryOperators(tree.l)
|
196 |
else
|
197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
@@ -206,14 +206,14 @@ end
|
|
206 |
# Randomly convert an operator into another one (binary->binary;
|
207 |
# unary->unary)
|
208 |
function mutateOperator(tree::Node)::Node
|
209 |
-
if countOperators(tree)
|
210 |
return tree
|
211 |
end
|
212 |
node = randomNode(tree)
|
213 |
-
while node.degree
|
214 |
node = randomNode(tree)
|
215 |
end
|
216 |
-
if node.degree
|
217 |
node.op = rand(1:length(unaops))
|
218 |
else
|
219 |
node.op = rand(1:length(binops))
|
@@ -223,9 +223,9 @@ end
|
|
223 |
|
224 |
# Count the number of constants in an equation
|
225 |
function countConstants(tree::Node)::Integer
|
226 |
-
if tree.degree
|
227 |
return convert(Integer, tree.constant)
|
228 |
-
elseif tree.degree
|
229 |
return 0 + countConstants(tree.l)
|
230 |
else
|
231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
@@ -238,11 +238,11 @@ function mutateConstant(
|
|
238 |
probNegate::Float32=0.01f0)::Node
|
239 |
# T is between 0 and 1.
|
240 |
|
241 |
-
if countConstants(tree)
|
242 |
return tree
|
243 |
end
|
244 |
node = randomNode(tree)
|
245 |
-
while node.degree
|
246 |
node = randomNode(tree)
|
247 |
end
|
248 |
|
@@ -273,19 +273,19 @@ end
|
|
273 |
# Evaluate an equation over an array of datapoints
|
274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
275 |
clen = size(cX)[1]
|
276 |
-
if tree.degree
|
277 |
if tree.constant
|
278 |
return fill(tree.val, clen)
|
279 |
else
|
280 |
return copy(cX[:, tree.val])
|
281 |
end
|
282 |
-
elseif tree.degree
|
283 |
cumulator = evalTreeArray(tree.l, cX)
|
284 |
if cumulator === nothing
|
285 |
return nothing
|
286 |
end
|
287 |
op_idx = tree.op
|
288 |
-
UNAOP!(op_idx,
|
289 |
@inbounds for i=1:clen
|
290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
291 |
return nothing
|
@@ -302,7 +302,7 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
302 |
return nothing
|
303 |
end
|
304 |
op_idx = tree.op
|
305 |
-
BINOP!(
|
306 |
@inbounds for i=1:clen
|
307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
308 |
return nothing
|
@@ -350,7 +350,7 @@ end
|
|
350 |
# Add a random unary/binary operation to the end of a tree
|
351 |
function appendRandomOp(tree::Node)::Node
|
352 |
node = randomNode(tree)
|
353 |
-
while node.degree
|
354 |
node = randomNode(tree)
|
355 |
end
|
356 |
|
@@ -458,7 +458,7 @@ end
|
|
458 |
|
459 |
# Return a random node from the tree with parent
|
460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
461 |
-
if tree.degree
|
462 |
return tree, parent
|
463 |
end
|
464 |
a = countNodes(tree)
|
@@ -467,14 +467,14 @@ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{No
|
|
467 |
if tree.degree >= 1
|
468 |
b = countNodes(tree.l)
|
469 |
end
|
470 |
-
if tree.degree
|
471 |
c = countNodes(tree.r)
|
472 |
end
|
473 |
|
474 |
i = rand(1:1+b+c)
|
475 |
if i <= b
|
476 |
return randomNodeAndParent(tree.l, tree)
|
477 |
-
elseif i
|
478 |
return tree, parent
|
479 |
end
|
480 |
|
@@ -487,7 +487,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
487 |
node, parent = randomNodeAndParent(tree, nothing)
|
488 |
isroot = (parent === nothing)
|
489 |
|
490 |
-
if node.degree
|
491 |
# Replace with new constant
|
492 |
newnode = randomConstantNode()
|
493 |
node.l = newnode.l
|
@@ -496,7 +496,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
496 |
node.degree = newnode.degree
|
497 |
node.val = newnode.val
|
498 |
node.constant = newnode.constant
|
499 |
-
elseif node.degree
|
500 |
# Join one of the children with the parent
|
501 |
if isroot
|
502 |
return node.l
|
@@ -536,17 +536,17 @@ function combineOperators(tree::Node)::Node
|
|
536 |
# ((const - var) - const) => (const - var)
|
537 |
# (want to add anything commutative!)
|
538 |
# TODO - need to combine plus/sub if they are both there.
|
539 |
-
if tree.degree
|
540 |
return tree
|
541 |
-
elseif tree.degree
|
542 |
tree.l = combineOperators(tree.l)
|
543 |
-
elseif tree.degree
|
544 |
tree.l = combineOperators(tree.l)
|
545 |
tree.r = combineOperators(tree.r)
|
546 |
end
|
547 |
|
548 |
-
top_level_constant = tree.degree
|
549 |
-
if tree.degree
|
550 |
op = tree.op
|
551 |
# Put the constant in r
|
552 |
if tree.l.constant
|
@@ -557,7 +557,7 @@ function combineOperators(tree::Node)::Node
|
|
557 |
topconstant = tree.r.val
|
558 |
# Simplify down first
|
559 |
below = tree.l
|
560 |
-
if below.degree
|
561 |
if below.l.constant
|
562 |
tree = below
|
563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
@@ -568,11 +568,11 @@ function combineOperators(tree::Node)::Node
|
|
568 |
end
|
569 |
end
|
570 |
|
571 |
-
if tree.degree
|
572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
573 |
# Not commutative, so use different op.
|
574 |
if tree.l.constant
|
575 |
-
if tree.r.degree
|
576 |
if tree.r.l.constant
|
577 |
#(const - (const - var)) => (var - const)
|
578 |
l = tree.l
|
@@ -591,7 +591,7 @@ function combineOperators(tree::Node)::Node
|
|
591 |
end
|
592 |
end
|
593 |
else #tree.r.constant is true
|
594 |
-
if tree.l.degree
|
595 |
if tree.l.l.constant
|
596 |
#((const - var) - const) => (const - var)
|
597 |
l = tree.l
|
@@ -616,17 +616,17 @@ end
|
|
616 |
|
617 |
# Simplify tree
|
618 |
function simplifyTree(tree::Node)::Node
|
619 |
-
if tree.degree
|
620 |
tree.l = simplifyTree(tree.l)
|
621 |
-
if tree.l.degree
|
622 |
return Node(unaops[tree.op](tree.l.val))
|
623 |
end
|
624 |
-
elseif tree.degree
|
625 |
tree.l = simplifyTree(tree.l)
|
626 |
tree.r = simplifyTree(tree.r)
|
627 |
constantsBelow = (
|
628 |
-
tree.l.degree
|
629 |
-
tree.r.degree
|
630 |
)
|
631 |
if constantsBelow
|
632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
@@ -648,9 +648,9 @@ end
|
|
648 |
|
649 |
# Check if any power operator is to the power of a complex expression
|
650 |
function deepPow(tree::Node)::Integer
|
651 |
-
if tree.degree
|
652 |
return 0
|
653 |
-
elseif tree.degree
|
654 |
return 0 + deepPow(tree.l)
|
655 |
else
|
656 |
if binops[tree.op] === pow
|
@@ -857,7 +857,7 @@ function run(
|
|
857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
858 |
end
|
859 |
|
860 |
-
if verbosity > 0 && (iT % verbosity
|
861 |
bestPops = bestSubPop(pop)
|
862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
@@ -870,13 +870,13 @@ end
|
|
870 |
|
871 |
# Get all the constants from a tree
|
872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
873 |
-
if tree.degree
|
874 |
if tree.constant
|
875 |
return [tree.val]
|
876 |
else
|
877 |
return Float32[]
|
878 |
end
|
879 |
-
elseif tree.degree
|
880 |
return getConstants(tree.l)
|
881 |
else
|
882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
@@ -886,11 +886,11 @@ end
|
|
886 |
|
887 |
# Set all the constants inside a tree
|
888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
889 |
-
if tree.degree
|
890 |
if tree.constant
|
891 |
tree.val = constants[1]
|
892 |
end
|
893 |
-
elseif tree.degree
|
894 |
setConstants(tree.l, constants)
|
895 |
else
|
896 |
numberLeft = countConstants(tree.l)
|
@@ -909,12 +909,12 @@ end
|
|
909 |
# Use Nelder-Mead to optimize the constants in an equation
|
910 |
function optimizeConstants(member::PopMember)::PopMember
|
911 |
nconst = countConstants(member.tree)
|
912 |
-
if nconst
|
913 |
return member
|
914 |
end
|
915 |
x0 = getConstants(member.tree)
|
916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
917 |
-
if size(x0)[1]
|
918 |
algorithm = Optim.Newton
|
919 |
else
|
920 |
algorithm = Optim.NelderMead
|
@@ -998,7 +998,7 @@ function fullRun(niterations::Integer;
|
|
998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
999 |
hallOfFame = HallOfFame()
|
1000 |
curmaxsize = 3
|
1001 |
-
if warmupMaxsize
|
1002 |
curmaxsize = maxsize
|
1003 |
end
|
1004 |
|
@@ -1067,7 +1067,7 @@ function fullRun(niterations::Integer;
|
|
1067 |
numberSmallerAndBetter += 1
|
1068 |
end
|
1069 |
end
|
1070 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
1071 |
if betterThanAllSmaller
|
1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
1073 |
push!(dominating, member)
|
@@ -1117,7 +1117,7 @@ function fullRun(niterations::Integer;
|
|
1117 |
|
1118 |
cycles_complete -= 1
|
1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
1120 |
-
if warmupMaxsize
|
1121 |
curmaxsize += 1
|
1122 |
if curmaxsize > maxsize
|
1123 |
curmaxsize = maxsize
|
@@ -1167,7 +1167,7 @@ function fullRun(niterations::Integer;
|
|
1167 |
numberSmallerAndBetter += 1
|
1168 |
end
|
1169 |
end
|
1170 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
1171 |
if betterThanAllSmaller
|
1172 |
delta_c = size - lastComplexity
|
1173 |
delta_l_mse = log(curMSE/lastMSE)
|
|
|
96 |
|
97 |
# Copy an equation (faster than deepcopy)
|
98 |
function copyNode(tree::Node)::Node
|
99 |
+
if tree.degree === 0
|
100 |
return Node(tree.val)
|
101 |
+
elseif tree.degree === 1
|
102 |
return Node(tree.op, copyNode(tree.l))
|
103 |
else
|
104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
|
107 |
|
108 |
# Count the operators, constants, variables in an equation
|
109 |
function countNodes(tree::Node)::Integer
|
110 |
+
if tree.degree === 0
|
111 |
return 1
|
112 |
+
elseif tree.degree === 1
|
113 |
return 1 + countNodes(tree.l)
|
114 |
else
|
115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
|
118 |
|
119 |
# Count the max depth of a tree
|
120 |
function countDepth(tree::Node)::Integer
|
121 |
+
if tree.degree === 0
|
122 |
return 1
|
123 |
+
elseif tree.degree === 1
|
124 |
return 1 + countDepth(tree.l)
|
125 |
else
|
126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
|
129 |
|
130 |
# Convert an equation to a string
|
131 |
function stringTree(tree::Node)::String
|
132 |
+
if tree.degree === 0
|
133 |
if tree.constant
|
134 |
return string(tree.val)
|
135 |
else
|
|
|
139 |
return "x$(tree.val - 1)"
|
140 |
end
|
141 |
end
|
142 |
+
elseif tree.degree === 1
|
143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
144 |
else
|
145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
|
153 |
|
154 |
# Return a random node from the tree
|
155 |
function randomNode(tree::Node)::Node
|
156 |
+
if tree.degree === 0
|
157 |
return tree
|
158 |
end
|
159 |
a = countNodes(tree)
|
|
|
162 |
if tree.degree >= 1
|
163 |
b = countNodes(tree.l)
|
164 |
end
|
165 |
+
if tree.degree === 2
|
166 |
c = countNodes(tree.r)
|
167 |
end
|
168 |
|
169 |
i = rand(1:1+b+c)
|
170 |
if i <= b
|
171 |
return randomNode(tree.l)
|
172 |
+
elseif i === b + 1
|
173 |
return tree
|
174 |
end
|
175 |
|
|
|
178 |
|
179 |
# Count the number of unary operators in the equation
|
180 |
function countUnaryOperators(tree::Node)::Integer
|
181 |
+
if tree.degree === 0
|
182 |
return 0
|
183 |
+
elseif tree.degree === 1
|
184 |
return 1 + countUnaryOperators(tree.l)
|
185 |
else
|
186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
|
189 |
|
190 |
# Count the number of binary operators in the equation
|
191 |
function countBinaryOperators(tree::Node)::Integer
|
192 |
+
if tree.degree === 0
|
193 |
return 0
|
194 |
+
elseif tree.degree === 1
|
195 |
return 0 + countBinaryOperators(tree.l)
|
196 |
else
|
197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
|
206 |
# Randomly convert an operator into another one (binary->binary;
|
207 |
# unary->unary)
|
208 |
function mutateOperator(tree::Node)::Node
|
209 |
+
if countOperators(tree) === 0
|
210 |
return tree
|
211 |
end
|
212 |
node = randomNode(tree)
|
213 |
+
while node.degree === 0
|
214 |
node = randomNode(tree)
|
215 |
end
|
216 |
+
if node.degree === 1
|
217 |
node.op = rand(1:length(unaops))
|
218 |
else
|
219 |
node.op = rand(1:length(binops))
|
|
|
223 |
|
224 |
# Count the number of constants in an equation
|
225 |
function countConstants(tree::Node)::Integer
|
226 |
+
if tree.degree === 0
|
227 |
return convert(Integer, tree.constant)
|
228 |
+
elseif tree.degree === 1
|
229 |
return 0 + countConstants(tree.l)
|
230 |
else
|
231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
|
238 |
probNegate::Float32=0.01f0)::Node
|
239 |
# T is between 0 and 1.
|
240 |
|
241 |
+
if countConstants(tree) === 0
|
242 |
return tree
|
243 |
end
|
244 |
node = randomNode(tree)
|
245 |
+
while node.degree !== 0 || node.constant === false
|
246 |
node = randomNode(tree)
|
247 |
end
|
248 |
|
|
|
273 |
# Evaluate an equation over an array of datapoints
|
274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
275 |
clen = size(cX)[1]
|
276 |
+
if tree.degree === 0
|
277 |
if tree.constant
|
278 |
return fill(tree.val, clen)
|
279 |
else
|
280 |
return copy(cX[:, tree.val])
|
281 |
end
|
282 |
+
elseif tree.degree === 1
|
283 |
cumulator = evalTreeArray(tree.l, cX)
|
284 |
if cumulator === nothing
|
285 |
return nothing
|
286 |
end
|
287 |
op_idx = tree.op
|
288 |
+
UNAOP!(cumulator, op_idx, clen)
|
289 |
@inbounds for i=1:clen
|
290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
291 |
return nothing
|
|
|
302 |
return nothing
|
303 |
end
|
304 |
op_idx = tree.op
|
305 |
+
BINOP!(cumulator, array2, op_idx, clen)
|
306 |
@inbounds for i=1:clen
|
307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
308 |
return nothing
|
|
|
350 |
# Add a random unary/binary operation to the end of a tree
|
351 |
function appendRandomOp(tree::Node)::Node
|
352 |
node = randomNode(tree)
|
353 |
+
while node.degree !== 0
|
354 |
node = randomNode(tree)
|
355 |
end
|
356 |
|
|
|
458 |
|
459 |
# Return a random node from the tree with parent
|
460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
461 |
+
if tree.degree === 0
|
462 |
return tree, parent
|
463 |
end
|
464 |
a = countNodes(tree)
|
|
|
467 |
if tree.degree >= 1
|
468 |
b = countNodes(tree.l)
|
469 |
end
|
470 |
+
if tree.degree === 2
|
471 |
c = countNodes(tree.r)
|
472 |
end
|
473 |
|
474 |
i = rand(1:1+b+c)
|
475 |
if i <= b
|
476 |
return randomNodeAndParent(tree.l, tree)
|
477 |
+
elseif i === b + 1
|
478 |
return tree, parent
|
479 |
end
|
480 |
|
|
|
487 |
node, parent = randomNodeAndParent(tree, nothing)
|
488 |
isroot = (parent === nothing)
|
489 |
|
490 |
+
if node.degree === 0
|
491 |
# Replace with new constant
|
492 |
newnode = randomConstantNode()
|
493 |
node.l = newnode.l
|
|
|
496 |
node.degree = newnode.degree
|
497 |
node.val = newnode.val
|
498 |
node.constant = newnode.constant
|
499 |
+
elseif node.degree === 1
|
500 |
# Join one of the children with the parent
|
501 |
if isroot
|
502 |
return node.l
|
|
|
536 |
# ((const - var) - const) => (const - var)
|
537 |
# (want to add anything commutative!)
|
538 |
# TODO - need to combine plus/sub if they are both there.
|
539 |
+
if tree.degree === 0
|
540 |
return tree
|
541 |
+
elseif tree.degree === 1
|
542 |
tree.l = combineOperators(tree.l)
|
543 |
+
elseif tree.degree === 2
|
544 |
tree.l = combineOperators(tree.l)
|
545 |
tree.r = combineOperators(tree.r)
|
546 |
end
|
547 |
|
548 |
+
top_level_constant = tree.degree === 2 && (tree.l.constant || tree.r.constant)
|
549 |
+
if tree.degree === 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
|
550 |
op = tree.op
|
551 |
# Put the constant in r
|
552 |
if tree.l.constant
|
|
|
557 |
topconstant = tree.r.val
|
558 |
# Simplify down first
|
559 |
below = tree.l
|
560 |
+
if below.degree === 2 && below.op === op
|
561 |
if below.l.constant
|
562 |
tree = below
|
563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
|
568 |
end
|
569 |
end
|
570 |
|
571 |
+
if tree.degree === 2 && binops[tree.op] === sub && top_level_constant
|
572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
573 |
# Not commutative, so use different op.
|
574 |
if tree.l.constant
|
575 |
+
if tree.r.degree === 2 && binops[tree.r.op] === sub
|
576 |
if tree.r.l.constant
|
577 |
#(const - (const - var)) => (var - const)
|
578 |
l = tree.l
|
|
|
591 |
end
|
592 |
end
|
593 |
else #tree.r.constant is true
|
594 |
+
if tree.l.degree === 2 && binops[tree.l.op] === sub
|
595 |
if tree.l.l.constant
|
596 |
#((const - var) - const) => (const - var)
|
597 |
l = tree.l
|
|
|
616 |
|
617 |
# Simplify tree
|
618 |
function simplifyTree(tree::Node)::Node
|
619 |
+
if tree.degree === 1
|
620 |
tree.l = simplifyTree(tree.l)
|
621 |
+
if tree.l.degree === 0 && tree.l.constant
|
622 |
return Node(unaops[tree.op](tree.l.val))
|
623 |
end
|
624 |
+
elseif tree.degree === 2
|
625 |
tree.l = simplifyTree(tree.l)
|
626 |
tree.r = simplifyTree(tree.r)
|
627 |
constantsBelow = (
|
628 |
+
tree.l.degree === 0 && tree.l.constant &&
|
629 |
+
tree.r.degree === 0 && tree.r.constant
|
630 |
)
|
631 |
if constantsBelow
|
632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
|
648 |
|
649 |
# Check if any power operator is to the power of a complex expression
|
650 |
function deepPow(tree::Node)::Integer
|
651 |
+
if tree.degree === 0
|
652 |
return 0
|
653 |
+
elseif tree.degree === 1
|
654 |
return 0 + deepPow(tree.l)
|
655 |
else
|
656 |
if binops[tree.op] === pow
|
|
|
857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
858 |
end
|
859 |
|
860 |
+
if verbosity > 0 && (iT % verbosity === 0)
|
861 |
bestPops = bestSubPop(pop)
|
862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
|
870 |
|
871 |
# Get all the constants from a tree
|
872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
873 |
+
if tree.degree === 0
|
874 |
if tree.constant
|
875 |
return [tree.val]
|
876 |
else
|
877 |
return Float32[]
|
878 |
end
|
879 |
+
elseif tree.degree === 1
|
880 |
return getConstants(tree.l)
|
881 |
else
|
882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
|
886 |
|
887 |
# Set all the constants inside a tree
|
888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
889 |
+
if tree.degree === 0
|
890 |
if tree.constant
|
891 |
tree.val = constants[1]
|
892 |
end
|
893 |
+
elseif tree.degree === 1
|
894 |
setConstants(tree.l, constants)
|
895 |
else
|
896 |
numberLeft = countConstants(tree.l)
|
|
|
909 |
# Use Nelder-Mead to optimize the constants in an equation
|
910 |
function optimizeConstants(member::PopMember)::PopMember
|
911 |
nconst = countConstants(member.tree)
|
912 |
+
if nconst === 0
|
913 |
return member
|
914 |
end
|
915 |
x0 = getConstants(member.tree)
|
916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
917 |
+
if size(x0)[1] === 1
|
918 |
algorithm = Optim.Newton
|
919 |
else
|
920 |
algorithm = Optim.NelderMead
|
|
|
998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
999 |
hallOfFame = HallOfFame()
|
1000 |
curmaxsize = 3
|
1001 |
+
if warmupMaxsize === 0
|
1002 |
curmaxsize = maxsize
|
1003 |
end
|
1004 |
|
|
|
1067 |
numberSmallerAndBetter += 1
|
1068 |
end
|
1069 |
end
|
1070 |
+
betterThanAllSmaller = (numberSmallerAndBetter === 0)
|
1071 |
if betterThanAllSmaller
|
1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
1073 |
push!(dominating, member)
|
|
|
1117 |
|
1118 |
cycles_complete -= 1
|
1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
1120 |
+
if warmupMaxsize !== 0 && cycles_elapsed % warmupMaxsize === 0
|
1121 |
curmaxsize += 1
|
1122 |
if curmaxsize > maxsize
|
1123 |
curmaxsize = maxsize
|
|
|
1167 |
numberSmallerAndBetter += 1
|
1168 |
end
|
1169 |
end
|
1170 |
+
betterThanAllSmaller = (numberSmallerAndBetter === 0)
|
1171 |
if betterThanAllSmaller
|
1172 |
delta_c = size - lastComplexity
|
1173 |
delta_l_mse = log(curMSE/lastMSE)
|
pysr/sr.py
CHANGED
@@ -287,26 +287,34 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
|
|
287 |
op_runner = ""
|
288 |
if len(binary_operators) > 0:
|
289 |
op_runner += """
|
290 |
-
function BINOP!(
|
291 |
-
if i
|
292 |
-
|
|
|
|
|
293 |
for i in range(1, len(binary_operators)):
|
294 |
op_runner += f"""
|
295 |
-
elseif i
|
296 |
-
|
|
|
|
|
297 |
op_runner += """
|
298 |
end
|
299 |
end"""
|
300 |
|
301 |
if len(unary_operators) > 0:
|
302 |
op_runner += """
|
303 |
-
function UNAOP!(
|
304 |
-
if i
|
305 |
-
|
|
|
|
|
306 |
for i in range(1, len(unary_operators)):
|
307 |
op_runner += """
|
308 |
-
elseif i
|
309 |
-
|
|
|
|
|
310 |
op_runner += """
|
311 |
end
|
312 |
end"""
|
|
|
287 |
op_runner = ""
|
288 |
if len(binary_operators) > 0:
|
289 |
op_runner += """
|
290 |
+
@inline function BINOP!(x::Array{Float32, 1}, y::Array{Float32, 1}, i::Int, clen::Int)
|
291 |
+
if i === 1
|
292 |
+
@inbounds @simd for j=1:clen
|
293 |
+
x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
|
294 |
+
end"""
|
295 |
for i in range(1, len(binary_operators)):
|
296 |
op_runner += f"""
|
297 |
+
elseif i === {i+1}
|
298 |
+
@inbounds @simd for j=1:clen
|
299 |
+
x[j] = {binary_operators[i]}(x[j], y[j])
|
300 |
+
end"""
|
301 |
op_runner += """
|
302 |
end
|
303 |
end"""
|
304 |
|
305 |
if len(unary_operators) > 0:
|
306 |
op_runner += """
|
307 |
+
@inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
|
308 |
+
if i === 1
|
309 |
+
@inbounds @simd for j=1:clen
|
310 |
+
x[j] = """f"{unary_operators[0]}(x[j])""""
|
311 |
+
end"""
|
312 |
for i in range(1, len(unary_operators)):
|
313 |
op_runner += """
|
314 |
+
elseif i === {i+1}
|
315 |
+
@inbounds @simd for j=1:clen
|
316 |
+
x[j] = {unary_operators[i]}(x[j])
|
317 |
+
end"""
|
318 |
op_runner += """
|
319 |
end
|
320 |
end"""
|