Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
0ddc60f
1
Parent(s):
bd4f864
Specify types for all parameters
Browse files- pysr/sr.py +89 -85
- setup.py +1 -1
pysr/sr.py
CHANGED
@@ -11,7 +11,7 @@ from datetime import datetime
|
|
11 |
from io import StringIO
|
12 |
from multiprocessing import cpu_count
|
13 |
from pathlib import Path
|
14 |
-
from typing import List, Optional
|
15 |
|
16 |
import numpy as np
|
17 |
import pandas as pd
|
@@ -659,90 +659,92 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
659 |
|
660 |
def __init__(
|
661 |
self,
|
662 |
-
model_selection="best",
|
663 |
*,
|
664 |
-
binary_operators=None,
|
665 |
-
unary_operators=None,
|
666 |
-
niterations=40,
|
667 |
-
populations=15,
|
668 |
-
population_size=33,
|
669 |
-
max_evals=None,
|
670 |
-
maxsize=20,
|
671 |
-
maxdepth=None,
|
672 |
-
warmup_maxsize_by=
|
673 |
-
timeout_in_seconds=None,
|
674 |
-
constraints=None,
|
675 |
-
nested_constraints=None,
|
676 |
-
loss=None,
|
677 |
-
full_objective=None,
|
678 |
-
complexity_of_operators=None,
|
679 |
-
complexity_of_constants=1,
|
680 |
-
complexity_of_variables=1,
|
681 |
-
parsimony=0.0032,
|
682 |
-
dimensional_constraint_penalty=None,
|
683 |
-
use_frequency=True,
|
684 |
-
use_frequency_in_tournament=True,
|
685 |
-
adaptive_parsimony_scaling=20.0,
|
686 |
-
alpha=0.1,
|
687 |
-
annealing=False,
|
688 |
-
early_stop_condition=None,
|
689 |
-
ncyclesperiteration=550,
|
690 |
-
fraction_replaced=0.000364,
|
691 |
-
fraction_replaced_hof=0.035,
|
692 |
-
weight_add_node=0.79,
|
693 |
-
weight_insert_node=5.1,
|
694 |
-
weight_delete_node=1.7,
|
695 |
-
weight_do_nothing=0.21,
|
696 |
-
weight_mutate_constant=0.048,
|
697 |
-
weight_mutate_operator=0.47,
|
698 |
-
weight_randomize=0.00023,
|
699 |
-
weight_simplify=0.0020,
|
700 |
-
weight_optimize=0.0,
|
701 |
-
crossover_probability=0.066,
|
702 |
-
skip_mutation_failures=True,
|
703 |
-
migration=True,
|
704 |
-
hof_migration=True,
|
705 |
-
topn=12,
|
706 |
-
should_simplify=None,
|
707 |
-
should_optimize_constants=True,
|
708 |
-
optimizer_algorithm="BFGS",
|
709 |
-
optimizer_nrestarts=2,
|
710 |
-
optimize_probability=0.14,
|
711 |
-
optimizer_iterations=8,
|
712 |
-
perturbation_factor=0.076,
|
713 |
-
tournament_selection_n=10,
|
714 |
-
tournament_selection_p=0.86,
|
715 |
-
procs=cpu_count(),
|
716 |
-
multithreading=None,
|
717 |
-
cluster_manager
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
|
|
|
|
746 |
**kwargs,
|
747 |
):
|
748 |
# Hyperparameters
|
@@ -1645,7 +1647,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1645 |
fraction_replaced_hof=self.fraction_replaced_hof,
|
1646 |
should_simplify=self.should_simplify,
|
1647 |
should_optimize_constants=self.should_optimize_constants,
|
1648 |
-
warmup_maxsize_by=
|
|
|
|
|
1649 |
use_frequency=self.use_frequency,
|
1650 |
use_frequency_in_tournament=self.use_frequency_in_tournament,
|
1651 |
adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
|
|
|
11 |
from io import StringIO
|
12 |
from multiprocessing import cpu_count
|
13 |
from pathlib import Path
|
14 |
+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
|
15 |
|
16 |
import numpy as np
|
17 |
import pandas as pd
|
|
|
659 |
|
660 |
def __init__(
|
661 |
self,
|
662 |
+
model_selection: Literal["best", "accuracy", "score"] = "best",
|
663 |
*,
|
664 |
+
binary_operators: Optional[List[str]] = None,
|
665 |
+
unary_operators: Optional[List[str]] = None,
|
666 |
+
niterations: int = 40,
|
667 |
+
populations: int = 15,
|
668 |
+
population_size: int = 33,
|
669 |
+
max_evals: Optional[int] = None,
|
670 |
+
maxsize: int = 20,
|
671 |
+
maxdepth: Optional[int] = None,
|
672 |
+
warmup_maxsize_by: Optional[float] = None,
|
673 |
+
timeout_in_seconds: Optional[float] = None,
|
674 |
+
constraints: Optional[Dict[str, Union[int, Tuple[int, int]]]] = None,
|
675 |
+
nested_constraints: Optional[Dict[str, Dict[str, int]]] = None,
|
676 |
+
loss: Optional[str] = None,
|
677 |
+
full_objective: Optional[str] = None,
|
678 |
+
complexity_of_operators: Optional[Dict[str, Union[int, float]]] = None,
|
679 |
+
complexity_of_constants: Union[int, float] = 1,
|
680 |
+
complexity_of_variables: Union[int, float] = 1,
|
681 |
+
parsimony: float = 0.0032,
|
682 |
+
dimensional_constraint_penalty: Optional[float] = None,
|
683 |
+
use_frequency: bool = True,
|
684 |
+
use_frequency_in_tournament: bool = True,
|
685 |
+
adaptive_parsimony_scaling: float = 20.0,
|
686 |
+
alpha: float = 0.1,
|
687 |
+
annealing: bool = False,
|
688 |
+
early_stop_condition: Optional[Union[float, str]] = None,
|
689 |
+
ncyclesperiteration: int = 550,
|
690 |
+
fraction_replaced: float = 0.000364,
|
691 |
+
fraction_replaced_hof: float = 0.035,
|
692 |
+
weight_add_node: float = 0.79,
|
693 |
+
weight_insert_node: float = 5.1,
|
694 |
+
weight_delete_node: float = 1.7,
|
695 |
+
weight_do_nothing: float = 0.21,
|
696 |
+
weight_mutate_constant: float = 0.048,
|
697 |
+
weight_mutate_operator: float = 0.47,
|
698 |
+
weight_randomize: float = 0.00023,
|
699 |
+
weight_simplify: float = 0.0020,
|
700 |
+
weight_optimize: float = 0.0,
|
701 |
+
crossover_probability: float = 0.066,
|
702 |
+
skip_mutation_failures: bool = True,
|
703 |
+
migration: bool = True,
|
704 |
+
hof_migration: bool = True,
|
705 |
+
topn: int = 12,
|
706 |
+
should_simplify: Optional[bool] = None,
|
707 |
+
should_optimize_constants: bool = True,
|
708 |
+
optimizer_algorithm: str = "BFGS",
|
709 |
+
optimizer_nrestarts: int = 2,
|
710 |
+
optimize_probability: float = 0.14,
|
711 |
+
optimizer_iterations: int = 8,
|
712 |
+
perturbation_factor: float = 0.076,
|
713 |
+
tournament_selection_n: int = 10,
|
714 |
+
tournament_selection_p: float = 0.86,
|
715 |
+
procs: int = cpu_count(),
|
716 |
+
multithreading: Optional[bool] = None,
|
717 |
+
cluster_manager: Optional[
|
718 |
+
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]
|
719 |
+
] = None,
|
720 |
+
heap_size_hint_in_bytes: Optional[int] = None,
|
721 |
+
batching: bool = False,
|
722 |
+
batch_size: int = 50,
|
723 |
+
fast_cycle: bool = False,
|
724 |
+
turbo: bool = False,
|
725 |
+
precision: int = 32,
|
726 |
+
enable_autodiff: bool = False,
|
727 |
+
random_state: Optional[Union[int, np.random.RandomState]] = None,
|
728 |
+
deterministic: bool = False,
|
729 |
+
warm_start: bool = False,
|
730 |
+
verbosity: int = 1,
|
731 |
+
update_verbosity: Optional[int] = None,
|
732 |
+
print_precision: int = 5,
|
733 |
+
progress: bool = True,
|
734 |
+
equation_file: Optional[str] = None,
|
735 |
+
temp_equation_file: bool = False,
|
736 |
+
tempdir: Optional[str] = None,
|
737 |
+
delete_tempfiles: bool = True,
|
738 |
+
julia_project: Optional[str] = None,
|
739 |
+
update: bool = False,
|
740 |
+
output_jax_format: bool = False,
|
741 |
+
output_torch_format: bool = False,
|
742 |
+
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
|
743 |
+
extra_torch_mappings: Optional[Dict[Callable, Callable]] = None,
|
744 |
+
extra_jax_mappings: Optional[Dict[Callable, str]] = None,
|
745 |
+
denoise: bool = False,
|
746 |
+
select_k_features: Optional[int] = None,
|
747 |
+
julia_kwargs: Optional[Dict] = None,
|
748 |
**kwargs,
|
749 |
):
|
750 |
# Hyperparameters
|
|
|
1647 |
fraction_replaced_hof=self.fraction_replaced_hof,
|
1648 |
should_simplify=self.should_simplify,
|
1649 |
should_optimize_constants=self.should_optimize_constants,
|
1650 |
+
warmup_maxsize_by=0.0
|
1651 |
+
if self.warmup_maxsize_by is None
|
1652 |
+
else self.warmup_maxsize_by,
|
1653 |
use_frequency=self.use_frequency,
|
1654 |
use_frequency_in_tournament=self.use_frequency_in_tournament,
|
1655 |
adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
|
setup.py
CHANGED
@@ -26,5 +26,5 @@ setuptools.setup(
|
|
26 |
"Programming Language :: Python :: 3",
|
27 |
"Operating System :: OS Independent",
|
28 |
],
|
29 |
-
python_requires=">=3.
|
30 |
)
|
|
|
26 |
"Programming Language :: Python :: 3",
|
27 |
"Operating System :: OS Independent",
|
28 |
],
|
29 |
+
python_requires=">=3.8",
|
30 |
)
|