MilesCranmer commited on
Commit
7847c48
1 Parent(s): 5617815

Make all best_ equations return list if multi output

Browse files
Files changed (1) hide show
  1. pysr/sr.py +14 -10
pysr/sr.py CHANGED
@@ -821,35 +821,39 @@ def best_row(equations=None):
821
  """Return the best row of a hall of fame file using the score column.
822
  By default this uses the last equation file.
823
  """
824
- if equations is None: all_eqs = get_hof()
825
- if isinstance(all_eqs, list):
826
- return [equations[j].iloc[np.argmax(equations[j]['score'])] for j in range(len(all_eqs))]
827
  else:
828
- best_idx = np.argmax(equations['score'])
829
- return equations.iloc[best_idx]
830
 
831
  def best_tex(equations=None):
832
  """Return the equation with the best score, in latex format
833
  By default this uses the last equation file.
834
  """
835
  if equations is None: equations = get_hof()
836
- best_sympy = best_row(equations)['sympy_format']
837
- return sympy.latex(best_sympy.simplify())
 
 
838
 
839
  def best(equations=None):
840
  """Return the equation with the best score, in sympy format.
841
  By default this uses the last equation file.
842
  """
843
  if equations is None: equations = get_hof()
844
- best_sympy = best_row(equations)['sympy_format']
845
- return best_sympy.simplify()
 
846
 
847
  def best_callable(equations=None):
848
  """Return the equation with the best score, in callable format.
849
  By default this uses the last equation file.
850
  """
851
  if equations is None: equations = get_hof()
852
- return best_row(equations)['lambda_format']
 
 
853
 
854
  def _escape_filename(filename):
855
  """Turns a file into a string representation with correctly escaped backslashes"""
 
821
  """Return the best row of a hall of fame file using the score column.
822
  By default this uses the last equation file.
823
  """
824
+ if equations is None: equations = get_hof()
825
+ if isinstance(equations, list):
826
+ return [eq.iloc[np.argmax(eq['score'])] for eq in equations]
827
  else:
828
+ return equations.iloc[np.argmax(equations['score'])]
 
829
 
830
  def best_tex(equations=None):
831
  """Return the equation with the best score, in latex format
832
  By default this uses the last equation file.
833
  """
834
  if equations is None: equations = get_hof()
835
+ if isinstance(equations, list):
836
+ return [sympy.latex(best_row(eq)['sympy_format'].simplify()) for eq in equations]
837
+ else:
838
+ return sympy.latex(best_row(equations)['sympy_format'].simplify())
839
 
840
  def best(equations=None):
841
  """Return the equation with the best score, in sympy format.
842
  By default this uses the last equation file.
843
  """
844
  if equations is None: equations = get_hof()
845
+ return [best_row(eq)['sympy_format'].simplify() for eq in equations]
846
+ else:
847
+ return best_row(equations)['sympy_format'].simplify()
848
 
849
  def best_callable(equations=None):
850
  """Return the equation with the best score, in callable format.
851
  By default this uses the last equation file.
852
  """
853
  if equations is None: equations = get_hof()
854
+ return [best_row(eq)['lambda_format'] for eq in equations]
855
+ else:
856
+ return best_row(equations)['lambda_format']
857
 
858
  def _escape_filename(filename):
859
  """Turns a file into a string representation with correctly escaped backslashes"""