MilesCranmer commited on
Commit
b2fc69c
·
unverified ·
1 Parent(s): ab9ae60

Enable complex numbers

Browse files
Files changed (2) hide show
  1. pysr/sr.py +11 -1
  2. pysr/version.py +2 -2
pysr/sr.py CHANGED
@@ -498,6 +498,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
498
  What precision to use for the data. By default this is `32`
499
  (float32), but you can select `64` or `16` as well, giving
500
  you 64 or 16 bits of floating point precision, respectively.
 
 
501
  Default is `32`.
502
  random_state : int, Numpy RandomState instance or None
503
  Pass an int for reproducible results across multiple function calls.
@@ -1619,7 +1621,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1619
  )
1620
 
1621
  # Convert data to desired precision
1622
- np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
 
 
 
 
 
 
 
 
1623
 
1624
  # This converts the data into a Julia array:
1625
  Main.X = np.array(X, dtype=np_dtype).T
 
498
  What precision to use for the data. By default this is `32`
499
  (float32), but you can select `64` or `16` as well, giving
500
  you 64 or 16 bits of floating point precision, respectively.
501
+ If you pass complex data, the corresponding complex precision
502
+ will be used (i.e., `64` for complex128, `32` for complex64).
503
  Default is `32`.
504
  random_state : int, Numpy RandomState instance or None
505
  Pass an int for reproducible results across multiple function calls.
 
1621
  )
1622
 
1623
  # Convert data to desired precision
1624
+ test_X = np.array(X)
1625
+ is_real = np.issubdtype(test_X.dtype, np.floating)
1626
+ is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
1627
+ if is_real:
1628
+ np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
1629
+ elif is_complex:
1630
+ np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
1631
+ else:
1632
+ np_dtype = None
1633
 
1634
  # This converts the data into a Julia array:
1635
  Main.X = np.array(X, dtype=np_dtype).T
pysr/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = "0.11.17"
2
- __symbolic_regression_jl_version__ = "0.15.3"
 
1
+ __version__ = "0.12.0"
2
+ __symbolic_regression_jl_version__ = "0.16.0"