MilesCranmer commited on
Commit
79a7cfe
·
1 Parent(s): 1b92896

Add precision parameter

Browse files
Files changed (1) hide show
  1. pysr/sr.py +28 -11
pysr/sr.py CHANGED
@@ -132,6 +132,7 @@ def pysr(
132
  tournament_selection_p=1.0,
133
  denoise=False,
134
  Xresampled=None,
 
135
  ):
136
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
137
  Note: most default parameters have been tuned over several example
@@ -250,6 +251,8 @@ def pysr(
250
  :type tournament_selection_p: float
251
  :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
252
  :type denoise: bool
 
 
253
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
254
  :type: pd.DataFrame/list
255
  """
@@ -427,6 +430,7 @@ def pysr(
427
  tournament_selection_n=tournament_selection_n,
428
  tournament_selection_p=tournament_selection_p,
429
  denoise=denoise,
 
430
  )
431
 
432
  kwargs = {**_set_paths(tempdir), **kwargs}
@@ -582,40 +586,53 @@ def _create_julia_files(
582
 
583
 
584
  def _make_datasets_julia_str(
585
- X, X_filename, weights, weights_filename, y, y_filename, multioutput, **kwargs
 
 
 
 
 
 
 
 
586
  ):
587
  def_datasets = """using DelimitedFiles"""
588
- np.savetxt(X_filename, X.astype(np.float32), delimiter=",")
 
 
 
589
  if multioutput:
590
- np.savetxt(y_filename, y.astype(np.float32), delimiter=",")
591
  else:
592
- np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=",")
 
593
  if weights is not None:
594
  if multioutput:
595
- np.savetxt(weights_filename, weights.astype(np.float32), delimiter=",")
596
  else:
597
  np.savetxt(
598
  weights_filename,
599
- weights.reshape(-1, 1).astype(np.float32),
600
  delimiter=",",
601
  )
 
602
  def_datasets += f"""
603
- X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
604
 
605
  if multioutput:
606
  def_datasets += f"""
607
- y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
608
  else:
609
  def_datasets += f"""
610
- y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
611
 
612
  if weights is not None:
613
  if multioutput:
614
  def_datasets += f"""
615
- weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
616
  else:
617
  def_datasets += f"""
618
- weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
619
  return def_datasets
620
 
621
 
 
132
  tournament_selection_p=1.0,
133
  denoise=False,
134
  Xresampled=None,
135
+ precision=32,
136
  ):
137
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
138
  Note: most default parameters have been tuned over several example
 
251
  :type tournament_selection_p: float
252
  :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
253
  :type denoise: bool
254
+ :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
255
+ :type precision: int
256
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
257
  :type: pd.DataFrame/list
258
  """
 
430
  tournament_selection_n=tournament_selection_n,
431
  tournament_selection_p=tournament_selection_p,
432
  denoise=denoise,
433
+ precision=precision,
434
  )
435
 
436
  kwargs = {**_set_paths(tempdir), **kwargs}
 
586
 
587
 
588
  def _make_datasets_julia_str(
589
+ X,
590
+ X_filename,
591
+ weights,
592
+ weights_filename,
593
+ y,
594
+ y_filename,
595
+ multioutput,
596
+ precision,
597
+ **kwargs,
598
  ):
599
  def_datasets = """using DelimitedFiles"""
600
+ julia_dtype = {16: "Float16", 32: "Float32", 64: "Float64"}[precision]
601
+ np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
602
+
603
+ np.savetxt(X_filename, X.astype(np_dtype), delimiter=",")
604
  if multioutput:
605
+ np.savetxt(y_filename, y.astype(np_dtype), delimiter=",")
606
  else:
607
+ np.savetxt(y_filename, y.reshape(-1, 1).astype(np_dtype), delimiter=",")
608
+
609
  if weights is not None:
610
  if multioutput:
611
+ np.savetxt(weights_filename, weights.astype(np_dtype), delimiter=",")
612
  else:
613
  np.savetxt(
614
  weights_filename,
615
+ weights.reshape(-1, 1).astype(np_dtype),
616
  delimiter=",",
617
  )
618
+
619
  def_datasets += f"""
620
+ X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', {julia_dtype}, '\\n')))"""
621
 
622
  if multioutput:
623
  def_datasets += f"""
624
+ y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')))"""
625
  else:
626
  def_datasets += f"""
627
+ y = readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
628
 
629
  if weights is not None:
630
  if multioutput:
631
  def_datasets += f"""
632
+ weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')))"""
633
  else:
634
  def_datasets += f"""
635
+ weights = readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
636
  return def_datasets
637
 
638