MilesCranmer commited on
Commit
c6c8728
1 Parent(s): 8f218cc

Test all aspects of generated LaTeX table

Browse files
Files changed (1) hide show
  1. test/test.py +107 -19
test/test.py CHANGED
@@ -281,19 +281,38 @@ class TestPipeline(unittest.TestCase):
281
  self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
282
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  class TestBest(unittest.TestCase):
285
  def setUp(self):
286
  self.rstate = np.random.RandomState(0)
287
  self.X = self.rstate.randn(10, 2)
288
  self.y = np.cos(self.X[:, 0]) ** 2
289
- self.model = PySRRegressor(
290
- progress=False,
291
- niterations=1,
292
- extra_sympy_mappings={},
293
- output_jax_format=False,
294
- model_selection="accuracy",
295
- equation_file="equation_file.csv",
296
- )
297
  equations = pd.DataFrame(
298
  {
299
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
@@ -301,17 +320,7 @@ class TestBest(unittest.TestCase):
301
  "complexity": [1, 2, 3],
302
  }
303
  )
304
-
305
- # Set up internal parameters as if it had been fitted:
306
- self.model.equation_file_ = "equation_file.csv"
307
- self.model.nout_ = 1
308
- self.model.selection_mask_ = None
309
- self.model.feature_names_in_ = np.array(["x0", "x1"], dtype=object)
310
- equations["complexity loss equation".split(" ")].to_csv(
311
- "equation_file.csv.bkup", sep="|"
312
- )
313
-
314
- self.model.refresh()
315
  self.equations_ = self.model.equations_
316
 
317
  def test_best(self):
@@ -485,3 +494,82 @@ class TestMiscellaneous(unittest.TestCase):
485
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
486
  # If any checks failed don't let the test pass.
487
  self.assertEqual(len(exception_messages), 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
282
 
283
 
284
+ def manually_create_model(equations, feature_names=None):
285
+ if feature_names is None:
286
+ feature_names = ["x0", "x1"]
287
+
288
+ model = PySRRegressor(
289
+ progress=False,
290
+ niterations=1,
291
+ extra_sympy_mappings={},
292
+ output_jax_format=False,
293
+ model_selection="accuracy",
294
+ equation_file="equation_file.csv",
295
+ )
296
+
297
+ # Set up internal parameters as if it had been fitted:
298
+ model.equation_file_ = "equation_file.csv"
299
+ model.nout_ = 1
300
+ model.selection_mask_ = None
301
+ model.feature_names_in_ = np.array(feature_names, dtype=object)
302
+ equations["complexity loss equation".split(" ")].to_csv(
303
+ "equation_file.csv.bkup", sep="|"
304
+ )
305
+
306
+ model.refresh()
307
+
308
+ return model
309
+
310
+
311
  class TestBest(unittest.TestCase):
312
  def setUp(self):
313
  self.rstate = np.random.RandomState(0)
314
  self.X = self.rstate.randn(10, 2)
315
  self.y = np.cos(self.X[:, 0]) ** 2
 
 
 
 
 
 
 
 
316
  equations = pd.DataFrame(
317
  {
318
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
 
320
  "complexity": [1, 2, 3],
321
  }
322
  )
323
+ self.model = manually_create_model(equations)
 
 
 
 
 
 
 
 
 
 
324
  self.equations_ = self.model.equations_
325
 
326
  def test_best(self):
 
494
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
495
  # If any checks failed don't let the test pass.
496
  self.assertEqual(len(exception_messages), 0)
497
+
498
+
499
+ class TestLaTeXTable(unittest.TestCase):
500
+ def create_true_latex(self, middle_part, include_score=False):
501
+ if include_score:
502
+ true_latex_table_str = r"""
503
+ \begin{table}[h]
504
+ \begin{center}
505
+ \begin{tabular}{@{}clll@{}}
506
+ \toprule
507
+ Equation & Complexity & Loss & Score \\
508
+ \midrule"""
509
+ else:
510
+ true_latex_table_str = r"""
511
+ \begin{table}[h]
512
+ \begin{center}
513
+ \begin{tabular}{@{}cll@{}}
514
+ \toprule
515
+ Equation & Complexity & Loss \\
516
+ \midrule"""
517
+ true_latex_table_str += middle_part
518
+ true_latex_table_str += r"""\bottomrule
519
+ \end{tabular}
520
+ \end{center}
521
+ \end{table}
522
+ """
523
+ # First, remove empty lines:
524
+ true_latex_table_str = "\n".join(
525
+ [line.strip() for line in true_latex_table_str.split("\n") if len(line) > 0]
526
+ )
527
+ return true_latex_table_str.strip()
528
+
529
+ def test_simple_table(self):
530
+ equations = pd.DataFrame(
531
+ dict(
532
+ equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
533
+ loss=[1.052, 0.02315, 1.12347e-15],
534
+ complexity=[1, 2, 8],
535
+ )
536
+ )
537
+ model = manually_create_model(equations)
538
+
539
+ # Regular table:
540
+ latex_table_str = model.latex_table()
541
+ middle_part = r"""
542
+ $x_{0}$ & 1 & 1.05 \\
543
+ $\cos{\left(x_{0} \right)}$ & 2 & 0.0232 \\
544
+ $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
545
+ """
546
+ true_latex_table_str = self.create_true_latex(middle_part)
547
+ self.assertEqual(latex_table_str, true_latex_table_str)
548
+
549
+ # Different precision:
550
+ latex_table_str = model.latex_table(precision=5)
551
+ middle_part = r"""
552
+ $x_{0}$ & 1 & 1.052 \\
553
+ $\cos{\left(x_{0} \right)}$ & 2 & 0.02315 \\
554
+ $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.1235e-15 \\
555
+ """
556
+ true_latex_table_str = self.create_true_latex(middle_part)
557
+ self.assertEqual(latex_table_str, self.create_true_latex(middle_part))
558
+
559
+ # Including score:
560
+ latex_table_str = model.latex_table(include_score=True)
561
+ middle_part = r"""
562
+ $x_{0}$ & 1 & 1.05 & 0 \\
563
+ $\cos{\left(x_{0} \right)}$ & 2 & 0.0232 & 3.82 \\
564
+ $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 & 5.11 \\
565
+ """
566
+ true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
567
+ self.assertEqual(latex_table_str, true_latex_table_str)
568
+
569
+ # Only last equation:
570
+ latex_table_str = model.latex_table(indices=[2])
571
+ middle_part = r"""
572
+ $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
573
+ """
574
+ true_latex_table_str = self.create_true_latex(middle_part)
575
+ self.assertEqual(latex_table_str, true_latex_table_str)