Johann Brehmer commited on
Commit
0a0cfdc
·
1 Parent(s): bdd2ad4

Deleting temporary folder optional

Browse files
Files changed (1) hide show
  1. pysr/sr.py +61 -53
pysr/sr.py CHANGED
@@ -8,6 +8,7 @@ import sympy
8
  from sympy import sympify, Symbol, lambdify
9
  import subprocess
10
  import tempfile
 
11
  from pathlib import Path
12
 
13
  global_equation_file = 'hall_of_fame.csv'
@@ -97,6 +98,8 @@ def pysr(X=None, y=None, weights=None,
97
  limitPowComplexity=False, #deprecated
98
  threads=None, #deprecated
99
  julia_optimization=3,
 
 
100
  ):
101
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
102
  Note: most default parameters have been tuned over several example
@@ -180,6 +183,8 @@ def pysr(X=None, y=None, weights=None,
180
  and use that instead of parsimony to explore equation space. Will
181
  naturally find equations of all complexities.
182
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
 
 
183
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
184
  (as strings).
185
 
@@ -396,59 +401,62 @@ const weights = convert(Array{Float32, 1}, """f"{weight_str})"
396
  const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
397
 
398
  # Get temporary directory in a system-independent way
399
- with tempfile.TemporaryDirectory() as tmpdirname:
400
- tmpdir = Path(tmpdirname)
401
- hyperparam_filename = str(tmpdir / f'.hyperparams_{rand_string}.jl')
402
- dataset_filename = str(tmpdir / f'.dataset_{rand_string}.jl')
403
- runfile_filename = str(tmpdir / f'.runfile_{rand_string}.jl')
404
-
405
- print(tmpdir)
406
-
407
- with open(hyperparam_filename, 'w') as f:
408
- print(def_hyperparams, file=f)
409
-
410
- with open(dataset_filename, 'w') as f:
411
- print(def_datasets, file=f)
412
-
413
- with open(tmpdir / f'.runfile_{rand_string}.jl', 'w') as f:
414
- print(f'@everywhere include("{hyperparam_filename}")', file=f)
415
- print(f'@everywhere include("{dataset_filename}")', file=f)
416
- print(f'@everywhere include("{pkg_directory}/sr.jl")', file=f)
417
- print(f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})', file=f)
418
- print(f'rmprocs(nprocs)', file=f)
419
-
420
-
421
- command = [
422
- f'julia', f'-O{julia_optimization:d}',
423
- f'-p', f'{procs}',
424
- runfile_filename,
425
- ]
426
- if timeout is not None:
427
- command = [f'timeout', f'{timeout}'] + command
428
-
429
- global global_n_features
430
- global global_equation_file
431
- global global_variable_names
432
- global global_extra_sympy_mappings
433
-
434
- global_n_features = X.shape[1]
435
- global_equation_file = equation_file
436
- global_variable_names = variable_names
437
- global_extra_sympy_mappings = extra_sympy_mappings
438
-
439
- print("Running on", ' '.join(command))
440
- process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=1)
441
- try:
442
- while True:
443
- line = process.stdout.readline()
444
- if not line: break
445
- print(line.decode('utf-8').replace('\n', ''))
446
-
447
- process.stdout.close()
448
- process.wait()
449
- except KeyboardInterrupt:
450
- print("Killing process... will return when done.")
451
- process.kill()
 
 
 
452
 
453
  return get_hof()
454
 
 
8
  from sympy import sympify, Symbol, lambdify
9
  import subprocess
10
  import tempfile
11
+ import shutil
12
  from pathlib import Path
13
 
14
  global_equation_file = 'hall_of_fame.csv'
 
98
  limitPowComplexity=False, #deprecated
99
  threads=None, #deprecated
100
  julia_optimization=3,
101
+ tempdir=None,
102
+ delete_tempfiles=False,
103
  ):
104
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
105
  Note: most default parameters have been tuned over several example
 
183
  and use that instead of parsimony to explore equation space. Will
184
  naturally find equations of all complexities.
185
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
186
+ :param tempdir: str or None, directory for the temporary files
187
+ :param delete_tempfiles: bool, whether to delete the temporary files after finishing
188
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
189
  (as strings).
190
 
 
401
  const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
402
 
403
  # Get temporary directory in a system-independent way
404
+ tmpdirname = tempfile.mkdtemp(dir=tempdir)
405
+ #with tempfile.TemporaryDirectory(dir=tempdir) as tmpdirname:
406
+ tmpdir = Path(tmpdirname)
407
+
408
+ hyperparam_filename = str(tmpdir / f'.hyperparams_{rand_string}.jl')
409
+ dataset_filename = str(tmpdir / f'.dataset_{rand_string}.jl')
410
+ runfile_filename = str(tmpdir / f'.runfile_{rand_string}.jl')
411
+
412
+ with open(hyperparam_filename, 'w') as f:
413
+ print(def_hyperparams, file=f)
414
+
415
+ with open(dataset_filename, 'w') as f:
416
+ print(def_datasets, file=f)
417
+
418
+ with open(tmpdir / f'.runfile_{rand_string}.jl', 'w') as f:
419
+ print(f'@everywhere include("{hyperparam_filename}")', file=f)
420
+ print(f'@everywhere include("{dataset_filename}")', file=f)
421
+ print(f'@everywhere include("{pkg_directory}/sr.jl")', file=f)
422
+ print(f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})', file=f)
423
+ print(f'rmprocs(nprocs)', file=f)
424
+
425
+
426
+ command = [
427
+ f'julia', f'-O{julia_optimization:d}',
428
+ f'-p', f'{procs}',
429
+ runfile_filename,
430
+ ]
431
+ if timeout is not None:
432
+ command = [f'timeout', f'{timeout}'] + command
433
+
434
+ global global_n_features
435
+ global global_equation_file
436
+ global global_variable_names
437
+ global global_extra_sympy_mappings
438
+
439
+ global_n_features = X.shape[1]
440
+ global_equation_file = equation_file
441
+ global_variable_names = variable_names
442
+ global_extra_sympy_mappings = extra_sympy_mappings
443
+
444
+ print("Running on", ' '.join(command))
445
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=1)
446
+ try:
447
+ while True:
448
+ line = process.stdout.readline()
449
+ if not line: break
450
+ print(line.decode('utf-8').replace('\n', ''))
451
+
452
+ process.stdout.close()
453
+ process.wait()
454
+ except KeyboardInterrupt:
455
+ print("Killing process... will return when done.")
456
+ process.kill()
457
+
458
+ if delete_tempfiles:
459
+ shutil.rmtree(tmpdir)
460
 
461
  return get_hof()
462