Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
1efb6f4
1
Parent(s):
b158e1f
Initial working version with PyJulia
Browse files- pysr/sr.py +62 -8
pysr/sr.py
CHANGED
@@ -12,6 +12,7 @@ from datetime import datetime
|
|
12 |
import warnings
|
13 |
from multiprocessing import cpu_count
|
14 |
|
|
|
15 |
global_state = dict(
|
16 |
equation_file="hall_of_fame.csv",
|
17 |
n_features=None,
|
@@ -132,6 +133,7 @@ def pysr(
|
|
132 |
Xresampled=None,
|
133 |
precision=32,
|
134 |
multithreading=None,
|
|
|
135 |
):
|
136 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
137 |
Note: most default parameters have been tuned over several example
|
@@ -254,6 +256,8 @@ def pysr(
|
|
254 |
:type precision: int
|
255 |
:param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
|
256 |
:type multithreading: bool
|
|
|
|
|
257 |
:returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
|
258 |
:type: pd.DataFrame/list
|
259 |
"""
|
@@ -272,7 +276,18 @@ def pysr(
|
|
272 |
# or procs is set to 0 (serial mode).
|
273 |
multithreading = procs != 0
|
274 |
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
if progress is not None:
|
278 |
if progress and not buffer_available:
|
@@ -280,6 +295,11 @@ def pysr(
|
|
280 |
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
281 |
)
|
282 |
progress = False
|
|
|
|
|
|
|
|
|
|
|
283 |
else:
|
284 |
progress = buffer_available
|
285 |
|
@@ -321,7 +341,8 @@ def pysr(
|
|
321 |
weights,
|
322 |
y,
|
323 |
)
|
324 |
-
|
|
|
325 |
|
326 |
if len(X) > 10000 and not batching:
|
327 |
warnings.warn(
|
@@ -437,6 +458,7 @@ def pysr(
|
|
437 |
denoise=denoise,
|
438 |
precision=precision,
|
439 |
multithreading=multithreading,
|
|
|
440 |
)
|
441 |
|
442 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
@@ -457,7 +479,7 @@ def pysr(
|
|
457 |
|
458 |
kwargs["need_install"] = False
|
459 |
|
460 |
-
if not (manifest_filepath).is_file():
|
461 |
kwargs["need_install"] = (not user_input) or _yesno(
|
462 |
"I will install Julia packages using PySR's Project.toml file. OK?"
|
463 |
)
|
@@ -471,10 +493,35 @@ def pysr(
|
|
471 |
|
472 |
kwargs["constraints_str"] = _make_constraints_str(**kwargs)
|
473 |
kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
|
474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
475 |
|
476 |
_create_julia_files(**kwargs)
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
_set_globals(**kwargs)
|
479 |
|
480 |
equations = get_hof(**kwargs)
|
@@ -558,12 +605,16 @@ def _create_julia_files(
|
|
558 |
need_install,
|
559 |
update,
|
560 |
multithreading,
|
|
|
561 |
**kwargs,
|
562 |
):
|
563 |
with open(hyperparam_filename, "w") as f:
|
564 |
print(def_hyperparams, file=f)
|
565 |
-
|
566 |
-
|
|
|
|
|
|
|
567 |
with open(runfile_filename, "w") as f:
|
568 |
if julia_project is None:
|
569 |
julia_project = pkg_directory
|
@@ -579,7 +630,10 @@ def _create_julia_files(
|
|
579 |
print(f"Pkg.update()", file=f)
|
580 |
print(f"using SymbolicRegression", file=f)
|
581 |
print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
|
582 |
-
|
|
|
|
|
|
|
583 |
if len(variable_names) == 0:
|
584 |
varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
|
585 |
else:
|
|
|
12 |
import warnings
|
13 |
from multiprocessing import cpu_count
|
14 |
|
15 |
+
Main = None
|
16 |
global_state = dict(
|
17 |
equation_file="hall_of_fame.csv",
|
18 |
n_features=None,
|
|
|
133 |
Xresampled=None,
|
134 |
precision=32,
|
135 |
multithreading=None,
|
136 |
+
pyjulia=False,
|
137 |
):
|
138 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
139 |
Note: most default parameters have been tuned over several example
|
|
|
256 |
:type precision: int
|
257 |
:param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
|
258 |
:type multithreading: bool
|
259 |
+
:param pyjulia: Whether to use PyJulia instead of julia binary. PyJulia should reduce startup time for repeat calls.
|
260 |
+
:type pyjulia: bool
|
261 |
:returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
|
262 |
:type: pd.DataFrame/list
|
263 |
"""
|
|
|
276 |
# or procs is set to 0 (serial mode).
|
277 |
multithreading = procs != 0
|
278 |
|
279 |
+
# Start up Julia:
|
280 |
+
global Main
|
281 |
+
if pyjulia and Main is None:
|
282 |
+
if not multithreading:
|
283 |
+
raise AssertionError(
|
284 |
+
"PyJulia does not support multiprocessing. Turn multithreading=True."
|
285 |
+
)
|
286 |
+
|
287 |
+
os.environ["JULIA_NUM_THREADS"] = str(procs)
|
288 |
+
from julia import Main
|
289 |
+
|
290 |
+
buffer_available = "buffer" in sys.stdout.__dir__() and not pyjulia
|
291 |
|
292 |
if progress is not None:
|
293 |
if progress and not buffer_available:
|
|
|
295 |
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
296 |
)
|
297 |
progress = False
|
298 |
+
if progress and pyjulia:
|
299 |
+
warnings.warn(
|
300 |
+
"Note: it looks like you are using PyJulia. The progress bar will be turned off."
|
301 |
+
)
|
302 |
+
progress = False
|
303 |
else:
|
304 |
progress = buffer_available
|
305 |
|
|
|
341 |
weights,
|
342 |
y,
|
343 |
)
|
344 |
+
if not pyjulia:
|
345 |
+
_check_for_julia_installation()
|
346 |
|
347 |
if len(X) > 10000 and not batching:
|
348 |
warnings.warn(
|
|
|
458 |
denoise=denoise,
|
459 |
precision=precision,
|
460 |
multithreading=multithreading,
|
461 |
+
pyjulia=pyjulia,
|
462 |
)
|
463 |
|
464 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
479 |
|
480 |
kwargs["need_install"] = False
|
481 |
|
482 |
+
if not (manifest_filepath).is_file() and not pyjulia:
|
483 |
kwargs["need_install"] = (not user_input) or _yesno(
|
484 |
"I will install Julia packages using PySR's Project.toml file. OK?"
|
485 |
)
|
|
|
493 |
|
494 |
kwargs["constraints_str"] = _make_constraints_str(**kwargs)
|
495 |
kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
|
496 |
+
|
497 |
+
if pyjulia:
|
498 |
+
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
|
499 |
+
|
500 |
+
Main.X = np.array(X, dtype=np_dtype).T
|
501 |
+
if len(y.shape) == 1:
|
502 |
+
Main.y = np.array(y, dtype=np_dtype)
|
503 |
+
else:
|
504 |
+
Main.y = np.array(y, dtype=np_dtype).T
|
505 |
+
if weights is not None:
|
506 |
+
if len(weights.shape) == 1:
|
507 |
+
Main.weights = np.array(weights, dtype=np_dtype)
|
508 |
+
else:
|
509 |
+
Main.weights = np.array(weights, dtype=np_dtype).T
|
510 |
+
|
511 |
+
kwargs["def_datasets"] = ""
|
512 |
+
else:
|
513 |
+
kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
|
514 |
|
515 |
_create_julia_files(**kwargs)
|
516 |
+
if pyjulia:
|
517 |
+
# Read entire file as a single string:
|
518 |
+
with open(kwargs["runfile_filename"], "r") as f:
|
519 |
+
runfile_string = f.read()
|
520 |
+
print("Running main runfile in PyJulia!")
|
521 |
+
Main.eval(runfile_string)
|
522 |
+
else:
|
523 |
+
_final_pysr_process(**kwargs)
|
524 |
+
|
525 |
_set_globals(**kwargs)
|
526 |
|
527 |
equations = get_hof(**kwargs)
|
|
|
605 |
need_install,
|
606 |
update,
|
607 |
multithreading,
|
608 |
+
pyjulia,
|
609 |
**kwargs,
|
610 |
):
|
611 |
with open(hyperparam_filename, "w") as f:
|
612 |
print(def_hyperparams, file=f)
|
613 |
+
|
614 |
+
if not pyjulia:
|
615 |
+
with open(dataset_filename, "w") as f:
|
616 |
+
print(def_datasets, file=f)
|
617 |
+
|
618 |
with open(runfile_filename, "w") as f:
|
619 |
if julia_project is None:
|
620 |
julia_project = pkg_directory
|
|
|
630 |
print(f"Pkg.update()", file=f)
|
631 |
print(f"using SymbolicRegression", file=f)
|
632 |
print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
|
633 |
+
|
634 |
+
if not pyjulia:
|
635 |
+
print(f'include("{_escape_filename(dataset_filename)}")', file=f)
|
636 |
+
|
637 |
if len(variable_names) == 0:
|
638 |
varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
|
639 |
else:
|