MilesCranmer commited on
Commit
b658d24
2 Parent(s): db44938 e84bed4

Merge pull request #670 from MilesCranmer/issue666

Browse files
pysr/__init__.py CHANGED
@@ -18,7 +18,6 @@ __all__ = [
18
  "sklearn_monkeypatch",
19
  "sympy2jax",
20
  "sympy2torch",
21
- "Problem",
22
  "install",
23
  "PySRRegressor",
24
  "best",
 
18
  "sklearn_monkeypatch",
19
  "sympy2jax",
20
  "sympy2torch",
 
21
  "install",
22
  "PySRRegressor",
23
  "best",
pysr/export_jax.py CHANGED
@@ -1,5 +1,5 @@
1
  import numpy as np # noqa: F401
2
- import sympy
3
 
4
  # Special since need to reduce arguments.
5
  MUL = 0
 
1
  import numpy as np # noqa: F401
2
+ import sympy # type: ignore
3
 
4
  # Special since need to reduce arguments.
5
  MUL = 0
pysr/export_latex.py CHANGED
@@ -3,8 +3,8 @@
3
  from typing import List, Optional, Tuple
4
 
5
  import pandas as pd
6
- import sympy
7
- from sympy.printing.latex import LatexPrinter
8
 
9
 
10
  class PreciseLatexPrinter(LatexPrinter):
 
3
  from typing import List, Optional, Tuple
4
 
5
  import pandas as pd
6
+ import sympy # type: ignore
7
+ from sympy.printing.latex import LatexPrinter # type: ignore
8
 
9
 
10
  class PreciseLatexPrinter(LatexPrinter):
pysr/export_numpy.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Union
6
  import numpy as np
7
  import pandas as pd
8
  from numpy.typing import NDArray
9
- from sympy import Expr, Symbol, lambdify
10
 
11
 
12
  def sympy2numpy(eqn, sympy_symbols, *, selection=None):
 
6
  import numpy as np
7
  import pandas as pd
8
  from numpy.typing import NDArray
9
+ from sympy import Expr, Symbol, lambdify # type: ignore
10
 
11
 
12
  def sympy2numpy(eqn, sympy_symbols, *, selection=None):
pysr/export_sympy.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  from typing import Callable, Dict, List, Optional
4
 
5
- import sympy
6
  from sympy import sympify
7
 
8
  from .utils import ArrayLike
 
2
 
3
  from typing import Callable, Dict, List, Optional
4
 
5
+ import sympy # type: ignore
6
  from sympy import sympify
7
 
8
  from .utils import ArrayLike
pysr/export_torch.py CHANGED
@@ -4,7 +4,7 @@ import collections as co
4
  import functools as ft
5
 
6
  import numpy as np # noqa: F401
7
- import sympy
8
 
9
 
10
  def _reduce(fn):
 
4
  import functools as ft
5
 
6
  import numpy as np # noqa: F401
7
+ import sympy # type: ignore
8
 
9
 
10
  def _reduce(fn):
pysr/test/test.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import pickle as pkl
3
  import tempfile
@@ -8,7 +9,7 @@ from pathlib import Path
8
 
9
  import numpy as np
10
  import pandas as pd
11
- import sympy
12
  from sklearn.utils.estimator_checks import check_estimator
13
 
14
  from pysr import PySRRegressor, install, jl
@@ -892,7 +893,7 @@ class TestHelpMessages(unittest.TestCase):
892
 
893
  # More complex, and with error
894
  with self.assertRaises(TypeError) as cm:
895
- model = PySRRegressor(ncyclesperiterationn=5)
896
 
897
  self.assertIn(
898
  "`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
@@ -903,10 +904,18 @@ class TestHelpMessages(unittest.TestCase):
903
 
904
  # Farther matches (this might need to be changed)
905
  with self.assertRaises(TypeError) as cm:
906
- model = PySRRegressor(operators=["+", "-"])
907
 
908
  self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
909
 
 
 
 
 
 
 
 
 
910
 
911
  TRUE_PREAMBLE = "\n".join(
912
  [
 
1
+ import importlib
2
  import os
3
  import pickle as pkl
4
  import tempfile
 
9
 
10
  import numpy as np
11
  import pandas as pd
12
+ import sympy # type: ignore
13
  from sklearn.utils.estimator_checks import check_estimator
14
 
15
  from pysr import PySRRegressor, install, jl
 
893
 
894
  # More complex, and with error
895
  with self.assertRaises(TypeError) as cm:
896
+ PySRRegressor(ncyclesperiterationn=5)
897
 
898
  self.assertIn(
899
  "`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
 
904
 
905
  # Farther matches (this might need to be changed)
906
  with self.assertRaises(TypeError) as cm:
907
+ PySRRegressor(operators=["+", "-"])
908
 
909
  self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
910
 
911
+ def test_issue_666(self):
912
+ # Try the equivalent of `from pysr import *`
913
+ pysr_module = importlib.import_module("pysr")
914
+ names_to_import = pysr_module.__all__
915
+
916
+ for name in names_to_import:
917
+ getattr(pysr_module, name)
918
+
919
 
920
  TRUE_PREAMBLE = "\n".join(
921
  [
pysr/test/test_jax.py CHANGED
@@ -3,7 +3,7 @@ from functools import partial
3
 
4
  import numpy as np
5
  import pandas as pd
6
- import sympy
7
 
8
  import pysr
9
  from pysr import PySRRegressor, sympy2jax
@@ -102,7 +102,7 @@ class TestJAX(unittest.TestCase):
102
  )
103
 
104
  def test_issue_656(self):
105
- import sympy
106
 
107
  E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
  f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
 
3
 
4
  import numpy as np
5
  import pandas as pd
6
+ import sympy # type: ignore
7
 
8
  import pysr
9
  from pysr import PySRRegressor, sympy2jax
 
102
  )
103
 
104
  def test_issue_656(self):
105
+ import sympy # type: ignore
106
 
107
  E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
  f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
pysr/test/test_torch.py CHANGED
@@ -2,7 +2,7 @@ import unittest
2
 
3
  import numpy as np
4
  import pandas as pd
5
- import sympy
6
 
7
  import pysr
8
  from pysr import PySRRegressor, sympy2torch
 
2
 
3
  import numpy as np
4
  import pandas as pd
5
+ import sympy # type: ignore
6
 
7
  import pysr
8
  from pysr import PySRRegressor, sympy2torch