MilesCranmer commited on
Commit
dd17964
·
1 Parent(s): b5a1925

Switch to enum for functions

Browse files
Files changed (1) hide show
  1. julia/sr.jl +28 -28
julia/sr.jl CHANGED
@@ -50,20 +50,20 @@ mutable struct Node
50
  degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
51
  val::Union{Float32, Integer} #Either const value, or enumerates variable
52
  constant::Bool #false if variable
53
- op::Function #enumerates operator (for degree=1,2)
54
  l::Union{Node, Nothing}
55
  r::Union{Node, Nothing}
56
 
57
- Node(val::Float32) = new(0, val, true, id, nothing, nothing)
58
- Node(val::Integer) = new(0, val, false, id, nothing, nothing)
59
- Node(op, l::Node) = new(1, 0.0f0, false, op, l, nothing)
60
- Node(op, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
61
- Node(op, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
62
 
63
  #Allow to pass the leaf value without additional node call:
64
- Node(op, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
65
- Node(op, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
66
- Node(op, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
67
  end
68
 
69
  # Copy an equation (faster than deepcopy)
@@ -87,10 +87,10 @@ function evalTree(tree::Node, x::Array{Float32, 1}=Float32[])::Float32
87
  return x[tree.val]
88
  end
89
  elseif tree.degree == 1
90
- return tree.op(evalTree(tree.l, x))
91
  else
92
  right = Threads.@spawn evalTree(tree.r, x)
93
- return tree.op(evalTree(tree.l, x), fetch(right))
94
  end
95
  end
96
 
@@ -115,10 +115,10 @@ function stringTree(tree::Node)::String
115
  return "x$(tree.val - 1)"
116
  end
117
  elseif tree.degree == 1
118
- return "$(tree.op)($(stringTree(tree.l)))"
119
  else
120
  right = Threads.@spawn stringTree(tree.r)
121
- return "$(tree.op)($(stringTree(tree.l)), $(fetch(right)))"
122
  end
123
  end
124
 
@@ -192,9 +192,9 @@ function mutateOperator(tree::Node)::Node
192
  node = randomNode(tree)
193
  end
194
  if node.degree == 1
195
- node.op = unaops[rand(1:length(unaops))]
196
  else
197
- node.op = binops[rand(1:length(binops))]
198
  end
199
  return tree
200
  end
@@ -252,10 +252,10 @@ function evalTreeArray(tree::Node)::Array{Float32, 1}
252
  return ones(Float32, len) .* X[:, tree.val]
253
  end
254
  elseif tree.degree == 1
255
- return tree.op.(evalTreeArray(tree.l))
256
  else
257
  right = Threads.@spawn evalTreeArray(tree.r)
258
- return tree.op.(evalTreeArray(tree.l), fetch(right))
259
  end
260
  end
261
 
@@ -294,13 +294,13 @@ function appendRandomOp(tree::Node)::Node
294
 
295
  if makeNewBinOp
296
  newnode = Node(
297
- binops[rand(1:length(binops))],
298
  left,
299
  right
300
  )
301
  else
302
  newnode = Node(
303
- unaops[rand(1:length(unaops))],
304
  left
305
  )
306
  end
@@ -323,13 +323,13 @@ function popRandomOp(tree::Node)::Node
323
  if makeNewBinOp
324
  right = randomConstantNode()
325
  newnode = Node(
326
- binops[rand(1:length(binops))],
327
  left,
328
  right
329
  )
330
  else
331
  newnode = Node(
332
- unaops[rand(1:length(unaops))],
333
  left
334
  )
335
  end
@@ -352,13 +352,13 @@ function insertRandomOp(tree::Node)::Node
352
  if makeNewBinOp
353
  right = randomConstantNode()
354
  newnode = Node(
355
- binops[rand(1:length(binops))],
356
  left,
357
  right
358
  )
359
  else
360
  newnode = Node(
361
- unaops[rand(1:length(unaops))],
362
  left
363
  )
364
  end
@@ -459,7 +459,7 @@ function combineOperators(tree::Node)::Node
459
  # ((const + var) + const) => (const + var)
460
  # ((const * var) * const) => (const * var)
461
  # (anything commutative!)
462
- if tree.degree == 2 && (tree.op == plus || tree.op == mult)
463
  op = tree.op
464
  if tree.l.constant || tree.r.constant
465
  # Put the constant in r
@@ -475,10 +475,10 @@ function combineOperators(tree::Node)::Node
475
  if below.degree == 2 && below.op == op
476
  if below.l.constant
477
  tree = below
478
- tree.l.val = op(tree.l.val, topconstant)
479
  elseif below.r.constant
480
  tree = below
481
- tree.r.val = op(tree.r.val, topconstant)
482
  end
483
  end
484
  end
@@ -491,7 +491,7 @@ function simplifyTree(tree::Node)::Node
491
  if tree.degree == 1
492
  tree.l = simplifyTree(tree.l)
493
  if tree.l.degree == 0 && tree.l.constant
494
- return Node(tree.op(tree.l.val))
495
  end
496
  elseif tree.degree == 2
497
  right = Threads.@spawn simplifyTree(tree.r)
@@ -502,7 +502,7 @@ function simplifyTree(tree::Node)::Node
502
  tree.r.degree == 0 && tree.r.constant
503
  )
504
  if constantsBelow
505
- return Node(tree.op(tree.l.val, tree.r.val))
506
  end
507
  end
508
  return tree
 
50
  degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
51
  val::Union{Float32, Integer} #Either const value, or enumerates variable
52
  constant::Bool #false if variable
53
+ op::Integer #enumerates operator (separately for degree=1,2)
54
  l::Union{Node, Nothing}
55
  r::Union{Node, Nothing}
56
 
57
+ Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
58
+ Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
59
+ Node(op::Integer, l::Node) = new(1, 0.0f0, false, op, l, nothing)
60
+ Node(op::Integer, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
61
+ Node(op::Integer, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
62
 
63
  #Allow to pass the leaf value without additional node call:
64
+ Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
65
+ Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
66
+ Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
67
  end
68
 
69
  # Copy an equation (faster than deepcopy)
 
87
  return x[tree.val]
88
  end
89
  elseif tree.degree == 1
90
+ return unaops[tree.op](evalTree(tree.l, x))
91
  else
92
  right = Threads.@spawn evalTree(tree.r, x)
93
+ return binops[tree.op](evalTree(tree.l, x), fetch(right))
94
  end
95
  end
96
 
 
115
  return "x$(tree.val - 1)"
116
  end
117
  elseif tree.degree == 1
118
+ return "$(unaops[tree.op])($(stringTree(tree.l)))"
119
  else
120
  right = Threads.@spawn stringTree(tree.r)
121
+ return "$(binops[tree.op])($(stringTree(tree.l)), $(fetch(right)))"
122
  end
123
  end
124
 
 
192
  node = randomNode(tree)
193
  end
194
  if node.degree == 1
195
+ node.op = rand(1:length(unaops))
196
  else
197
+ node.op = rand(1:length(binops))
198
  end
199
  return tree
200
  end
 
252
  return ones(Float32, len) .* X[:, tree.val]
253
  end
254
  elseif tree.degree == 1
255
+ return unaops[tree.op].(evalTreeArray(tree.l))
256
  else
257
  right = Threads.@spawn evalTreeArray(tree.r)
258
+ return binops[tree.op].(evalTreeArray(tree.l), fetch(right))
259
  end
260
  end
261
 
 
294
 
295
  if makeNewBinOp
296
  newnode = Node(
297
+ rand(1:length(binops)),
298
  left,
299
  right
300
  )
301
  else
302
  newnode = Node(
303
+ rand(1:length(unaops)),
304
  left
305
  )
306
  end
 
323
  if makeNewBinOp
324
  right = randomConstantNode()
325
  newnode = Node(
326
+ rand(1:length(binops)),
327
  left,
328
  right
329
  )
330
  else
331
  newnode = Node(
332
+ rand(1:length(unaops)),
333
  left
334
  )
335
  end
 
352
  if makeNewBinOp
353
  right = randomConstantNode()
354
  newnode = Node(
355
+ rand(1:length(binops)),
356
  left,
357
  right
358
  )
359
  else
360
  newnode = Node(
361
+ rand(1:length(unaops)),
362
  left
363
  )
364
  end
 
459
  # ((const + var) + const) => (const + var)
460
  # ((const * var) * const) => (const * var)
461
  # (anything commutative!)
462
+ if tree.degree == 2 && (binops[tree.op] == plus || binops[tree.op] == mult)
463
  op = tree.op
464
  if tree.l.constant || tree.r.constant
465
  # Put the constant in r
 
475
  if below.degree == 2 && below.op == op
476
  if below.l.constant
477
  tree = below
478
+ tree.l.val = binops[op](tree.l.val, topconstant)
479
  elseif below.r.constant
480
  tree = below
481
+ tree.r.val = binops[op](tree.r.val, topconstant)
482
  end
483
  end
484
  end
 
491
  if tree.degree == 1
492
  tree.l = simplifyTree(tree.l)
493
  if tree.l.degree == 0 && tree.l.constant
494
+ return Node(unaops[tree.op](tree.l.val))
495
  end
496
  elseif tree.degree == 2
497
  right = Threads.@spawn simplifyTree(tree.r)
 
502
  tree.r.degree == 0 && tree.r.constant
503
  )
504
  if constantsBelow
505
+ return Node(binops[tree.op](tree.l.val, tree.r.val))
506
  end
507
  end
508
  return tree