MilesCranmer commited on
Commit
2ca2654
1 Parent(s): 3a557a9

Add parameter for batching

Browse files
Files changed (2) hide show
  1. julia/sr.jl +26 -4
  2. pysr/sr.py +8 -1
julia/sr.jl CHANGED
@@ -616,8 +616,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
616
  prev = member.tree
617
  tree = copyNode(prev)
618
  #TODO - reconsider this
619
- # beforeLoss = member.score
620
- beforeLoss = scoreFuncBatch(member.tree)
 
 
 
621
 
622
  mutationChoice = rand()
623
  weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
@@ -648,7 +651,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
648
  return PopMember(tree, beforeLoss)
649
  end
650
 
651
- afterLoss = scoreFuncBatch(tree)
 
 
 
 
652
 
653
  if annealing
654
  delta = afterLoss - beforeLoss
@@ -697,6 +704,16 @@ function bestOfSample(pop::Population)::PopMember
697
  return sample.members[best_idx]
698
  end
699
 
 
 
 
 
 
 
 
 
 
 
700
  # Return best 10 examples
701
  function bestSubPop(pop::Population; topn::Integer=10)::Population
702
  best_idx = sortperm([pop.members[member].score for member=1:pop.n])
@@ -1000,7 +1017,7 @@ function fullRun(niterations::Integer;
1000
  @async begin
1001
  allPops[i] = @spawnat :any let
1002
  tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
1003
- for j=1:tmp_pop.n
1004
  if rand() < 0.1
1005
  tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
1006
  tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
@@ -1009,6 +1026,11 @@ function fullRun(niterations::Integer;
1009
  end
1010
  end
1011
  end
 
 
 
 
 
1012
  tmp_pop
1013
  end
1014
  put!(channels[i], fetch(allPops[i]))
 
616
  prev = member.tree
617
  tree = copyNode(prev)
618
  #TODO - reconsider this
619
+ if batching
620
+ beforeLoss = scoreFuncBatch(member.tree)
621
+ else
622
+ beforeLoss = member.score
623
+ end
624
 
625
  mutationChoice = rand()
626
  weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
 
651
  return PopMember(tree, beforeLoss)
652
  end
653
 
654
+ if batching
655
+ afterLoss = scoreFuncBatch(tree)
656
+ else
657
+ afterLoss = scoreFunc(tree)
658
+ end
659
 
660
  if annealing
661
  delta = afterLoss - beforeLoss
 
704
  return sample.members[best_idx]
705
  end
706
 
707
+ function finalizeScores(pop::Population)::Population
708
+ need_recalculate = batching
709
+ if need_recalculate
710
+ @inbounds @simd for member=1:pop.n
711
+ pop.members[member].score = scoreFunc(pop.members[member].tree)
712
+ end
713
+ end
714
+ return pop
715
+ end
716
+
717
  # Return best 10 examples
718
  function bestSubPop(pop::Population; topn::Integer=10)::Population
719
  best_idx = sortperm([pop.members[member].score for member=1:pop.n])
 
1017
  @async begin
1018
  allPops[i] = @spawnat :any let
1019
  tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
1020
+ @inbounds @simd for j=1:tmp_pop.n
1021
  if rand() < 0.1
1022
  tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
1023
  tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
 
1026
  end
1027
  end
1028
  end
1029
+ if shouldOptimizeConstants
1030
+ #pass #(We already calculate full scores in the optimizer)
1031
+ else
1032
+ tmp_pop = finalizeScores(tmp_pop)
1033
+ end
1034
  tmp_pop
1035
  end
1036
  put!(channels[i], fetch(allPops[i]))
pysr/sr.py CHANGED
@@ -76,6 +76,8 @@ def pysr(X=None, y=None, weights=None,
76
  fast_cycle=False,
77
  maxdepth=None,
78
  variable_names=[],
 
 
79
  threads=None, #deprecated
80
  julia_optimization=3,
81
  ):
@@ -138,6 +140,10 @@ def pysr(X=None, y=None, weights=None,
138
  15% faster. May be algorithmically less efficient.
139
  :param variable_names: list, a list of names for the variables, other
140
  than "x0", "x1", etc.
 
 
 
 
141
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
142
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
143
  (as strings).
@@ -227,7 +233,8 @@ const nrestarts = {nrestarts:d}
227
  const perturbationFactor = {perturbationFactor:f}f0
228
  const annealing = {"true" if annealing else "false"}
229
  const weighted = {"true" if weights is not None else "false"}
230
- const batchSize = {min([50, len(X)]):d}
 
231
  const useVarMap = {"false" if len(variable_names) == 0 else "true"}
232
  const mutationWeights = [
233
  {weightMutateConstant:f},
 
76
  fast_cycle=False,
77
  maxdepth=None,
78
  variable_names=[],
79
+ batching=False,
80
+ batchSize=50,
81
  threads=None, #deprecated
82
  julia_optimization=3,
83
  ):
 
140
  15% faster. May be algorithmically less efficient.
141
  :param variable_names: list, a list of names for the variables, other
142
  than "x0", "x1", etc.
143
+ :param batching: bool, whether to compare population members on small batches
144
+ during evolution. Still uses full dataset for comparing against
145
+ hall of fame.
146
+ :param batchSize: int, the amount of data to use if doing batching.
147
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
148
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
149
  (as strings).
 
233
  const perturbationFactor = {perturbationFactor:f}f0
234
  const annealing = {"true" if annealing else "false"}
235
  const weighted = {"true" if weights is not None else "false"}
236
+ const batching = {"true" if batching else "false"}
237
+ const batchSize = {min([batchSize, len(X)]) if batching else len(X):d}
238
  const useVarMap = {"false" if len(variable_names) == 0 else "true"}
239
  const mutationWeights = [
240
  {weightMutateConstant:f},