Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
1f4e612
1
Parent(s):
eccca5d
Reset to last working copy
Browse files- julia/sr.jl +65 -59
- pysr/sr.py +12 -20
julia/sr.jl
CHANGED
@@ -96,9 +96,9 @@ end
|
|
96 |
|
97 |
# Copy an equation (faster than deepcopy)
|
98 |
function copyNode(tree::Node)::Node
|
99 |
-
if tree.degree
|
100 |
return Node(tree.val)
|
101 |
-
elseif tree.degree
|
102 |
return Node(tree.op, copyNode(tree.l))
|
103 |
else
|
104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
@@ -107,9 +107,9 @@ end
|
|
107 |
|
108 |
# Count the operators, constants, variables in an equation
|
109 |
function countNodes(tree::Node)::Integer
|
110 |
-
if tree.degree
|
111 |
return 1
|
112 |
-
elseif tree.degree
|
113 |
return 1 + countNodes(tree.l)
|
114 |
else
|
115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
@@ -118,9 +118,9 @@ end
|
|
118 |
|
119 |
# Count the max depth of a tree
|
120 |
function countDepth(tree::Node)::Integer
|
121 |
-
if tree.degree
|
122 |
return 1
|
123 |
-
elseif tree.degree
|
124 |
return 1 + countDepth(tree.l)
|
125 |
else
|
126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
@@ -129,7 +129,7 @@ end
|
|
129 |
|
130 |
# Convert an equation to a string
|
131 |
function stringTree(tree::Node)::String
|
132 |
-
if tree.degree
|
133 |
if tree.constant
|
134 |
return string(tree.val)
|
135 |
else
|
@@ -139,7 +139,7 @@ function stringTree(tree::Node)::String
|
|
139 |
return "x$(tree.val - 1)"
|
140 |
end
|
141 |
end
|
142 |
-
elseif tree.degree
|
143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
144 |
else
|
145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
@@ -153,7 +153,7 @@ end
|
|
153 |
|
154 |
# Return a random node from the tree
|
155 |
function randomNode(tree::Node)::Node
|
156 |
-
if tree.degree
|
157 |
return tree
|
158 |
end
|
159 |
a = countNodes(tree)
|
@@ -162,14 +162,14 @@ function randomNode(tree::Node)::Node
|
|
162 |
if tree.degree >= 1
|
163 |
b = countNodes(tree.l)
|
164 |
end
|
165 |
-
if tree.degree
|
166 |
c = countNodes(tree.r)
|
167 |
end
|
168 |
|
169 |
i = rand(1:1+b+c)
|
170 |
if i <= b
|
171 |
return randomNode(tree.l)
|
172 |
-
elseif i
|
173 |
return tree
|
174 |
end
|
175 |
|
@@ -178,9 +178,9 @@ end
|
|
178 |
|
179 |
# Count the number of unary operators in the equation
|
180 |
function countUnaryOperators(tree::Node)::Integer
|
181 |
-
if tree.degree
|
182 |
return 0
|
183 |
-
elseif tree.degree
|
184 |
return 1 + countUnaryOperators(tree.l)
|
185 |
else
|
186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
@@ -189,9 +189,9 @@ end
|
|
189 |
|
190 |
# Count the number of binary operators in the equation
|
191 |
function countBinaryOperators(tree::Node)::Integer
|
192 |
-
if tree.degree
|
193 |
return 0
|
194 |
-
elseif tree.degree
|
195 |
return 0 + countBinaryOperators(tree.l)
|
196 |
else
|
197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
@@ -206,14 +206,14 @@ end
|
|
206 |
# Randomly convert an operator into another one (binary->binary;
|
207 |
# unary->unary)
|
208 |
function mutateOperator(tree::Node)::Node
|
209 |
-
if countOperators(tree)
|
210 |
return tree
|
211 |
end
|
212 |
node = randomNode(tree)
|
213 |
-
while node.degree
|
214 |
node = randomNode(tree)
|
215 |
end
|
216 |
-
if node.degree
|
217 |
node.op = rand(1:length(unaops))
|
218 |
else
|
219 |
node.op = rand(1:length(binops))
|
@@ -223,9 +223,9 @@ end
|
|
223 |
|
224 |
# Count the number of constants in an equation
|
225 |
function countConstants(tree::Node)::Integer
|
226 |
-
if tree.degree
|
227 |
return convert(Integer, tree.constant)
|
228 |
-
elseif tree.degree
|
229 |
return 0 + countConstants(tree.l)
|
230 |
else
|
231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
@@ -238,11 +238,11 @@ function mutateConstant(
|
|
238 |
probNegate::Float32=0.01f0)::Node
|
239 |
# T is between 0 and 1.
|
240 |
|
241 |
-
if countConstants(tree)
|
242 |
return tree
|
243 |
end
|
244 |
node = randomNode(tree)
|
245 |
-
while node.degree
|
246 |
node = randomNode(tree)
|
247 |
end
|
248 |
|
@@ -273,19 +273,21 @@ end
|
|
273 |
# Evaluate an equation over an array of datapoints
|
274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
275 |
clen = size(cX)[1]
|
276 |
-
if tree.degree
|
277 |
if tree.constant
|
278 |
return fill(tree.val, clen)
|
279 |
else
|
280 |
return copy(cX[:, tree.val])
|
281 |
end
|
282 |
-
elseif tree.degree
|
283 |
cumulator = evalTreeArray(tree.l, cX)
|
284 |
if cumulator === nothing
|
285 |
return nothing
|
286 |
end
|
287 |
op_idx = tree.op
|
288 |
-
|
|
|
|
|
289 |
@inbounds for i=1:clen
|
290 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
291 |
return nothing
|
@@ -301,8 +303,12 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
|
|
301 |
if array2 === nothing
|
302 |
return nothing
|
303 |
end
|
|
|
304 |
op_idx = tree.op
|
305 |
-
|
|
|
|
|
|
|
306 |
@inbounds for i=1:clen
|
307 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
308 |
return nothing
|
@@ -350,7 +356,7 @@ end
|
|
350 |
# Add a random unary/binary operation to the end of a tree
|
351 |
function appendRandomOp(tree::Node)::Node
|
352 |
node = randomNode(tree)
|
353 |
-
while node.degree
|
354 |
node = randomNode(tree)
|
355 |
end
|
356 |
|
@@ -458,7 +464,7 @@ end
|
|
458 |
|
459 |
# Return a random node from the tree with parent
|
460 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
461 |
-
if tree.degree
|
462 |
return tree, parent
|
463 |
end
|
464 |
a = countNodes(tree)
|
@@ -467,14 +473,14 @@ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{No
|
|
467 |
if tree.degree >= 1
|
468 |
b = countNodes(tree.l)
|
469 |
end
|
470 |
-
if tree.degree
|
471 |
c = countNodes(tree.r)
|
472 |
end
|
473 |
|
474 |
i = rand(1:1+b+c)
|
475 |
if i <= b
|
476 |
return randomNodeAndParent(tree.l, tree)
|
477 |
-
elseif i
|
478 |
return tree, parent
|
479 |
end
|
480 |
|
@@ -487,7 +493,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
487 |
node, parent = randomNodeAndParent(tree, nothing)
|
488 |
isroot = (parent === nothing)
|
489 |
|
490 |
-
if node.degree
|
491 |
# Replace with new constant
|
492 |
newnode = randomConstantNode()
|
493 |
node.l = newnode.l
|
@@ -496,7 +502,7 @@ function deleteRandomOp(tree::Node)::Node
|
|
496 |
node.degree = newnode.degree
|
497 |
node.val = newnode.val
|
498 |
node.constant = newnode.constant
|
499 |
-
elseif node.degree
|
500 |
# Join one of the children with the parent
|
501 |
if isroot
|
502 |
return node.l
|
@@ -536,17 +542,17 @@ function combineOperators(tree::Node)::Node
|
|
536 |
# ((const - var) - const) => (const - var)
|
537 |
# (want to add anything commutative!)
|
538 |
# TODO - need to combine plus/sub if they are both there.
|
539 |
-
if tree.degree
|
540 |
return tree
|
541 |
-
elseif tree.degree
|
542 |
tree.l = combineOperators(tree.l)
|
543 |
-
elseif tree.degree
|
544 |
tree.l = combineOperators(tree.l)
|
545 |
tree.r = combineOperators(tree.r)
|
546 |
end
|
547 |
|
548 |
-
top_level_constant = tree.degree
|
549 |
-
if tree.degree
|
550 |
op = tree.op
|
551 |
# Put the constant in r
|
552 |
if tree.l.constant
|
@@ -557,7 +563,7 @@ function combineOperators(tree::Node)::Node
|
|
557 |
topconstant = tree.r.val
|
558 |
# Simplify down first
|
559 |
below = tree.l
|
560 |
-
if below.degree
|
561 |
if below.l.constant
|
562 |
tree = below
|
563 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
@@ -568,11 +574,11 @@ function combineOperators(tree::Node)::Node
|
|
568 |
end
|
569 |
end
|
570 |
|
571 |
-
if tree.degree
|
572 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
573 |
# Not commutative, so use different op.
|
574 |
if tree.l.constant
|
575 |
-
if tree.r.degree
|
576 |
if tree.r.l.constant
|
577 |
#(const - (const - var)) => (var - const)
|
578 |
l = tree.l
|
@@ -591,7 +597,7 @@ function combineOperators(tree::Node)::Node
|
|
591 |
end
|
592 |
end
|
593 |
else #tree.r.constant is true
|
594 |
-
if tree.l.degree
|
595 |
if tree.l.l.constant
|
596 |
#((const - var) - const) => (const - var)
|
597 |
l = tree.l
|
@@ -616,17 +622,17 @@ end
|
|
616 |
|
617 |
# Simplify tree
|
618 |
function simplifyTree(tree::Node)::Node
|
619 |
-
if tree.degree
|
620 |
tree.l = simplifyTree(tree.l)
|
621 |
-
if tree.l.degree
|
622 |
return Node(unaops[tree.op](tree.l.val))
|
623 |
end
|
624 |
-
elseif tree.degree
|
625 |
tree.l = simplifyTree(tree.l)
|
626 |
tree.r = simplifyTree(tree.r)
|
627 |
constantsBelow = (
|
628 |
-
tree.l.degree
|
629 |
-
tree.r.degree
|
630 |
)
|
631 |
if constantsBelow
|
632 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
@@ -648,9 +654,9 @@ end
|
|
648 |
|
649 |
# Check if any power operator is to the power of a complex expression
|
650 |
function deepPow(tree::Node)::Integer
|
651 |
-
if tree.degree
|
652 |
return 0
|
653 |
-
elseif tree.degree
|
654 |
return 0 + deepPow(tree.l)
|
655 |
else
|
656 |
if binops[tree.op] === pow
|
@@ -857,7 +863,7 @@ function run(
|
|
857 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
858 |
end
|
859 |
|
860 |
-
if verbosity > 0 && (iT % verbosity
|
861 |
bestPops = bestSubPop(pop)
|
862 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
863 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
@@ -870,13 +876,13 @@ end
|
|
870 |
|
871 |
# Get all the constants from a tree
|
872 |
function getConstants(tree::Node)::Array{Float32, 1}
|
873 |
-
if tree.degree
|
874 |
if tree.constant
|
875 |
return [tree.val]
|
876 |
else
|
877 |
return Float32[]
|
878 |
end
|
879 |
-
elseif tree.degree
|
880 |
return getConstants(tree.l)
|
881 |
else
|
882 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
@@ -886,11 +892,11 @@ end
|
|
886 |
|
887 |
# Set all the constants inside a tree
|
888 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
889 |
-
if tree.degree
|
890 |
if tree.constant
|
891 |
tree.val = constants[1]
|
892 |
end
|
893 |
-
elseif tree.degree
|
894 |
setConstants(tree.l, constants)
|
895 |
else
|
896 |
numberLeft = countConstants(tree.l)
|
@@ -909,12 +915,12 @@ end
|
|
909 |
# Use Nelder-Mead to optimize the constants in an equation
|
910 |
function optimizeConstants(member::PopMember)::PopMember
|
911 |
nconst = countConstants(member.tree)
|
912 |
-
if nconst
|
913 |
return member
|
914 |
end
|
915 |
x0 = getConstants(member.tree)
|
916 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
917 |
-
if size(x0)[1]
|
918 |
algorithm = Optim.Newton
|
919 |
else
|
920 |
algorithm = Optim.NelderMead
|
@@ -998,7 +1004,7 @@ function fullRun(niterations::Integer;
|
|
998 |
bestSubPops = [Population(1) for j=1:npopulations]
|
999 |
hallOfFame = HallOfFame()
|
1000 |
curmaxsize = 3
|
1001 |
-
if warmupMaxsize
|
1002 |
curmaxsize = maxsize
|
1003 |
end
|
1004 |
|
@@ -1067,7 +1073,7 @@ function fullRun(niterations::Integer;
|
|
1067 |
numberSmallerAndBetter += 1
|
1068 |
end
|
1069 |
end
|
1070 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
1071 |
if betterThanAllSmaller
|
1072 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
1073 |
push!(dominating, member)
|
@@ -1117,7 +1123,7 @@ function fullRun(niterations::Integer;
|
|
1117 |
|
1118 |
cycles_complete -= 1
|
1119 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
1120 |
-
if warmupMaxsize
|
1121 |
curmaxsize += 1
|
1122 |
if curmaxsize > maxsize
|
1123 |
curmaxsize = maxsize
|
@@ -1167,7 +1173,7 @@ function fullRun(niterations::Integer;
|
|
1167 |
numberSmallerAndBetter += 1
|
1168 |
end
|
1169 |
end
|
1170 |
-
betterThanAllSmaller = (numberSmallerAndBetter
|
1171 |
if betterThanAllSmaller
|
1172 |
delta_c = size - lastComplexity
|
1173 |
delta_l_mse = log(curMSE/lastMSE)
|
|
|
96 |
|
97 |
# Copy an equation (faster than deepcopy)
|
98 |
function copyNode(tree::Node)::Node
|
99 |
+
if tree.degree == 0
|
100 |
return Node(tree.val)
|
101 |
+
elseif tree.degree == 1
|
102 |
return Node(tree.op, copyNode(tree.l))
|
103 |
else
|
104 |
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
|
|
|
107 |
|
108 |
# Count the operators, constants, variables in an equation
|
109 |
function countNodes(tree::Node)::Integer
|
110 |
+
if tree.degree == 0
|
111 |
return 1
|
112 |
+
elseif tree.degree == 1
|
113 |
return 1 + countNodes(tree.l)
|
114 |
else
|
115 |
return 1 + countNodes(tree.l) + countNodes(tree.r)
|
|
|
118 |
|
119 |
# Count the max depth of a tree
|
120 |
function countDepth(tree::Node)::Integer
|
121 |
+
if tree.degree == 0
|
122 |
return 1
|
123 |
+
elseif tree.degree == 1
|
124 |
return 1 + countDepth(tree.l)
|
125 |
else
|
126 |
return 1 + max(countDepth(tree.l), countDepth(tree.r))
|
|
|
129 |
|
130 |
# Convert an equation to a string
|
131 |
function stringTree(tree::Node)::String
|
132 |
+
if tree.degree == 0
|
133 |
if tree.constant
|
134 |
return string(tree.val)
|
135 |
else
|
|
|
139 |
return "x$(tree.val - 1)"
|
140 |
end
|
141 |
end
|
142 |
+
elseif tree.degree == 1
|
143 |
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
144 |
else
|
145 |
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
|
|
|
153 |
|
154 |
# Return a random node from the tree
|
155 |
function randomNode(tree::Node)::Node
|
156 |
+
if tree.degree == 0
|
157 |
return tree
|
158 |
end
|
159 |
a = countNodes(tree)
|
|
|
162 |
if tree.degree >= 1
|
163 |
b = countNodes(tree.l)
|
164 |
end
|
165 |
+
if tree.degree == 2
|
166 |
c = countNodes(tree.r)
|
167 |
end
|
168 |
|
169 |
i = rand(1:1+b+c)
|
170 |
if i <= b
|
171 |
return randomNode(tree.l)
|
172 |
+
elseif i == b + 1
|
173 |
return tree
|
174 |
end
|
175 |
|
|
|
178 |
|
179 |
# Count the number of unary operators in the equation
|
180 |
function countUnaryOperators(tree::Node)::Integer
|
181 |
+
if tree.degree == 0
|
182 |
return 0
|
183 |
+
elseif tree.degree == 1
|
184 |
return 1 + countUnaryOperators(tree.l)
|
185 |
else
|
186 |
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
|
|
|
189 |
|
190 |
# Count the number of binary operators in the equation
|
191 |
function countBinaryOperators(tree::Node)::Integer
|
192 |
+
if tree.degree == 0
|
193 |
return 0
|
194 |
+
elseif tree.degree == 1
|
195 |
return 0 + countBinaryOperators(tree.l)
|
196 |
else
|
197 |
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
|
|
|
206 |
# Randomly convert an operator into another one (binary->binary;
|
207 |
# unary->unary)
|
208 |
function mutateOperator(tree::Node)::Node
|
209 |
+
if countOperators(tree) == 0
|
210 |
return tree
|
211 |
end
|
212 |
node = randomNode(tree)
|
213 |
+
while node.degree == 0
|
214 |
node = randomNode(tree)
|
215 |
end
|
216 |
+
if node.degree == 1
|
217 |
node.op = rand(1:length(unaops))
|
218 |
else
|
219 |
node.op = rand(1:length(binops))
|
|
|
223 |
|
224 |
# Count the number of constants in an equation
|
225 |
function countConstants(tree::Node)::Integer
|
226 |
+
if tree.degree == 0
|
227 |
return convert(Integer, tree.constant)
|
228 |
+
elseif tree.degree == 1
|
229 |
return 0 + countConstants(tree.l)
|
230 |
else
|
231 |
return 0 + countConstants(tree.l) + countConstants(tree.r)
|
|
|
238 |
probNegate::Float32=0.01f0)::Node
|
239 |
# T is between 0 and 1.
|
240 |
|
241 |
+
if countConstants(tree) == 0
|
242 |
return tree
|
243 |
end
|
244 |
node = randomNode(tree)
|
245 |
+
while node.degree != 0 || node.constant == false
|
246 |
node = randomNode(tree)
|
247 |
end
|
248 |
|
|
|
273 |
# Evaluate an equation over an array of datapoints
|
274 |
function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
|
275 |
clen = size(cX)[1]
|
276 |
+
if tree.degree == 0
|
277 |
if tree.constant
|
278 |
return fill(tree.val, clen)
|
279 |
else
|
280 |
return copy(cX[:, tree.val])
|
281 |
end
|
282 |
+
elseif tree.degree == 1
|
283 |
cumulator = evalTreeArray(tree.l, cX)
|
284 |
if cumulator === nothing
|
285 |
return nothing
|
286 |
end
|
287 |
op_idx = tree.op
|
288 |
+
@inbounds @simd for i=1:clen
|
289 |
+
cumulator[i] = UNAOP(op_idx, cumulator[i])
|
290 |
+
end
|
291 |
@inbounds for i=1:clen
|
292 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
293 |
return nothing
|
|
|
303 |
if array2 === nothing
|
304 |
return nothing
|
305 |
end
|
306 |
+
|
307 |
op_idx = tree.op
|
308 |
+
|
309 |
+
@inbounds @simd for i=1:clen
|
310 |
+
cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
|
311 |
+
end
|
312 |
@inbounds for i=1:clen
|
313 |
if isinf(cumulator[i]) || isnan(cumulator[i])
|
314 |
return nothing
|
|
|
356 |
# Add a random unary/binary operation to the end of a tree
|
357 |
function appendRandomOp(tree::Node)::Node
|
358 |
node = randomNode(tree)
|
359 |
+
while node.degree != 0
|
360 |
node = randomNode(tree)
|
361 |
end
|
362 |
|
|
|
464 |
|
465 |
# Return a random node from the tree with parent
|
466 |
function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
|
467 |
+
if tree.degree == 0
|
468 |
return tree, parent
|
469 |
end
|
470 |
a = countNodes(tree)
|
|
|
473 |
if tree.degree >= 1
|
474 |
b = countNodes(tree.l)
|
475 |
end
|
476 |
+
if tree.degree == 2
|
477 |
c = countNodes(tree.r)
|
478 |
end
|
479 |
|
480 |
i = rand(1:1+b+c)
|
481 |
if i <= b
|
482 |
return randomNodeAndParent(tree.l, tree)
|
483 |
+
elseif i == b + 1
|
484 |
return tree, parent
|
485 |
end
|
486 |
|
|
|
493 |
node, parent = randomNodeAndParent(tree, nothing)
|
494 |
isroot = (parent === nothing)
|
495 |
|
496 |
+
if node.degree == 0
|
497 |
# Replace with new constant
|
498 |
newnode = randomConstantNode()
|
499 |
node.l = newnode.l
|
|
|
502 |
node.degree = newnode.degree
|
503 |
node.val = newnode.val
|
504 |
node.constant = newnode.constant
|
505 |
+
elseif node.degree == 1
|
506 |
# Join one of the children with the parent
|
507 |
if isroot
|
508 |
return node.l
|
|
|
542 |
# ((const - var) - const) => (const - var)
|
543 |
# (want to add anything commutative!)
|
544 |
# TODO - need to combine plus/sub if they are both there.
|
545 |
+
if tree.degree == 0
|
546 |
return tree
|
547 |
+
elseif tree.degree == 1
|
548 |
tree.l = combineOperators(tree.l)
|
549 |
+
elseif tree.degree == 2
|
550 |
tree.l = combineOperators(tree.l)
|
551 |
tree.r = combineOperators(tree.r)
|
552 |
end
|
553 |
|
554 |
+
top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
|
555 |
+
if tree.degree == 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
|
556 |
op = tree.op
|
557 |
# Put the constant in r
|
558 |
if tree.l.constant
|
|
|
563 |
topconstant = tree.r.val
|
564 |
# Simplify down first
|
565 |
below = tree.l
|
566 |
+
if below.degree == 2 && below.op == op
|
567 |
if below.l.constant
|
568 |
tree = below
|
569 |
tree.l.val = binops[op](tree.l.val, topconstant)
|
|
|
574 |
end
|
575 |
end
|
576 |
|
577 |
+
if tree.degree == 2 && binops[tree.op] === sub && top_level_constant
|
578 |
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
|
579 |
# Not commutative, so use different op.
|
580 |
if tree.l.constant
|
581 |
+
if tree.r.degree == 2 && binops[tree.r.op] === sub
|
582 |
if tree.r.l.constant
|
583 |
#(const - (const - var)) => (var - const)
|
584 |
l = tree.l
|
|
|
597 |
end
|
598 |
end
|
599 |
else #tree.r.constant is true
|
600 |
+
if tree.l.degree == 2 && binops[tree.l.op] === sub
|
601 |
if tree.l.l.constant
|
602 |
#((const - var) - const) => (const - var)
|
603 |
l = tree.l
|
|
|
622 |
|
623 |
# Simplify tree
|
624 |
function simplifyTree(tree::Node)::Node
|
625 |
+
if tree.degree == 1
|
626 |
tree.l = simplifyTree(tree.l)
|
627 |
+
if tree.l.degree == 0 && tree.l.constant
|
628 |
return Node(unaops[tree.op](tree.l.val))
|
629 |
end
|
630 |
+
elseif tree.degree == 2
|
631 |
tree.l = simplifyTree(tree.l)
|
632 |
tree.r = simplifyTree(tree.r)
|
633 |
constantsBelow = (
|
634 |
+
tree.l.degree == 0 && tree.l.constant &&
|
635 |
+
tree.r.degree == 0 && tree.r.constant
|
636 |
)
|
637 |
if constantsBelow
|
638 |
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
|
|
654 |
|
655 |
# Check if any power operator is to the power of a complex expression
|
656 |
function deepPow(tree::Node)::Integer
|
657 |
+
if tree.degree == 0
|
658 |
return 0
|
659 |
+
elseif tree.degree == 1
|
660 |
return 0 + deepPow(tree.l)
|
661 |
else
|
662 |
if binops[tree.op] === pow
|
|
|
863 |
pop = regEvolCycle(pop, 1.0f0, curmaxsize)
|
864 |
end
|
865 |
|
866 |
+
if verbosity > 0 && (iT % verbosity == 0)
|
867 |
bestPops = bestSubPop(pop)
|
868 |
bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
|
869 |
bestCurScore = bestPops.members[bestCurScoreIdx].score
|
|
|
876 |
|
877 |
# Get all the constants from a tree
|
878 |
function getConstants(tree::Node)::Array{Float32, 1}
|
879 |
+
if tree.degree == 0
|
880 |
if tree.constant
|
881 |
return [tree.val]
|
882 |
else
|
883 |
return Float32[]
|
884 |
end
|
885 |
+
elseif tree.degree == 1
|
886 |
return getConstants(tree.l)
|
887 |
else
|
888 |
both = [getConstants(tree.l), getConstants(tree.r)]
|
|
|
892 |
|
893 |
# Set all the constants inside a tree
|
894 |
function setConstants(tree::Node, constants::Array{Float32, 1})
|
895 |
+
if tree.degree == 0
|
896 |
if tree.constant
|
897 |
tree.val = constants[1]
|
898 |
end
|
899 |
+
elseif tree.degree == 1
|
900 |
setConstants(tree.l, constants)
|
901 |
else
|
902 |
numberLeft = countConstants(tree.l)
|
|
|
915 |
# Use Nelder-Mead to optimize the constants in an equation
|
916 |
function optimizeConstants(member::PopMember)::PopMember
|
917 |
nconst = countConstants(member.tree)
|
918 |
+
if nconst == 0
|
919 |
return member
|
920 |
end
|
921 |
x0 = getConstants(member.tree)
|
922 |
f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
|
923 |
+
if size(x0)[1] == 1
|
924 |
algorithm = Optim.Newton
|
925 |
else
|
926 |
algorithm = Optim.NelderMead
|
|
|
1004 |
bestSubPops = [Population(1) for j=1:npopulations]
|
1005 |
hallOfFame = HallOfFame()
|
1006 |
curmaxsize = 3
|
1007 |
+
if warmupMaxsize == 0
|
1008 |
curmaxsize = maxsize
|
1009 |
end
|
1010 |
|
|
|
1073 |
numberSmallerAndBetter += 1
|
1074 |
end
|
1075 |
end
|
1076 |
+
betterThanAllSmaller = (numberSmallerAndBetter == 0)
|
1077 |
if betterThanAllSmaller
|
1078 |
println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
|
1079 |
push!(dominating, member)
|
|
|
1123 |
|
1124 |
cycles_complete -= 1
|
1125 |
cycles_elapsed = npopulations * niterations - cycles_complete
|
1126 |
+
if warmupMaxsize != 0 && cycles_elapsed % warmupMaxsize == 0
|
1127 |
curmaxsize += 1
|
1128 |
if curmaxsize > maxsize
|
1129 |
curmaxsize = maxsize
|
|
|
1173 |
numberSmallerAndBetter += 1
|
1174 |
end
|
1175 |
end
|
1176 |
+
betterThanAllSmaller = (numberSmallerAndBetter == 0)
|
1177 |
if betterThanAllSmaller
|
1178 |
delta_c = size - lastComplexity
|
1179 |
delta_l_mse = log(curMSE/lastMSE)
|
pysr/sr.py
CHANGED
@@ -286,35 +286,27 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
|
|
286 |
|
287 |
op_runner = ""
|
288 |
if len(binary_operators) > 0:
|
289 |
-
op_runner += """
|
290 |
-
@inline function BINOP
|
291 |
-
if i
|
292 |
-
|
293 |
-
x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
|
294 |
-
end"""
|
295 |
for i in range(1, len(binary_operators)):
|
296 |
op_runner += f"""
|
297 |
-
elseif i
|
298 |
-
|
299 |
-
x[j] = {binary_operators[i]}(x[j], y[j])
|
300 |
-
end"""
|
301 |
op_runner += """
|
302 |
end
|
303 |
end"""
|
304 |
|
305 |
if len(unary_operators) > 0:
|
306 |
-
op_runner += """
|
307 |
-
@inline function UNAOP
|
308 |
-
if i
|
309 |
-
|
310 |
-
x[j] = """f"{unary_operators[0]}(x[j])""""
|
311 |
-
end"""
|
312 |
for i in range(1, len(unary_operators)):
|
313 |
op_runner += f"""
|
314 |
-
elseif i
|
315 |
-
|
316 |
-
x[j] = {unary_operators[i]}(x[j])
|
317 |
-
end"""
|
318 |
op_runner += """
|
319 |
end
|
320 |
end"""
|
|
|
286 |
|
287 |
op_runner = ""
|
288 |
if len(binary_operators) > 0:
|
289 |
+
op_runner += f"""
|
290 |
+
@inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
|
291 |
+
if i == 1
|
292 |
+
return @fastmath {binary_operators[0]}(x, y)"""
|
|
|
|
|
293 |
for i in range(1, len(binary_operators)):
|
294 |
op_runner += f"""
|
295 |
+
elseif i == {i+1}
|
296 |
+
return @fastmath {binary_operators[i]}(x, y)"""
|
|
|
|
|
297 |
op_runner += """
|
298 |
end
|
299 |
end"""
|
300 |
|
301 |
if len(unary_operators) > 0:
|
302 |
+
op_runner += f"""
|
303 |
+
@inline function UNAOP(i::Int, x::Float32)::Float32
|
304 |
+
if i == 1
|
305 |
+
return @fastmath {unary_operators[0]}(x)"""
|
|
|
|
|
306 |
for i in range(1, len(unary_operators)):
|
307 |
op_runner += f"""
|
308 |
+
elseif i == {i+1}
|
309 |
+
return @fastmath {unary_operators[i]}(x)"""
|
|
|
|
|
310 |
op_runner += """
|
311 |
end
|
312 |
end"""
|