MilesCranmer commited on
Commit
1f4e612
1 Parent(s): eccca5d

Reset to last working copy

Browse files
Files changed (2) hide show
  1. julia/sr.jl +65 -59
  2. pysr/sr.py +12 -20
julia/sr.jl CHANGED
@@ -96,9 +96,9 @@ end
96
 
97
  # Copy an equation (faster than deepcopy)
98
  function copyNode(tree::Node)::Node
99
- if tree.degree === 0
100
  return Node(tree.val)
101
- elseif tree.degree === 1
102
  return Node(tree.op, copyNode(tree.l))
103
  else
104
  return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
@@ -107,9 +107,9 @@ end
107
 
108
  # Count the operators, constants, variables in an equation
109
  function countNodes(tree::Node)::Integer
110
- if tree.degree === 0
111
  return 1
112
- elseif tree.degree === 1
113
  return 1 + countNodes(tree.l)
114
  else
115
  return 1 + countNodes(tree.l) + countNodes(tree.r)
@@ -118,9 +118,9 @@ end
118
 
119
  # Count the max depth of a tree
120
  function countDepth(tree::Node)::Integer
121
- if tree.degree === 0
122
  return 1
123
- elseif tree.degree === 1
124
  return 1 + countDepth(tree.l)
125
  else
126
  return 1 + max(countDepth(tree.l), countDepth(tree.r))
@@ -129,7 +129,7 @@ end
129
 
130
  # Convert an equation to a string
131
  function stringTree(tree::Node)::String
132
- if tree.degree === 0
133
  if tree.constant
134
  return string(tree.val)
135
  else
@@ -139,7 +139,7 @@ function stringTree(tree::Node)::String
139
  return "x$(tree.val - 1)"
140
  end
141
  end
142
- elseif tree.degree === 1
143
  return "$(unaops[tree.op])($(stringTree(tree.l)))"
144
  else
145
  return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
@@ -153,7 +153,7 @@ end
153
 
154
  # Return a random node from the tree
155
  function randomNode(tree::Node)::Node
156
- if tree.degree === 0
157
  return tree
158
  end
159
  a = countNodes(tree)
@@ -162,14 +162,14 @@ function randomNode(tree::Node)::Node
162
  if tree.degree >= 1
163
  b = countNodes(tree.l)
164
  end
165
- if tree.degree === 2
166
  c = countNodes(tree.r)
167
  end
168
 
169
  i = rand(1:1+b+c)
170
  if i <= b
171
  return randomNode(tree.l)
172
- elseif i === b + 1
173
  return tree
174
  end
175
 
@@ -178,9 +178,9 @@ end
178
 
179
  # Count the number of unary operators in the equation
180
  function countUnaryOperators(tree::Node)::Integer
181
- if tree.degree === 0
182
  return 0
183
- elseif tree.degree === 1
184
  return 1 + countUnaryOperators(tree.l)
185
  else
186
  return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
@@ -189,9 +189,9 @@ end
189
 
190
  # Count the number of binary operators in the equation
191
  function countBinaryOperators(tree::Node)::Integer
192
- if tree.degree === 0
193
  return 0
194
- elseif tree.degree === 1
195
  return 0 + countBinaryOperators(tree.l)
196
  else
197
  return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
@@ -206,14 +206,14 @@ end
206
  # Randomly convert an operator into another one (binary->binary;
207
  # unary->unary)
208
  function mutateOperator(tree::Node)::Node
209
- if countOperators(tree) === 0
210
  return tree
211
  end
212
  node = randomNode(tree)
213
- while node.degree === 0
214
  node = randomNode(tree)
215
  end
216
- if node.degree === 1
217
  node.op = rand(1:length(unaops))
218
  else
219
  node.op = rand(1:length(binops))
@@ -223,9 +223,9 @@ end
223
 
224
  # Count the number of constants in an equation
225
  function countConstants(tree::Node)::Integer
226
- if tree.degree === 0
227
  return convert(Integer, tree.constant)
228
- elseif tree.degree === 1
229
  return 0 + countConstants(tree.l)
230
  else
231
  return 0 + countConstants(tree.l) + countConstants(tree.r)
@@ -238,11 +238,11 @@ function mutateConstant(
238
  probNegate::Float32=0.01f0)::Node
239
  # T is between 0 and 1.
240
 
241
- if countConstants(tree) === 0
242
  return tree
243
  end
244
  node = randomNode(tree)
245
- while node.degree !== 0 || node.constant === false
246
  node = randomNode(tree)
247
  end
248
 
@@ -273,19 +273,21 @@ end
273
  # Evaluate an equation over an array of datapoints
274
  function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
275
  clen = size(cX)[1]
276
- if tree.degree === 0
277
  if tree.constant
278
  return fill(tree.val, clen)
279
  else
280
  return copy(cX[:, tree.val])
281
  end
282
- elseif tree.degree === 1
283
  cumulator = evalTreeArray(tree.l, cX)
284
  if cumulator === nothing
285
  return nothing
286
  end
287
  op_idx = tree.op
288
- UNAOP!(cumulator, op_idx, clen)
 
 
289
  @inbounds for i=1:clen
290
  if isinf(cumulator[i]) || isnan(cumulator[i])
291
  return nothing
@@ -301,8 +303,12 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
301
  if array2 === nothing
302
  return nothing
303
  end
 
304
  op_idx = tree.op
305
- BINOP!(cumulator, array2, op_idx, clen)
 
 
 
306
  @inbounds for i=1:clen
307
  if isinf(cumulator[i]) || isnan(cumulator[i])
308
  return nothing
@@ -350,7 +356,7 @@ end
350
  # Add a random unary/binary operation to the end of a tree
351
  function appendRandomOp(tree::Node)::Node
352
  node = randomNode(tree)
353
- while node.degree !== 0
354
  node = randomNode(tree)
355
  end
356
 
@@ -458,7 +464,7 @@ end
458
 
459
  # Return a random node from the tree with parent
460
  function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
461
- if tree.degree === 0
462
  return tree, parent
463
  end
464
  a = countNodes(tree)
@@ -467,14 +473,14 @@ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{No
467
  if tree.degree >= 1
468
  b = countNodes(tree.l)
469
  end
470
- if tree.degree === 2
471
  c = countNodes(tree.r)
472
  end
473
 
474
  i = rand(1:1+b+c)
475
  if i <= b
476
  return randomNodeAndParent(tree.l, tree)
477
- elseif i === b + 1
478
  return tree, parent
479
  end
480
 
@@ -487,7 +493,7 @@ function deleteRandomOp(tree::Node)::Node
487
  node, parent = randomNodeAndParent(tree, nothing)
488
  isroot = (parent === nothing)
489
 
490
- if node.degree === 0
491
  # Replace with new constant
492
  newnode = randomConstantNode()
493
  node.l = newnode.l
@@ -496,7 +502,7 @@ function deleteRandomOp(tree::Node)::Node
496
  node.degree = newnode.degree
497
  node.val = newnode.val
498
  node.constant = newnode.constant
499
- elseif node.degree === 1
500
  # Join one of the children with the parent
501
  if isroot
502
  return node.l
@@ -536,17 +542,17 @@ function combineOperators(tree::Node)::Node
536
  # ((const - var) - const) => (const - var)
537
  # (want to add anything commutative!)
538
  # TODO - need to combine plus/sub if they are both there.
539
- if tree.degree === 0
540
  return tree
541
- elseif tree.degree === 1
542
  tree.l = combineOperators(tree.l)
543
- elseif tree.degree === 2
544
  tree.l = combineOperators(tree.l)
545
  tree.r = combineOperators(tree.r)
546
  end
547
 
548
- top_level_constant = tree.degree === 2 && (tree.l.constant || tree.r.constant)
549
- if tree.degree === 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
550
  op = tree.op
551
  # Put the constant in r
552
  if tree.l.constant
@@ -557,7 +563,7 @@ function combineOperators(tree::Node)::Node
557
  topconstant = tree.r.val
558
  # Simplify down first
559
  below = tree.l
560
- if below.degree === 2 && below.op === op
561
  if below.l.constant
562
  tree = below
563
  tree.l.val = binops[op](tree.l.val, topconstant)
@@ -568,11 +574,11 @@ function combineOperators(tree::Node)::Node
568
  end
569
  end
570
 
571
- if tree.degree === 2 && binops[tree.op] === sub && top_level_constant
572
  # Currently just simplifies subtraction. (can't assume both plus and sub are operators)
573
  # Not commutative, so use different op.
574
  if tree.l.constant
575
- if tree.r.degree === 2 && binops[tree.r.op] === sub
576
  if tree.r.l.constant
577
  #(const - (const - var)) => (var - const)
578
  l = tree.l
@@ -591,7 +597,7 @@ function combineOperators(tree::Node)::Node
591
  end
592
  end
593
  else #tree.r.constant is true
594
- if tree.l.degree === 2 && binops[tree.l.op] === sub
595
  if tree.l.l.constant
596
  #((const - var) - const) => (const - var)
597
  l = tree.l
@@ -616,17 +622,17 @@ end
616
 
617
  # Simplify tree
618
  function simplifyTree(tree::Node)::Node
619
- if tree.degree === 1
620
  tree.l = simplifyTree(tree.l)
621
- if tree.l.degree === 0 && tree.l.constant
622
  return Node(unaops[tree.op](tree.l.val))
623
  end
624
- elseif tree.degree === 2
625
  tree.l = simplifyTree(tree.l)
626
  tree.r = simplifyTree(tree.r)
627
  constantsBelow = (
628
- tree.l.degree === 0 && tree.l.constant &&
629
- tree.r.degree === 0 && tree.r.constant
630
  )
631
  if constantsBelow
632
  return Node(binops[tree.op](tree.l.val, tree.r.val))
@@ -648,9 +654,9 @@ end
648
 
649
  # Check if any power operator is to the power of a complex expression
650
  function deepPow(tree::Node)::Integer
651
- if tree.degree === 0
652
  return 0
653
- elseif tree.degree === 1
654
  return 0 + deepPow(tree.l)
655
  else
656
  if binops[tree.op] === pow
@@ -857,7 +863,7 @@ function run(
857
  pop = regEvolCycle(pop, 1.0f0, curmaxsize)
858
  end
859
 
860
- if verbosity > 0 && (iT % verbosity === 0)
861
  bestPops = bestSubPop(pop)
862
  bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
863
  bestCurScore = bestPops.members[bestCurScoreIdx].score
@@ -870,13 +876,13 @@ end
870
 
871
  # Get all the constants from a tree
872
  function getConstants(tree::Node)::Array{Float32, 1}
873
- if tree.degree === 0
874
  if tree.constant
875
  return [tree.val]
876
  else
877
  return Float32[]
878
  end
879
- elseif tree.degree === 1
880
  return getConstants(tree.l)
881
  else
882
  both = [getConstants(tree.l), getConstants(tree.r)]
@@ -886,11 +892,11 @@ end
886
 
887
  # Set all the constants inside a tree
888
  function setConstants(tree::Node, constants::Array{Float32, 1})
889
- if tree.degree === 0
890
  if tree.constant
891
  tree.val = constants[1]
892
  end
893
- elseif tree.degree === 1
894
  setConstants(tree.l, constants)
895
  else
896
  numberLeft = countConstants(tree.l)
@@ -909,12 +915,12 @@ end
909
  # Use Nelder-Mead to optimize the constants in an equation
910
  function optimizeConstants(member::PopMember)::PopMember
911
  nconst = countConstants(member.tree)
912
- if nconst === 0
913
  return member
914
  end
915
  x0 = getConstants(member.tree)
916
  f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
917
- if size(x0)[1] === 1
918
  algorithm = Optim.Newton
919
  else
920
  algorithm = Optim.NelderMead
@@ -998,7 +1004,7 @@ function fullRun(niterations::Integer;
998
  bestSubPops = [Population(1) for j=1:npopulations]
999
  hallOfFame = HallOfFame()
1000
  curmaxsize = 3
1001
- if warmupMaxsize === 0
1002
  curmaxsize = maxsize
1003
  end
1004
 
@@ -1067,7 +1073,7 @@ function fullRun(niterations::Integer;
1067
  numberSmallerAndBetter += 1
1068
  end
1069
  end
1070
- betterThanAllSmaller = (numberSmallerAndBetter === 0)
1071
  if betterThanAllSmaller
1072
  println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
1073
  push!(dominating, member)
@@ -1117,7 +1123,7 @@ function fullRun(niterations::Integer;
1117
 
1118
  cycles_complete -= 1
1119
  cycles_elapsed = npopulations * niterations - cycles_complete
1120
- if warmupMaxsize !== 0 && cycles_elapsed % warmupMaxsize === 0
1121
  curmaxsize += 1
1122
  if curmaxsize > maxsize
1123
  curmaxsize = maxsize
@@ -1167,7 +1173,7 @@ function fullRun(niterations::Integer;
1167
  numberSmallerAndBetter += 1
1168
  end
1169
  end
1170
- betterThanAllSmaller = (numberSmallerAndBetter === 0)
1171
  if betterThanAllSmaller
1172
  delta_c = size - lastComplexity
1173
  delta_l_mse = log(curMSE/lastMSE)
 
96
 
97
  # Copy an equation (faster than deepcopy)
98
  function copyNode(tree::Node)::Node
99
+ if tree.degree == 0
100
  return Node(tree.val)
101
+ elseif tree.degree == 1
102
  return Node(tree.op, copyNode(tree.l))
103
  else
104
  return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
 
107
 
108
  # Count the operators, constants, variables in an equation
109
  function countNodes(tree::Node)::Integer
110
+ if tree.degree == 0
111
  return 1
112
+ elseif tree.degree == 1
113
  return 1 + countNodes(tree.l)
114
  else
115
  return 1 + countNodes(tree.l) + countNodes(tree.r)
 
118
 
119
  # Count the max depth of a tree
120
  function countDepth(tree::Node)::Integer
121
+ if tree.degree == 0
122
  return 1
123
+ elseif tree.degree == 1
124
  return 1 + countDepth(tree.l)
125
  else
126
  return 1 + max(countDepth(tree.l), countDepth(tree.r))
 
129
 
130
  # Convert an equation to a string
131
  function stringTree(tree::Node)::String
132
+ if tree.degree == 0
133
  if tree.constant
134
  return string(tree.val)
135
  else
 
139
  return "x$(tree.val - 1)"
140
  end
141
  end
142
+ elseif tree.degree == 1
143
  return "$(unaops[tree.op])($(stringTree(tree.l)))"
144
  else
145
  return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
 
153
 
154
  # Return a random node from the tree
155
  function randomNode(tree::Node)::Node
156
+ if tree.degree == 0
157
  return tree
158
  end
159
  a = countNodes(tree)
 
162
  if tree.degree >= 1
163
  b = countNodes(tree.l)
164
  end
165
+ if tree.degree == 2
166
  c = countNodes(tree.r)
167
  end
168
 
169
  i = rand(1:1+b+c)
170
  if i <= b
171
  return randomNode(tree.l)
172
+ elseif i == b + 1
173
  return tree
174
  end
175
 
 
178
 
179
  # Count the number of unary operators in the equation
180
  function countUnaryOperators(tree::Node)::Integer
181
+ if tree.degree == 0
182
  return 0
183
+ elseif tree.degree == 1
184
  return 1 + countUnaryOperators(tree.l)
185
  else
186
  return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
 
189
 
190
  # Count the number of binary operators in the equation
191
  function countBinaryOperators(tree::Node)::Integer
192
+ if tree.degree == 0
193
  return 0
194
+ elseif tree.degree == 1
195
  return 0 + countBinaryOperators(tree.l)
196
  else
197
  return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
 
206
  # Randomly convert an operator into another one (binary->binary;
207
  # unary->unary)
208
  function mutateOperator(tree::Node)::Node
209
+ if countOperators(tree) == 0
210
  return tree
211
  end
212
  node = randomNode(tree)
213
+ while node.degree == 0
214
  node = randomNode(tree)
215
  end
216
+ if node.degree == 1
217
  node.op = rand(1:length(unaops))
218
  else
219
  node.op = rand(1:length(binops))
 
223
 
224
  # Count the number of constants in an equation
225
  function countConstants(tree::Node)::Integer
226
+ if tree.degree == 0
227
  return convert(Integer, tree.constant)
228
+ elseif tree.degree == 1
229
  return 0 + countConstants(tree.l)
230
  else
231
  return 0 + countConstants(tree.l) + countConstants(tree.r)
 
238
  probNegate::Float32=0.01f0)::Node
239
  # T is between 0 and 1.
240
 
241
+ if countConstants(tree) == 0
242
  return tree
243
  end
244
  node = randomNode(tree)
245
+ while node.degree != 0 || node.constant == false
246
  node = randomNode(tree)
247
  end
248
 
 
273
  # Evaluate an equation over an array of datapoints
274
  function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
275
  clen = size(cX)[1]
276
+ if tree.degree == 0
277
  if tree.constant
278
  return fill(tree.val, clen)
279
  else
280
  return copy(cX[:, tree.val])
281
  end
282
+ elseif tree.degree == 1
283
  cumulator = evalTreeArray(tree.l, cX)
284
  if cumulator === nothing
285
  return nothing
286
  end
287
  op_idx = tree.op
288
+ @inbounds @simd for i=1:clen
289
+ cumulator[i] = UNAOP(op_idx, cumulator[i])
290
+ end
291
  @inbounds for i=1:clen
292
  if isinf(cumulator[i]) || isnan(cumulator[i])
293
  return nothing
 
303
  if array2 === nothing
304
  return nothing
305
  end
306
+
307
  op_idx = tree.op
308
+
309
+ @inbounds @simd for i=1:clen
310
+ cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
311
+ end
312
  @inbounds for i=1:clen
313
  if isinf(cumulator[i]) || isnan(cumulator[i])
314
  return nothing
 
356
  # Add a random unary/binary operation to the end of a tree
357
  function appendRandomOp(tree::Node)::Node
358
  node = randomNode(tree)
359
+ while node.degree != 0
360
  node = randomNode(tree)
361
  end
362
 
 
464
 
465
  # Return a random node from the tree with parent
466
  function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
467
+ if tree.degree == 0
468
  return tree, parent
469
  end
470
  a = countNodes(tree)
 
473
  if tree.degree >= 1
474
  b = countNodes(tree.l)
475
  end
476
+ if tree.degree == 2
477
  c = countNodes(tree.r)
478
  end
479
 
480
  i = rand(1:1+b+c)
481
  if i <= b
482
  return randomNodeAndParent(tree.l, tree)
483
+ elseif i == b + 1
484
  return tree, parent
485
  end
486
 
 
493
  node, parent = randomNodeAndParent(tree, nothing)
494
  isroot = (parent === nothing)
495
 
496
+ if node.degree == 0
497
  # Replace with new constant
498
  newnode = randomConstantNode()
499
  node.l = newnode.l
 
502
  node.degree = newnode.degree
503
  node.val = newnode.val
504
  node.constant = newnode.constant
505
+ elseif node.degree == 1
506
  # Join one of the children with the parent
507
  if isroot
508
  return node.l
 
542
  # ((const - var) - const) => (const - var)
543
  # (want to add anything commutative!)
544
  # TODO - need to combine plus/sub if they are both there.
545
+ if tree.degree == 0
546
  return tree
547
+ elseif tree.degree == 1
548
  tree.l = combineOperators(tree.l)
549
+ elseif tree.degree == 2
550
  tree.l = combineOperators(tree.l)
551
  tree.r = combineOperators(tree.r)
552
  end
553
 
554
+ top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
555
+ if tree.degree == 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
556
  op = tree.op
557
  # Put the constant in r
558
  if tree.l.constant
 
563
  topconstant = tree.r.val
564
  # Simplify down first
565
  below = tree.l
566
+ if below.degree == 2 && below.op == op
567
  if below.l.constant
568
  tree = below
569
  tree.l.val = binops[op](tree.l.val, topconstant)
 
574
  end
575
  end
576
 
577
+ if tree.degree == 2 && binops[tree.op] === sub && top_level_constant
578
  # Currently just simplifies subtraction. (can't assume both plus and sub are operators)
579
  # Not commutative, so use different op.
580
  if tree.l.constant
581
+ if tree.r.degree == 2 && binops[tree.r.op] === sub
582
  if tree.r.l.constant
583
  #(const - (const - var)) => (var - const)
584
  l = tree.l
 
597
  end
598
  end
599
  else #tree.r.constant is true
600
+ if tree.l.degree == 2 && binops[tree.l.op] === sub
601
  if tree.l.l.constant
602
  #((const - var) - const) => (const - var)
603
  l = tree.l
 
622
 
623
  # Simplify tree
624
  function simplifyTree(tree::Node)::Node
625
+ if tree.degree == 1
626
  tree.l = simplifyTree(tree.l)
627
+ if tree.l.degree == 0 && tree.l.constant
628
  return Node(unaops[tree.op](tree.l.val))
629
  end
630
+ elseif tree.degree == 2
631
  tree.l = simplifyTree(tree.l)
632
  tree.r = simplifyTree(tree.r)
633
  constantsBelow = (
634
+ tree.l.degree == 0 && tree.l.constant &&
635
+ tree.r.degree == 0 && tree.r.constant
636
  )
637
  if constantsBelow
638
  return Node(binops[tree.op](tree.l.val, tree.r.val))
 
654
 
655
  # Check if any power operator is to the power of a complex expression
656
  function deepPow(tree::Node)::Integer
657
+ if tree.degree == 0
658
  return 0
659
+ elseif tree.degree == 1
660
  return 0 + deepPow(tree.l)
661
  else
662
  if binops[tree.op] === pow
 
863
  pop = regEvolCycle(pop, 1.0f0, curmaxsize)
864
  end
865
 
866
+ if verbosity > 0 && (iT % verbosity == 0)
867
  bestPops = bestSubPop(pop)
868
  bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
869
  bestCurScore = bestPops.members[bestCurScoreIdx].score
 
876
 
877
  # Get all the constants from a tree
878
  function getConstants(tree::Node)::Array{Float32, 1}
879
+ if tree.degree == 0
880
  if tree.constant
881
  return [tree.val]
882
  else
883
  return Float32[]
884
  end
885
+ elseif tree.degree == 1
886
  return getConstants(tree.l)
887
  else
888
  both = [getConstants(tree.l), getConstants(tree.r)]
 
892
 
893
  # Set all the constants inside a tree
894
  function setConstants(tree::Node, constants::Array{Float32, 1})
895
+ if tree.degree == 0
896
  if tree.constant
897
  tree.val = constants[1]
898
  end
899
+ elseif tree.degree == 1
900
  setConstants(tree.l, constants)
901
  else
902
  numberLeft = countConstants(tree.l)
 
915
  # Use Nelder-Mead to optimize the constants in an equation
916
  function optimizeConstants(member::PopMember)::PopMember
917
  nconst = countConstants(member.tree)
918
+ if nconst == 0
919
  return member
920
  end
921
  x0 = getConstants(member.tree)
922
  f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
923
+ if size(x0)[1] == 1
924
  algorithm = Optim.Newton
925
  else
926
  algorithm = Optim.NelderMead
 
1004
  bestSubPops = [Population(1) for j=1:npopulations]
1005
  hallOfFame = HallOfFame()
1006
  curmaxsize = 3
1007
+ if warmupMaxsize == 0
1008
  curmaxsize = maxsize
1009
  end
1010
 
 
1073
  numberSmallerAndBetter += 1
1074
  end
1075
  end
1076
+ betterThanAllSmaller = (numberSmallerAndBetter == 0)
1077
  if betterThanAllSmaller
1078
  println(io, "$size|$(curMSE)|$(stringTree(member.tree))")
1079
  push!(dominating, member)
 
1123
 
1124
  cycles_complete -= 1
1125
  cycles_elapsed = npopulations * niterations - cycles_complete
1126
+ if warmupMaxsize != 0 && cycles_elapsed % warmupMaxsize == 0
1127
  curmaxsize += 1
1128
  if curmaxsize > maxsize
1129
  curmaxsize = maxsize
 
1173
  numberSmallerAndBetter += 1
1174
  end
1175
  end
1176
+ betterThanAllSmaller = (numberSmallerAndBetter == 0)
1177
  if betterThanAllSmaller
1178
  delta_c = size - lastComplexity
1179
  delta_l_mse = log(curMSE/lastMSE)
pysr/sr.py CHANGED
@@ -286,35 +286,27 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
- op_runner += """
290
- @inline function BINOP!(x::Array{Float32, 1}, y::Array{Float32, 1}, i::Int, clen::Int)
291
- if i === 1
292
- @inbounds @simd for j=1:clen
293
- x[j] = """f"{binary_operators[0]}""""(x[j], y[j])
294
- end"""
295
  for i in range(1, len(binary_operators)):
296
  op_runner += f"""
297
- elseif i === {i+1}
298
- @inbounds @simd for j=1:clen
299
- x[j] = {binary_operators[i]}(x[j], y[j])
300
- end"""
301
  op_runner += """
302
  end
303
  end"""
304
 
305
  if len(unary_operators) > 0:
306
- op_runner += """
307
- @inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
308
- if i === 1
309
- @inbounds @simd for j=1:clen
310
- x[j] = """f"{unary_operators[0]}(x[j])""""
311
- end"""
312
  for i in range(1, len(unary_operators)):
313
  op_runner += f"""
314
- elseif i === {i+1}
315
- @inbounds @simd for j=1:clen
316
- x[j] = {unary_operators[i]}(x[j])
317
- end"""
318
  op_runner += """
319
  end
320
  end"""
 
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
+ op_runner += f"""
290
+ @inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
291
+ if i == 1
292
+ return @fastmath {binary_operators[0]}(x, y)"""
 
 
293
  for i in range(1, len(binary_operators)):
294
  op_runner += f"""
295
+ elseif i == {i+1}
296
+ return @fastmath {binary_operators[i]}(x, y)"""
 
 
297
  op_runner += """
298
  end
299
  end"""
300
 
301
  if len(unary_operators) > 0:
302
+ op_runner += f"""
303
+ @inline function UNAOP(i::Int, x::Float32)::Float32
304
+ if i == 1
305
+ return @fastmath {unary_operators[0]}(x)"""
 
 
306
  for i in range(1, len(unary_operators)):
307
  op_runner += f"""
308
+ elseif i == {i+1}
309
+ return @fastmath {unary_operators[i]}(x)"""
 
 
310
  op_runner += """
311
  end
312
  end"""