Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
2ca2654
1
Parent(s):
3a557a9
Add parameter for batching
Browse files- julia/sr.jl +26 -4
- pysr/sr.py +8 -1
julia/sr.jl
CHANGED
@@ -616,8 +616,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
|
|
616 |
prev = member.tree
|
617 |
tree = copyNode(prev)
|
618 |
#TODO - reconsider this
|
619 |
-
|
620 |
-
|
|
|
|
|
|
|
621 |
|
622 |
mutationChoice = rand()
|
623 |
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
@@ -648,7 +651,11 @@ function iterate(member::PopMember, T::Float32)::PopMember
|
|
648 |
return PopMember(tree, beforeLoss)
|
649 |
end
|
650 |
|
651 |
-
|
|
|
|
|
|
|
|
|
652 |
|
653 |
if annealing
|
654 |
delta = afterLoss - beforeLoss
|
@@ -697,6 +704,16 @@ function bestOfSample(pop::Population)::PopMember
|
|
697 |
return sample.members[best_idx]
|
698 |
end
|
699 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
700 |
# Return best 10 examples
|
701 |
function bestSubPop(pop::Population; topn::Integer=10)::Population
|
702 |
best_idx = sortperm([pop.members[member].score for member=1:pop.n])
|
@@ -1000,7 +1017,7 @@ function fullRun(niterations::Integer;
|
|
1000 |
@async begin
|
1001 |
allPops[i] = @spawnat :any let
|
1002 |
tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
|
1003 |
-
for j=1:tmp_pop.n
|
1004 |
if rand() < 0.1
|
1005 |
tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
|
1006 |
tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
|
@@ -1009,6 +1026,11 @@ function fullRun(niterations::Integer;
|
|
1009 |
end
|
1010 |
end
|
1011 |
end
|
|
|
|
|
|
|
|
|
|
|
1012 |
tmp_pop
|
1013 |
end
|
1014 |
put!(channels[i], fetch(allPops[i]))
|
|
|
616 |
prev = member.tree
|
617 |
tree = copyNode(prev)
|
618 |
#TODO - reconsider this
|
619 |
+
if batching
|
620 |
+
beforeLoss = scoreFuncBatch(member.tree)
|
621 |
+
else
|
622 |
+
beforeLoss = member.score
|
623 |
+
end
|
624 |
|
625 |
mutationChoice = rand()
|
626 |
weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
|
|
|
651 |
return PopMember(tree, beforeLoss)
|
652 |
end
|
653 |
|
654 |
+
if batching
|
655 |
+
afterLoss = scoreFuncBatch(tree)
|
656 |
+
else
|
657 |
+
afterLoss = scoreFunc(tree)
|
658 |
+
end
|
659 |
|
660 |
if annealing
|
661 |
delta = afterLoss - beforeLoss
|
|
|
704 |
return sample.members[best_idx]
|
705 |
end
|
706 |
|
707 |
+
function finalizeScores(pop::Population)::Population
|
708 |
+
need_recalculate = batching
|
709 |
+
if need_recalculate
|
710 |
+
@inbounds @simd for member=1:pop.n
|
711 |
+
pop.members[member].score = scoreFunc(pop.members[member].tree)
|
712 |
+
end
|
713 |
+
end
|
714 |
+
return pop
|
715 |
+
end
|
716 |
+
|
717 |
# Return best 10 examples
|
718 |
function bestSubPop(pop::Population; topn::Integer=10)::Population
|
719 |
best_idx = sortperm([pop.members[member].score for member=1:pop.n])
|
|
|
1017 |
@async begin
|
1018 |
allPops[i] = @spawnat :any let
|
1019 |
tmp_pop = run(cur_pop, ncyclesperiteration, verbosity=verbosity)
|
1020 |
+
@inbounds @simd for j=1:tmp_pop.n
|
1021 |
if rand() < 0.1
|
1022 |
tmp_pop.members[j].tree = simplifyTree(tmp_pop.members[j].tree)
|
1023 |
tmp_pop.members[j].tree = combineOperators(tmp_pop.members[j].tree)
|
|
|
1026 |
end
|
1027 |
end
|
1028 |
end
|
1029 |
+
if shouldOptimizeConstants
|
1030 |
+
#pass #(We already calculate full scores in the optimizer)
|
1031 |
+
else
|
1032 |
+
tmp_pop = finalizeScores(tmp_pop)
|
1033 |
+
end
|
1034 |
tmp_pop
|
1035 |
end
|
1036 |
put!(channels[i], fetch(allPops[i]))
|
pysr/sr.py
CHANGED
@@ -76,6 +76,8 @@ def pysr(X=None, y=None, weights=None,
|
|
76 |
fast_cycle=False,
|
77 |
maxdepth=None,
|
78 |
variable_names=[],
|
|
|
|
|
79 |
threads=None, #deprecated
|
80 |
julia_optimization=3,
|
81 |
):
|
@@ -138,6 +140,10 @@ def pysr(X=None, y=None, weights=None,
|
|
138 |
15% faster. May be algorithmically less efficient.
|
139 |
:param variable_names: list, a list of names for the variables, other
|
140 |
than "x0", "x1", etc.
|
|
|
|
|
|
|
|
|
141 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
142 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
143 |
(as strings).
|
@@ -227,7 +233,8 @@ const nrestarts = {nrestarts:d}
|
|
227 |
const perturbationFactor = {perturbationFactor:f}f0
|
228 |
const annealing = {"true" if annealing else "false"}
|
229 |
const weighted = {"true" if weights is not None else "false"}
|
230 |
-
const
|
|
|
231 |
const useVarMap = {"false" if len(variable_names) == 0 else "true"}
|
232 |
const mutationWeights = [
|
233 |
{weightMutateConstant:f},
|
|
|
76 |
fast_cycle=False,
|
77 |
maxdepth=None,
|
78 |
variable_names=[],
|
79 |
+
batching=False,
|
80 |
+
batchSize=50,
|
81 |
threads=None, #deprecated
|
82 |
julia_optimization=3,
|
83 |
):
|
|
|
140 |
15% faster. May be algorithmically less efficient.
|
141 |
:param variable_names: list, a list of names for the variables, other
|
142 |
than "x0", "x1", etc.
|
143 |
+
:param batching: bool, whether to compare population members on small batches
|
144 |
+
during evolution. Still uses full dataset for comparing against
|
145 |
+
hall of fame.
|
146 |
+
:param batchSize: int, the amount of data to use if doing batching.
|
147 |
:param julia_optimization: int, Optimization level (0, 1, 2, 3)
|
148 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
149 |
(as strings).
|
|
|
233 |
const perturbationFactor = {perturbationFactor:f}f0
|
234 |
const annealing = {"true" if annealing else "false"}
|
235 |
const weighted = {"true" if weights is not None else "false"}
|
236 |
+
const batching = {"true" if batching else "false"}
|
237 |
+
const batchSize = {min([batchSize, len(X)]) if batching else len(X):d}
|
238 |
const useVarMap = {"false" if len(variable_names) == 0 else "true"}
|
239 |
const mutationWeights = [
|
240 |
{weightMutateConstant:f},
|