Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
dd17964
1
Parent(s):
b5a1925
Switch to enum for functions
Browse files- julia/sr.jl +28 -28
julia/sr.jl
CHANGED
@@ -50,20 +50,20 @@ mutable struct Node
|
|
50 |
degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
|
51 |
val::Union{Float32, Integer} #Either const value, or enumerates variable
|
52 |
constant::Bool #false if variable
|
53 |
-
op::
|
54 |
l::Union{Node, Nothing}
|
55 |
r::Union{Node, Nothing}
|
56 |
|
57 |
-
Node(val::Float32) = new(0, val, true,
|
58 |
-
Node(val::Integer) = new(0, val, false,
|
59 |
-
Node(op, l::Node) = new(1, 0.0f0, false, op, l, nothing)
|
60 |
-
Node(op, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
|
61 |
-
Node(op, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
|
62 |
|
63 |
#Allow to pass the leaf value without additional node call:
|
64 |
-
Node(op, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
|
65 |
-
Node(op, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
|
66 |
-
Node(op, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
|
67 |
end
|
68 |
|
69 |
# Copy an equation (faster than deepcopy)
|
@@ -87,10 +87,10 @@ function evalTree(tree::Node, x::Array{Float32, 1}=Float32[])::Float32
|
|
87 |
return x[tree.val]
|
88 |
end
|
89 |
elseif tree.degree == 1
|
90 |
-
return tree.op(evalTree(tree.l, x))
|
91 |
else
|
92 |
right = Threads.@spawn evalTree(tree.r, x)
|
93 |
-
return tree.op(evalTree(tree.l, x), fetch(right))
|
94 |
end
|
95 |
end
|
96 |
|
@@ -115,10 +115,10 @@ function stringTree(tree::Node)::String
|
|
115 |
return "x$(tree.val - 1)"
|
116 |
end
|
117 |
elseif tree.degree == 1
|
118 |
-
return "$(tree.op)($(stringTree(tree.l)))"
|
119 |
else
|
120 |
right = Threads.@spawn stringTree(tree.r)
|
121 |
-
return "$(tree.op)($(stringTree(tree.l)), $(fetch(right)))"
|
122 |
end
|
123 |
end
|
124 |
|
@@ -192,9 +192,9 @@ function mutateOperator(tree::Node)::Node
|
|
192 |
node = randomNode(tree)
|
193 |
end
|
194 |
if node.degree == 1
|
195 |
-
node.op =
|
196 |
else
|
197 |
-
node.op =
|
198 |
end
|
199 |
return tree
|
200 |
end
|
@@ -252,10 +252,10 @@ function evalTreeArray(tree::Node)::Array{Float32, 1}
|
|
252 |
return ones(Float32, len) .* X[:, tree.val]
|
253 |
end
|
254 |
elseif tree.degree == 1
|
255 |
-
return tree.op.(evalTreeArray(tree.l))
|
256 |
else
|
257 |
right = Threads.@spawn evalTreeArray(tree.r)
|
258 |
-
return tree.op.(evalTreeArray(tree.l), fetch(right))
|
259 |
end
|
260 |
end
|
261 |
|
@@ -294,13 +294,13 @@ function appendRandomOp(tree::Node)::Node
|
|
294 |
|
295 |
if makeNewBinOp
|
296 |
newnode = Node(
|
297 |
-
|
298 |
left,
|
299 |
right
|
300 |
)
|
301 |
else
|
302 |
newnode = Node(
|
303 |
-
|
304 |
left
|
305 |
)
|
306 |
end
|
@@ -323,13 +323,13 @@ function popRandomOp(tree::Node)::Node
|
|
323 |
if makeNewBinOp
|
324 |
right = randomConstantNode()
|
325 |
newnode = Node(
|
326 |
-
|
327 |
left,
|
328 |
right
|
329 |
)
|
330 |
else
|
331 |
newnode = Node(
|
332 |
-
|
333 |
left
|
334 |
)
|
335 |
end
|
@@ -352,13 +352,13 @@ function insertRandomOp(tree::Node)::Node
|
|
352 |
if makeNewBinOp
|
353 |
right = randomConstantNode()
|
354 |
newnode = Node(
|
355 |
-
|
356 |
left,
|
357 |
right
|
358 |
)
|
359 |
else
|
360 |
newnode = Node(
|
361 |
-
|
362 |
left
|
363 |
)
|
364 |
end
|
@@ -459,7 +459,7 @@ function combineOperators(tree::Node)::Node
|
|
459 |
# ((const + var) + const) => (const + var)
|
460 |
# ((const * var) * const) => (const * var)
|
461 |
# (anything commutative!)
|
462 |
-
if tree.degree == 2 && (tree.op == plus || tree.op == mult)
|
463 |
op = tree.op
|
464 |
if tree.l.constant || tree.r.constant
|
465 |
# Put the constant in r
|
@@ -475,10 +475,10 @@ function combineOperators(tree::Node)::Node
|
|
475 |
if below.degree == 2 && below.op == op
|
476 |
if below.l.constant
|
477 |
tree = below
|
478 |
-
tree.l.val = op(tree.l.val, topconstant)
|
479 |
elseif below.r.constant
|
480 |
tree = below
|
481 |
-
tree.r.val = op(tree.r.val, topconstant)
|
482 |
end
|
483 |
end
|
484 |
end
|
@@ -491,7 +491,7 @@ function simplifyTree(tree::Node)::Node
|
|
491 |
if tree.degree == 1
|
492 |
tree.l = simplifyTree(tree.l)
|
493 |
if tree.l.degree == 0 && tree.l.constant
|
494 |
-
return Node(tree.op(tree.l.val))
|
495 |
end
|
496 |
elseif tree.degree == 2
|
497 |
right = Threads.@spawn simplifyTree(tree.r)
|
@@ -502,7 +502,7 @@ function simplifyTree(tree::Node)::Node
|
|
502 |
tree.r.degree == 0 && tree.r.constant
|
503 |
)
|
504 |
if constantsBelow
|
505 |
-
return Node(tree.op(tree.l.val, tree.r.val))
|
506 |
end
|
507 |
end
|
508 |
return tree
|
|
|
50 |
degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
|
51 |
val::Union{Float32, Integer} #Either const value, or enumerates variable
|
52 |
constant::Bool #false if variable
|
53 |
+
op::Integer #enumerates operator (separately for degree=1,2)
|
54 |
l::Union{Node, Nothing}
|
55 |
r::Union{Node, Nothing}
|
56 |
|
57 |
+
Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
|
58 |
+
Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
|
59 |
+
Node(op::Integer, l::Node) = new(1, 0.0f0, false, op, l, nothing)
|
60 |
+
Node(op::Integer, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
|
61 |
+
Node(op::Integer, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
|
62 |
|
63 |
#Allow to pass the leaf value without additional node call:
|
64 |
+
Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
|
65 |
+
Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
|
66 |
+
Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
|
67 |
end
|
68 |
|
69 |
# Copy an equation (faster than deepcopy)
|
|
|
87 |
return x[tree.val]
|
88 |
end
|
89 |
elseif tree.degree == 1
|
90 |
+
return unaops[tree.op](evalTree(tree.l, x))
|
91 |
else
|
92 |
right = Threads.@spawn evalTree(tree.r, x)
|
93 |
+
return binops[tree.op](evalTree(tree.l, x), fetch(right))
|
94 |
end
|
95 |
end
|
96 |
|
|
|
115 |
return "x$(tree.val - 1)"
|
116 |
end
|
117 |
elseif tree.degree == 1
|
118 |
+
return "$(unaops[tree.op])($(stringTree(tree.l)))"
|
119 |
else
|
120 |
right = Threads.@spawn stringTree(tree.r)
|
121 |
+
return "$(binops[tree.op])($(stringTree(tree.l)), $(fetch(right)))"
|
122 |
end
|
123 |
end
|
124 |
|
|
|
192 |
node = randomNode(tree)
|
193 |
end
|
194 |
if node.degree == 1
|
195 |
+
node.op = rand(1:length(unaops))
|
196 |
else
|
197 |
+
node.op = rand(1:length(binops))
|
198 |
end
|
199 |
return tree
|
200 |
end
|
|
|
252 |
return ones(Float32, len) .* X[:, tree.val]
|
253 |
end
|
254 |
elseif tree.degree == 1
|
255 |
+
return unaops[tree.op].(evalTreeArray(tree.l))
|
256 |
else
|
257 |
right = Threads.@spawn evalTreeArray(tree.r)
|
258 |
+
return binops[tree.op].(evalTreeArray(tree.l), fetch(right))
|
259 |
end
|
260 |
end
|
261 |
|
|
|
294 |
|
295 |
if makeNewBinOp
|
296 |
newnode = Node(
|
297 |
+
rand(1:length(binops)),
|
298 |
left,
|
299 |
right
|
300 |
)
|
301 |
else
|
302 |
newnode = Node(
|
303 |
+
rand(1:length(unaops)),
|
304 |
left
|
305 |
)
|
306 |
end
|
|
|
323 |
if makeNewBinOp
|
324 |
right = randomConstantNode()
|
325 |
newnode = Node(
|
326 |
+
rand(1:length(binops)),
|
327 |
left,
|
328 |
right
|
329 |
)
|
330 |
else
|
331 |
newnode = Node(
|
332 |
+
rand(1:length(unaops)),
|
333 |
left
|
334 |
)
|
335 |
end
|
|
|
352 |
if makeNewBinOp
|
353 |
right = randomConstantNode()
|
354 |
newnode = Node(
|
355 |
+
rand(1:length(binops)),
|
356 |
left,
|
357 |
right
|
358 |
)
|
359 |
else
|
360 |
newnode = Node(
|
361 |
+
rand(1:length(unaops)),
|
362 |
left
|
363 |
)
|
364 |
end
|
|
|
459 |
# ((const + var) + const) => (const + var)
|
460 |
# ((const * var) * const) => (const * var)
|
461 |
# (anything commutative!)
|
462 |
+
if tree.degree == 2 && (binops[tree.op] == plus || binops[tree.op] == mult)
|
463 |
op = tree.op
|
464 |
if tree.l.constant || tree.r.constant
|
465 |
# Put the constant in r
|
|
|
475 |
if below.degree == 2 && below.op == op
|
476 |
if below.l.constant
|
477 |
tree = below
|
478 |
+
tree.l.val = binops[op](tree.l.val, topconstant)
|
479 |
elseif below.r.constant
|
480 |
tree = below
|
481 |
+
tree.r.val = binops[op](tree.r.val, topconstant)
|
482 |
end
|
483 |
end
|
484 |
end
|
|
|
491 |
if tree.degree == 1
|
492 |
tree.l = simplifyTree(tree.l)
|
493 |
if tree.l.degree == 0 && tree.l.constant
|
494 |
+
return Node(unaops[tree.op](tree.l.val))
|
495 |
end
|
496 |
elseif tree.degree == 2
|
497 |
right = Threads.@spawn simplifyTree(tree.r)
|
|
|
502 |
tree.r.degree == 0 && tree.r.constant
|
503 |
)
|
504 |
if constantsBelow
|
505 |
+
return Node(binops[tree.op](tree.l.val, tree.r.val))
|
506 |
end
|
507 |
end
|
508 |
return tree
|