MilesCranmer commited on
Commit
0779a61
·
unverified ·
2 Parent(s): 7c35b4e 96d6ea9

Merge remote-tracking branch 'origin/master' into var-complexity

Browse files
.github/workflows/CI.yml CHANGED
@@ -52,7 +52,7 @@ jobs:
52
  with:
53
  version: ${{ matrix.julia-version }}
54
  - name: "Cache Julia"
55
- uses: julia-actions/cache@v1
56
  with:
57
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
58
  cache-packages: false
@@ -144,7 +144,7 @@ jobs:
144
  activate-environment: pysr-test
145
  environment-file: environment.yml
146
  - name: "Cache Julia"
147
- uses: julia-actions/cache@v1
148
  with:
149
  cache-name: ${{ matrix.os }}-conda-${{ matrix.python-version }}
150
  cache-packages: false
 
52
  with:
53
  version: ${{ matrix.julia-version }}
54
  - name: "Cache Julia"
55
+ uses: julia-actions/cache@v2
56
  with:
57
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
58
  cache-packages: false
 
144
  activate-environment: pysr-test
145
  environment-file: environment.yml
146
  - name: "Cache Julia"
147
+ uses: julia-actions/cache@v2
148
  with:
149
  cache-name: ${{ matrix.os }}-conda-${{ matrix.python-version }}
150
  cache-packages: false
.github/workflows/CI_Windows.yml CHANGED
@@ -40,7 +40,7 @@ jobs:
40
  with:
41
  version: ${{ matrix.julia-version }}
42
  - name: "Cache Julia"
43
- uses: julia-actions/cache@v1
44
  with:
45
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
46
  cache-packages: false
 
40
  with:
41
  version: ${{ matrix.julia-version }}
42
  - name: "Cache Julia"
43
+ uses: julia-actions/cache@v2
44
  with:
45
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
46
  cache-packages: false
.github/workflows/CI_mac.yml CHANGED
@@ -40,7 +40,7 @@ jobs:
40
  with:
41
  version: ${{ matrix.julia-version }}
42
  - name: "Cache Julia"
43
- uses: julia-actions/cache@v1
44
  with:
45
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
46
  cache-packages: false
 
40
  with:
41
  version: ${{ matrix.julia-version }}
42
  - name: "Cache Julia"
43
+ uses: julia-actions/cache@v2
44
  with:
45
  cache-name: ${{ matrix.os }}-test-${{ matrix.julia-version }}-${{ matrix.python-version }}
46
  cache-packages: false
.github/workflows/docker_deploy.yml CHANGED
@@ -24,13 +24,13 @@ jobs:
24
  - name: Checkout
25
  uses: actions/checkout@v4
26
  - name: Login to Docker Hub
27
- uses: docker/login-action@v2
28
  if: github.event_name != 'pull_request'
29
  with:
30
  username: ${{ secrets.DOCKERHUB_USERNAME }}
31
  password: ${{ secrets.DOCKERHUB_TOKEN }}
32
  - name: Login to GitHub registry
33
- uses: docker/login-action@v2
34
  if: github.event_name != 'pull_request'
35
  with:
36
  registry: ghcr.io
 
24
  - name: Checkout
25
  uses: actions/checkout@v4
26
  - name: Login to Docker Hub
27
+ uses: docker/login-action@v3
28
  if: github.event_name != 'pull_request'
29
  with:
30
  username: ${{ secrets.DOCKERHUB_USERNAME }}
31
  password: ${{ secrets.DOCKERHUB_TOKEN }}
32
  - name: Login to GitHub registry
33
+ uses: docker/login-action@v3
34
  if: github.event_name != 'pull_request'
35
  with:
36
  registry: ghcr.io
.pre-commit-config.yaml CHANGED
@@ -9,7 +9,7 @@ repos:
9
  - id: check-added-large-files
10
  # General formatting
11
  - repo: https://github.com/psf/black
12
- rev: 24.4.0
13
  hooks:
14
  - id: black
15
  - id: black-jupyter
 
9
  - id: check-added-large-files
10
  # General formatting
11
  - repo: https://github.com/psf/black
12
+ rev: 24.4.2
13
  hooks:
14
  - id: black
15
  - id: black-jupyter
examples/pysr_demo.ipynb CHANGED
@@ -396,7 +396,7 @@
396
  "id": "wbWHyOjl2_kX"
397
  },
398
  "source": [
399
- "Since `quart` is arguably more complex than the other operators, you can also give it a different complexity, using, e.g., `complexity_of_operators={\"quart\": 2}` to give it a complexity of 2 (instead of the default 2). You can also define custom complexities for variables and constants (`complexity_of_variables` and `complexity_of_constants`, respectively - both take a single number).\n",
400
  "\n",
401
  "\n",
402
  "One can also add a binary operator, with, e.g., `\"myoperator(x, y) = x^2 * y\"`. All Julia operators that work on scalar 32-bit floating point values are available.\n",
 
396
  "id": "wbWHyOjl2_kX"
397
  },
398
  "source": [
399
+ "Since `quart` is arguably more complex than the other operators, you can also give it a different complexity, using, e.g., `complexity_of_operators={\"quart\": 2}` to give it a complexity of 2 (instead of the default 1). You can also define custom complexities for variables and constants (`complexity_of_variables` and `complexity_of_constants`, respectively - both take a single number).\n",
400
  "\n",
401
  "\n",
402
  "One can also add a binary operator, with, e.g., `\"myoperator(x, y) = x^2 * y\"`. All Julia operators that work on scalar 32-bit floating point values are available.\n",
pysr/sr.py CHANGED
@@ -1,6 +1,8 @@
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
 
 
4
  import os
5
  import pickle as pkl
6
  import re
@@ -912,15 +914,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
912
  updated_kwarg_name = DEPRECATED_KWARGS[k]
913
  setattr(self, updated_kwarg_name, v)
914
  warnings.warn(
915
- f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
916
  "Please use that instead.",
917
  FutureWarning,
918
  )
919
  # Handle kwargs that have been moved to the fit method
920
  elif k in ["weights", "variable_names", "Xresampled"]:
921
  warnings.warn(
922
- f"{k} is a data dependant parameter so should be passed when fit is called. "
923
- f"Ignoring parameter; please pass {k} during the call to fit instead.",
924
  FutureWarning,
925
  )
926
  elif k == "julia_project":
@@ -937,9 +939,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
937
  FutureWarning,
938
  )
939
  else:
940
- raise TypeError(
941
- f"{k} is not a valid keyword argument for PySRRegressor."
 
942
  )
 
 
 
943
 
944
  @classmethod
945
  def from_file(
@@ -2545,6 +2551,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2545
  return with_preamble(table_string)
2546
 
2547
 
 
 
 
 
 
 
 
 
 
 
2548
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2549
  """Select an expression and return its index."""
2550
  if model_selection == "accuracy":
 
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
4
+ import difflib
5
+ import inspect
6
  import os
7
  import pickle as pkl
8
  import re
 
914
  updated_kwarg_name = DEPRECATED_KWARGS[k]
915
  setattr(self, updated_kwarg_name, v)
916
  warnings.warn(
917
+ f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
918
  "Please use that instead.",
919
  FutureWarning,
920
  )
921
  # Handle kwargs that have been moved to the fit method
922
  elif k in ["weights", "variable_names", "Xresampled"]:
923
  warnings.warn(
924
+ f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
925
+ f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
926
  FutureWarning,
927
  )
928
  elif k == "julia_project":
 
939
  FutureWarning,
940
  )
941
  else:
942
+ suggested_keywords = _suggest_keywords(PySRRegressor, k)
943
+ err_msg = (
944
+ f"`{k}` is not a valid keyword argument for PySRRegressor."
945
  )
946
+ if len(suggested_keywords) > 0:
947
+ err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
948
+ raise TypeError(err_msg)
949
 
950
  @classmethod
951
  def from_file(
 
2551
  return with_preamble(table_string)
2552
 
2553
 
2554
+ def _suggest_keywords(cls, k: str) -> List[str]:
2555
+ valid_keywords = [
2556
+ param
2557
+ for param in inspect.signature(cls.__init__).parameters
2558
+ if param not in ["self", "kwargs"]
2559
+ ]
2560
+ suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
2561
+ return suggestions
2562
+
2563
+
2564
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2565
  """Select an expression and return its index."""
2566
  if model_selection == "accuracy":
pysr/test/test.py CHANGED
@@ -15,9 +15,8 @@ from pysr import PySRRegressor, install, jl
15
  from pysr.export_latex import sympy2latex
16
  from pysr.feature_selection import _handle_feature_selection, run_feature_selection
17
  from pysr.julia_helpers import init_julia
18
- from pysr.sr import _check_assertions, _process_constraints, idx_model_selection
19
  from pysr.utils import _csv_filename_to_pkl_filename
20
-
21
  from .params import (
22
  DEFAULT_NCYCLES,
23
  DEFAULT_NITERATIONS,
@@ -596,6 +595,105 @@ class TestMiscellaneous(unittest.TestCase):
596
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
597
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  def test_deprecation(self):
600
  """Ensure that deprecation works as expected.
601
 
@@ -738,100 +836,28 @@ class TestMiscellaneous(unittest.TestCase):
738
  model.get_best()
739
  print("Failed", opt["kwargs"])
740
 
741
- def test_pickle_with_temp_equation_file(self):
742
- """If we have a temporary equation file, unpickle the estimator."""
743
- model = PySRRegressor(
744
- populations=int(1 + DEFAULT_POPULATIONS / 5),
745
- temp_equation_file=True,
746
- procs=0,
747
- multithreading=False,
748
  )
749
- nout = 3
750
- X = np.random.randn(100, 2)
751
- y = np.random.randn(100, nout)
752
- model.fit(X, y)
753
- contents = model.equation_file_contents_.copy()
754
-
755
- y_predictions = model.predict(X)
756
-
757
- equation_file_base = model.equation_file_
758
- for i in range(1, nout + 1):
759
- assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
760
-
761
- with tempfile.NamedTemporaryFile() as pickle_file:
762
- pkl.dump(model, pickle_file)
763
- pickle_file.seek(0)
764
- model2 = pkl.load(pickle_file)
765
-
766
- contents2 = model2.equation_file_contents_
767
- cols_to_check = ["equation", "loss", "complexity"]
768
- for frame1, frame2 in zip(contents, contents2):
769
- pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
770
 
771
- y_predictions2 = model2.predict(X)
772
- np.testing.assert_array_equal(y_predictions, y_predictions2)
 
773
 
774
- def test_scikit_learn_compatibility(self):
775
- """Test PySRRegressor compatibility with scikit-learn."""
776
- model = PySRRegressor(
777
- niterations=int(1 + DEFAULT_NITERATIONS / 10),
778
- populations=int(1 + DEFAULT_POPULATIONS / 3),
779
- ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
780
- verbosity=0,
781
- progress=False,
782
- random_state=0,
783
- deterministic=True, # Deterministic as tests require this.
784
- procs=0,
785
- multithreading=False,
786
- warm_start=False,
787
- temp_equation_file=True,
788
- ) # Return early.
789
-
790
- check_generator = check_estimator(model, generate_only=True)
791
- exception_messages = []
792
- for _, check in check_generator:
793
- if check.func.__name__ == "check_complex_data":
794
- # We can use complex data, so avoid this check.
795
- continue
796
- try:
797
- with warnings.catch_warnings():
798
- warnings.simplefilter("ignore")
799
- check(model)
800
- print("Passed", check.func.__name__)
801
- except Exception:
802
- error_message = str(traceback.format_exc())
803
- exception_messages.append(
804
- f"{check.func.__name__}:\n" + error_message + "\n"
805
- )
806
- print("Failed", check.func.__name__, "with:")
807
- # Add a leading tab to error message, which
808
- # might be multi-line:
809
- print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
810
- # If any checks failed don't let the test pass.
811
- self.assertEqual(len(exception_messages), 0)
812
-
813
- def test_param_groupings(self):
814
- """Test that param_groupings are complete"""
815
- param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
816
- if not param_groupings_file.exists():
817
- return
818
-
819
- # Read the file, discarding lines ending in ":",
820
- # and removing leading "\s*-\s*":
821
- params = []
822
- with open(param_groupings_file, "r") as f:
823
- for line in f.readlines():
824
- if line.strip().endswith(":"):
825
- continue
826
- if line.strip().startswith("-"):
827
- params.append(line.strip()[1:].strip())
828
 
829
- regressor_params = [
830
- p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
831
- ]
832
 
833
- # Check the sets are equal:
834
- self.assertSetEqual(set(params), set(regressor_params))
835
 
836
 
837
  TRUE_PREAMBLE = "\n".join(
@@ -1187,6 +1213,7 @@ def runtests(just_tests=False):
1187
  TestBest,
1188
  TestFeatureSelection,
1189
  TestMiscellaneous,
 
1190
  TestLaTeXTable,
1191
  TestDimensionalConstraints,
1192
  ]
 
15
  from pysr.export_latex import sympy2latex
16
  from pysr.feature_selection import _handle_feature_selection, run_feature_selection
17
  from pysr.julia_helpers import init_julia
18
+ from pysr.sr import _check_assertions, _process_constraints, _suggest_keywords, idx_model_selection
19
  from pysr.utils import _csv_filename_to_pkl_filename
 
20
  from .params import (
21
  DEFAULT_NCYCLES,
22
  DEFAULT_NITERATIONS,
 
595
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
596
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
597
 
598
+ def test_pickle_with_temp_equation_file(self):
599
+ """If we have a temporary equation file, unpickle the estimator."""
600
+ model = PySRRegressor(
601
+ populations=int(1 + DEFAULT_POPULATIONS / 5),
602
+ temp_equation_file=True,
603
+ procs=0,
604
+ multithreading=False,
605
+ )
606
+ nout = 3
607
+ X = np.random.randn(100, 2)
608
+ y = np.random.randn(100, nout)
609
+ model.fit(X, y)
610
+ contents = model.equation_file_contents_.copy()
611
+
612
+ y_predictions = model.predict(X)
613
+
614
+ equation_file_base = model.equation_file_
615
+ for i in range(1, nout + 1):
616
+ assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
617
+
618
+ with tempfile.NamedTemporaryFile() as pickle_file:
619
+ pkl.dump(model, pickle_file)
620
+ pickle_file.seek(0)
621
+ model2 = pkl.load(pickle_file)
622
+
623
+ contents2 = model2.equation_file_contents_
624
+ cols_to_check = ["equation", "loss", "complexity"]
625
+ for frame1, frame2 in zip(contents, contents2):
626
+ pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
627
+
628
+ y_predictions2 = model2.predict(X)
629
+ np.testing.assert_array_equal(y_predictions, y_predictions2)
630
+
631
+ def test_scikit_learn_compatibility(self):
632
+ """Test PySRRegressor compatibility with scikit-learn."""
633
+ model = PySRRegressor(
634
+ niterations=int(1 + DEFAULT_NITERATIONS / 10),
635
+ populations=int(1 + DEFAULT_POPULATIONS / 3),
636
+ ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
637
+ verbosity=0,
638
+ progress=False,
639
+ random_state=0,
640
+ deterministic=True, # Deterministic as tests require this.
641
+ procs=0,
642
+ multithreading=False,
643
+ warm_start=False,
644
+ temp_equation_file=True,
645
+ ) # Return early.
646
+
647
+ check_generator = check_estimator(model, generate_only=True)
648
+ exception_messages = []
649
+ for _, check in check_generator:
650
+ if check.func.__name__ == "check_complex_data":
651
+ # We can use complex data, so avoid this check.
652
+ continue
653
+ try:
654
+ with warnings.catch_warnings():
655
+ warnings.simplefilter("ignore")
656
+ check(model)
657
+ print("Passed", check.func.__name__)
658
+ except Exception:
659
+ error_message = str(traceback.format_exc())
660
+ exception_messages.append(
661
+ f"{check.func.__name__}:\n" + error_message + "\n"
662
+ )
663
+ print("Failed", check.func.__name__, "with:")
664
+ # Add a leading tab to error message, which
665
+ # might be multi-line:
666
+ print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
667
+ # If any checks failed don't let the test pass.
668
+ self.assertEqual(len(exception_messages), 0)
669
+
670
+ def test_param_groupings(self):
671
+ """Test that param_groupings are complete"""
672
+ param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
673
+ if not param_groupings_file.exists():
674
+ return
675
+
676
+ # Read the file, discarding lines ending in ":",
677
+ # and removing leading "\s*-\s*":
678
+ params = []
679
+ with open(param_groupings_file, "r") as f:
680
+ for line in f.readlines():
681
+ if line.strip().endswith(":"):
682
+ continue
683
+ if line.strip().startswith("-"):
684
+ params.append(line.strip()[1:].strip())
685
+
686
+ regressor_params = [
687
+ p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
688
+ ]
689
+
690
+ # Check the sets are equal:
691
+ self.assertSetEqual(set(params), set(regressor_params))
692
+
693
+
694
+ class TestHelpMessages(unittest.TestCase):
695
+ """Test user help messages."""
696
+
697
  def test_deprecation(self):
698
  """Ensure that deprecation works as expected.
699
 
 
836
  model.get_best()
837
  print("Failed", opt["kwargs"])
838
 
839
+ def test_suggest_keywords(self):
840
+ # Easy
841
+ self.assertEqual(
842
+ _suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
 
 
 
843
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
 
845
+ # More complex, and with error
846
+ with self.assertRaises(TypeError) as cm:
847
+ model = PySRRegressor(ncyclesperiterationn=5)
848
 
849
+ self.assertIn(
850
+ "`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
851
+ )
852
+ self.assertIn("Did you mean", str(cm.exception))
853
+ self.assertIn("`ncycles_per_iteration`, ", str(cm.exception))
854
+ self.assertIn("`niterations`", str(cm.exception))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
 
856
+ # Farther matches (this might need to be changed)
857
+ with self.assertRaises(TypeError) as cm:
858
+ model = PySRRegressor(operators=["+", "-"])
859
 
860
+ self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
 
861
 
862
 
863
  TRUE_PREAMBLE = "\n".join(
 
1213
  TestBest,
1214
  TestFeatureSelection,
1215
  TestMiscellaneous,
1216
+ TestHelpMessages,
1217
  TestLaTeXTable,
1218
  TestDimensionalConstraints,
1219
  ]