Spaces:
Running
Running
File size: 5,723 Bytes
603c5f4 02376fd 618a3f8 603c5f4 618a3f8 b7d54b1 6d3c900 618a3f8 92eb30b 618a3f8 a949e43 618a3f8 603c5f4 618a3f8 f729ba4 618a3f8 87880d1 618a3f8 603c5f4 618a3f8 92eb30b 7acebb6 92eb30b 603c5f4 92eb30b 603c5f4 92eb30b 65159ce b7d54b1 02376fd 65159ce ddadb22 65159ce 618a3f8 ef66f4a 618a3f8 ef66f4a 618a3f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import os
import platform
import subprocess
import sys
import tempfile
import textwrap
import unittest
from pathlib import Path
import numpy as np
from .. import PySRRegressor
from ..julia_import import jl_version
from .params import DEFAULT_NITERATIONS, DEFAULT_POPULATIONS
class TestStartup(unittest.TestCase):
"""Various tests related to starting up PySR."""
def setUp(self):
# Using inspect,
# get default niterations from PySRRegressor, and double them:
self.default_test_kwargs = dict(
progress=False,
model_selection="accuracy",
niterations=DEFAULT_NITERATIONS * 2,
populations=DEFAULT_POPULATIONS * 2,
temp_equation_file=True,
)
self.rstate = np.random.RandomState(0)
self.X = self.rstate.randn(100, 5)
def test_warm_start_from_file(self):
"""Test that we can warm start in another process."""
if platform.system() == "Windows":
self.skipTest("Warm start test incompatible with Windows")
with tempfile.TemporaryDirectory() as tmpdirname:
model = PySRRegressor(
**self.default_test_kwargs,
unary_operators=["cos"],
)
model.warm_start = True
model.temp_equation_file = False
model.equation_file = Path(tmpdirname) / "equations.csv"
model.deterministic = True
model.multithreading = False
model.random_state = 0
model.procs = 0
model.early_stop_condition = 1e-10
rstate = np.random.RandomState(0)
X = rstate.randn(100, 2)
y = np.cos(X[:, 0]) ** 2
model.fit(X, y)
best_loss = model.equations_.iloc[-1]["loss"]
# Save X and y to a file:
X_file = Path(tmpdirname) / "X.npy"
y_file = Path(tmpdirname) / "y.npy"
np.save(X_file, X)
np.save(y_file, y)
# Now, create a new process and warm start from the file:
result = subprocess.run(
[
sys.executable,
"-c",
textwrap.dedent(
f"""
from pysr import PySRRegressor
import numpy as np
X = np.load("{X_file}")
y = np.load("{y_file}")
print("Loading model from file")
model = PySRRegressor.from_file("{model.equation_file}")
assert model.julia_state_ is not None
# Reset saved equations; should be loaded from state!
model.equations_ = None
model.equation_file_contents_ = None
model.warm_start = True
model.niterations = 0
model.max_evals = 0
model.ncycles_per_iteration = 0
model.fit(X, y)
best_loss = model.equations_.iloc[-1]["loss"]
assert best_loss <= {best_loss}
"""
),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=os.environ,
)
self.assertEqual(result.returncode, 0)
self.assertIn("Loading model from file", result.stdout.decode())
self.assertIn("Started!", result.stderr.decode())
def test_bad_startup_options(self):
warning_tests = [
dict(
code='import os; os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "no"; import pysr',
msg="PYTHON_JULIACALL_HANDLE_SIGNALS environment variable is set",
),
dict(
code='import os; os.environ["PYTHON_JULIACALL_THREADS"] = "1"; import pysr',
msg="PYTHON_JULIACALL_THREADS environment variable is set",
),
dict(
code="import juliacall; import pysr",
msg="juliacall module already imported.",
),
dict(
code='import os; os.environ["PYSR_AUTOLOAD_EXTENSIONS"] = "foo"; import pysr',
msg="PYSR_AUTOLOAD_EXTENSIONS environment variable is set",
),
]
for warning_test in warning_tests:
result = subprocess.run(
[sys.executable, "-c", warning_test["code"]],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=os.environ,
)
self.assertIn(warning_test["msg"], result.stderr.decode())
def test_notebook(self):
if jl_version < (1, 9, 0):
self.skipTest("Julia version too old")
if platform.system() == "Windows":
self.skipTest("Notebook test incompatible with Windows")
result = subprocess.run(
[
sys.executable,
"-m",
"pytest",
"--nbval",
str(Path(__file__).parent / "test_nb.ipynb"),
"--nbval-sanitize-with",
str(Path(__file__).parent / "nb_sanitize.cfg"),
],
env=os.environ,
)
self.assertEqual(result.returncode, 0)
def runtests(just_tests=False):
tests = [TestStartup]
if just_tests:
return tests
suite = unittest.TestSuite()
loader = unittest.TestLoader()
for test in tests:
suite.addTests(loader.loadTestsFromTestCase(test))
runner = unittest.TextTestRunner()
return runner.run(suite)
|