MilesCranmer commited on
Commit
0ddc60f
1 Parent(s): bd4f864

Specify types for all parameters

Browse files
Files changed (2) hide show
  1. pysr/sr.py +89 -85
  2. setup.py +1 -1
pysr/sr.py CHANGED
@@ -11,7 +11,7 @@ from datetime import datetime
11
  from io import StringIO
12
  from multiprocessing import cpu_count
13
  from pathlib import Path
14
- from typing import List, Optional
15
 
16
  import numpy as np
17
  import pandas as pd
@@ -659,90 +659,92 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
659
 
660
  def __init__(
661
  self,
662
- model_selection="best",
663
  *,
664
- binary_operators=None,
665
- unary_operators=None,
666
- niterations=40,
667
- populations=15,
668
- population_size=33,
669
- max_evals=None,
670
- maxsize=20,
671
- maxdepth=None,
672
- warmup_maxsize_by=0.0,
673
- timeout_in_seconds=None,
674
- constraints=None,
675
- nested_constraints=None,
676
- loss=None,
677
- full_objective=None,
678
- complexity_of_operators=None,
679
- complexity_of_constants=1,
680
- complexity_of_variables=1,
681
- parsimony=0.0032,
682
- dimensional_constraint_penalty=None,
683
- use_frequency=True,
684
- use_frequency_in_tournament=True,
685
- adaptive_parsimony_scaling=20.0,
686
- alpha=0.1,
687
- annealing=False,
688
- early_stop_condition=None,
689
- ncyclesperiteration=550,
690
- fraction_replaced=0.000364,
691
- fraction_replaced_hof=0.035,
692
- weight_add_node=0.79,
693
- weight_insert_node=5.1,
694
- weight_delete_node=1.7,
695
- weight_do_nothing=0.21,
696
- weight_mutate_constant=0.048,
697
- weight_mutate_operator=0.47,
698
- weight_randomize=0.00023,
699
- weight_simplify=0.0020,
700
- weight_optimize=0.0,
701
- crossover_probability=0.066,
702
- skip_mutation_failures=True,
703
- migration=True,
704
- hof_migration=True,
705
- topn=12,
706
- should_simplify=None,
707
- should_optimize_constants=True,
708
- optimizer_algorithm="BFGS",
709
- optimizer_nrestarts=2,
710
- optimize_probability=0.14,
711
- optimizer_iterations=8,
712
- perturbation_factor=0.076,
713
- tournament_selection_n=10,
714
- tournament_selection_p=0.86,
715
- procs=cpu_count(),
716
- multithreading=None,
717
- cluster_manager=None,
718
- heap_size_hint_in_bytes=None,
719
- batching=False,
720
- batch_size=50,
721
- fast_cycle=False,
722
- turbo=False,
723
- precision=32,
724
- enable_autodiff=False,
725
- random_state=None,
726
- deterministic=False,
727
- warm_start=False,
728
- verbosity=1,
729
- update_verbosity=None,
730
- print_precision=5,
731
- progress=True,
732
- equation_file=None,
733
- temp_equation_file=False,
734
- tempdir=None,
735
- delete_tempfiles=True,
736
- julia_project=None,
737
- update=False,
738
- output_jax_format=False,
739
- output_torch_format=False,
740
- extra_sympy_mappings=None,
741
- extra_torch_mappings=None,
742
- extra_jax_mappings=None,
743
- denoise=False,
744
- select_k_features=None,
745
- julia_kwargs=None,
 
 
746
  **kwargs,
747
  ):
748
  # Hyperparameters
@@ -1645,7 +1647,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1645
  fraction_replaced_hof=self.fraction_replaced_hof,
1646
  should_simplify=self.should_simplify,
1647
  should_optimize_constants=self.should_optimize_constants,
1648
- warmup_maxsize_by=self.warmup_maxsize_by,
 
 
1649
  use_frequency=self.use_frequency,
1650
  use_frequency_in_tournament=self.use_frequency_in_tournament,
1651
  adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
 
11
  from io import StringIO
12
  from multiprocessing import cpu_count
13
  from pathlib import Path
14
+ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
15
 
16
  import numpy as np
17
  import pandas as pd
 
659
 
660
  def __init__(
661
  self,
662
+ model_selection: Literal["best", "accuracy", "score"] = "best",
663
  *,
664
+ binary_operators: Optional[List[str]] = None,
665
+ unary_operators: Optional[List[str]] = None,
666
+ niterations: int = 40,
667
+ populations: int = 15,
668
+ population_size: int = 33,
669
+ max_evals: Optional[int] = None,
670
+ maxsize: int = 20,
671
+ maxdepth: Optional[int] = None,
672
+ warmup_maxsize_by: Optional[float] = None,
673
+ timeout_in_seconds: Optional[float] = None,
674
+ constraints: Optional[Dict[str, Union[int, Tuple[int, int]]]] = None,
675
+ nested_constraints: Optional[Dict[str, Dict[str, int]]] = None,
676
+ loss: Optional[str] = None,
677
+ full_objective: Optional[str] = None,
678
+ complexity_of_operators: Optional[Dict[str, Union[int, float]]] = None,
679
+ complexity_of_constants: Union[int, float] = 1,
680
+ complexity_of_variables: Union[int, float] = 1,
681
+ parsimony: float = 0.0032,
682
+ dimensional_constraint_penalty: Optional[float] = None,
683
+ use_frequency: bool = True,
684
+ use_frequency_in_tournament: bool = True,
685
+ adaptive_parsimony_scaling: float = 20.0,
686
+ alpha: float = 0.1,
687
+ annealing: bool = False,
688
+ early_stop_condition: Optional[Union[float, str]] = None,
689
+ ncyclesperiteration: int = 550,
690
+ fraction_replaced: float = 0.000364,
691
+ fraction_replaced_hof: float = 0.035,
692
+ weight_add_node: float = 0.79,
693
+ weight_insert_node: float = 5.1,
694
+ weight_delete_node: float = 1.7,
695
+ weight_do_nothing: float = 0.21,
696
+ weight_mutate_constant: float = 0.048,
697
+ weight_mutate_operator: float = 0.47,
698
+ weight_randomize: float = 0.00023,
699
+ weight_simplify: float = 0.0020,
700
+ weight_optimize: float = 0.0,
701
+ crossover_probability: float = 0.066,
702
+ skip_mutation_failures: bool = True,
703
+ migration: bool = True,
704
+ hof_migration: bool = True,
705
+ topn: int = 12,
706
+ should_simplify: Optional[bool] = None,
707
+ should_optimize_constants: bool = True,
708
+ optimizer_algorithm: str = "BFGS",
709
+ optimizer_nrestarts: int = 2,
710
+ optimize_probability: float = 0.14,
711
+ optimizer_iterations: int = 8,
712
+ perturbation_factor: float = 0.076,
713
+ tournament_selection_n: int = 10,
714
+ tournament_selection_p: float = 0.86,
715
+ procs: int = cpu_count(),
716
+ multithreading: Optional[bool] = None,
717
+ cluster_manager: Optional[
718
+ Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]
719
+ ] = None,
720
+ heap_size_hint_in_bytes: Optional[int] = None,
721
+ batching: bool = False,
722
+ batch_size: int = 50,
723
+ fast_cycle: bool = False,
724
+ turbo: bool = False,
725
+ precision: int = 32,
726
+ enable_autodiff: bool = False,
727
+ random_state: Optional[Union[int, np.random.RandomState]] = None,
728
+ deterministic: bool = False,
729
+ warm_start: bool = False,
730
+ verbosity: int = 1,
731
+ update_verbosity: Optional[int] = None,
732
+ print_precision: int = 5,
733
+ progress: bool = True,
734
+ equation_file: Optional[str] = None,
735
+ temp_equation_file: bool = False,
736
+ tempdir: Optional[str] = None,
737
+ delete_tempfiles: bool = True,
738
+ julia_project: Optional[str] = None,
739
+ update: bool = False,
740
+ output_jax_format: bool = False,
741
+ output_torch_format: bool = False,
742
+ extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
743
+ extra_torch_mappings: Optional[Dict[Callable, Callable]] = None,
744
+ extra_jax_mappings: Optional[Dict[Callable, str]] = None,
745
+ denoise: bool = False,
746
+ select_k_features: Optional[int] = None,
747
+ julia_kwargs: Optional[Dict] = None,
748
  **kwargs,
749
  ):
750
  # Hyperparameters
 
1647
  fraction_replaced_hof=self.fraction_replaced_hof,
1648
  should_simplify=self.should_simplify,
1649
  should_optimize_constants=self.should_optimize_constants,
1650
+ warmup_maxsize_by=0.0
1651
+ if self.warmup_maxsize_by is None
1652
+ else self.warmup_maxsize_by,
1653
  use_frequency=self.use_frequency,
1654
  use_frequency_in_tournament=self.use_frequency_in_tournament,
1655
  adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
setup.py CHANGED
@@ -26,5 +26,5 @@ setuptools.setup(
26
  "Programming Language :: Python :: 3",
27
  "Operating System :: OS Independent",
28
  ],
29
- python_requires=">=3.7",
30
  )
 
26
  "Programming Language :: Python :: 3",
27
  "Operating System :: OS Independent",
28
  ],
29
+ python_requires=">=3.8",
30
  )