Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
04454ac
1
Parent(s):
2d025c2
Add unittests for units checks
Browse files- pysr/test/test.py +39 -3
pysr/test/test.py
CHANGED
@@ -19,6 +19,7 @@ from ..sr import (
|
|
19 |
_handle_feature_selection,
|
20 |
_csv_filename_to_pkl_filename,
|
21 |
idx_model_selection,
|
|
|
22 |
)
|
23 |
from ..export_latex import to_latex
|
24 |
|
@@ -932,12 +933,47 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
932 |
self.assertLess(model.get_best()["loss"], 1e-6)
|
933 |
self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
|
934 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
935 |
|
936 |
# TODO: add tests for:
|
937 |
# - custom operators + dimensions
|
938 |
-
# - invalid number of dimensions
|
939 |
-
# - X
|
940 |
-
# - y
|
941 |
# - no constants, so that it needs to find the right fraction
|
942 |
# - custom dimensional_constraint_penalty
|
943 |
|
|
|
19 |
_handle_feature_selection,
|
20 |
_csv_filename_to_pkl_filename,
|
21 |
idx_model_selection,
|
22 |
+
_check_assertions,
|
23 |
)
|
24 |
from ..export_latex import to_latex
|
25 |
|
|
|
933 |
self.assertLess(model.get_best()["loss"], 1e-6)
|
934 |
self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
|
935 |
|
936 |
+
def test_unit_checks(self):
|
937 |
+
"""This just checks the number of units passed"""
|
938 |
+
use_custom_variable_names = False
|
939 |
+
variable_names = None
|
940 |
+
weights = None
|
941 |
+
args = (use_custom_variable_names, variable_names, weights)
|
942 |
+
valid_units = [
|
943 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
944 |
+
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|
945 |
+
(np.ones((10, 1)), np.ones(10), None, "m/s"),
|
946 |
+
(np.ones((10, 1)), np.ones(10), None, ["m/s"]),
|
947 |
+
(np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
|
948 |
+
(np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", "km"]),
|
949 |
+
]
|
950 |
+
for X, y, X_units, y_units in valid_units:
|
951 |
+
_check_assertions(
|
952 |
+
X,
|
953 |
+
*args,
|
954 |
+
y,
|
955 |
+
X_units,
|
956 |
+
y_units,
|
957 |
+
)
|
958 |
+
invalid_units = [
|
959 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
|
960 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "m"),
|
961 |
+
(np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
|
962 |
+
(np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
|
963 |
+
]
|
964 |
+
for X, y, X_units, y_units in invalid_units:
|
965 |
+
with self.assertRaises(ValueError):
|
966 |
+
_check_assertions(
|
967 |
+
X,
|
968 |
+
*args,
|
969 |
+
y,
|
970 |
+
X_units,
|
971 |
+
y_units,
|
972 |
+
)
|
973 |
+
|
974 |
|
975 |
# TODO: add tests for:
|
976 |
# - custom operators + dimensions
|
|
|
|
|
|
|
977 |
# - no constants, so that it needs to find the right fraction
|
978 |
# - custom dimensional_constraint_penalty
|
979 |
|