MilesCranmer commited on
Commit
a2492c3
1 Parent(s): fd28328

Better support for live reading and predictions

Browse files
Files changed (1) hide show
  1. gui/processing.py +80 -82
gui/processing.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import tempfile
4
  import time
5
  from pathlib import Path
 
6
 
7
  import pandas as pd
8
  from data import generate_data, read_csv
@@ -37,8 +38,6 @@ def pysr_fit(queue: mp.Queue, out_queue: mp.Queue):
37
 
38
 
39
  def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
40
- import numpy as np
41
-
42
  import pysr
43
 
44
  while True:
@@ -49,7 +48,7 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
49
 
50
  X = args["X"]
51
  equation_file = str(args["equation_file"])
52
- complexity = args["complexity"]
53
 
54
  equation_file_pkl = equation_file.replace(".csv", ".pkl")
55
  equation_file_bkup = equation_file + ".bkup"
@@ -66,31 +65,29 @@ def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
66
  except pd.errors.EmptyDataError:
67
  continue
68
 
69
- index = np.abs(model.equations_.complexity - complexity).argmin
70
  ypred = model.predict(X, index)
71
 
72
- out_queue.put(ypred)
 
73
 
 
 
 
74
 
75
- class PySRProcess:
76
- def __init__(self):
77
- self.queue = mp.Queue()
78
- self.out_queue = mp.Queue()
79
- self.process = mp.Process(target=pysr_fit, args=(self.queue, self.out_queue))
80
- self.process.start()
81
 
82
 
83
- class PySRReaderProcess:
84
- def __init__(self):
85
- self.queue = mp.Queue()
86
- self.out_queue = mp.Queue()
87
- self.process = mp.Process(
88
- target=pysr_predict, args=(self.queue, self.out_queue)
89
- )
90
  self.process.start()
91
 
92
 
93
  PERSISTENT_WRITER = None
 
94
 
95
 
96
  def processing(
@@ -118,9 +115,15 @@ def processing(
118
  ):
119
  """Load data, then spawn a process to run the greet function."""
120
  global PERSISTENT_WRITER
 
 
121
  if PERSISTENT_WRITER is None:
122
- print("Starting PySR process")
123
- PERSISTENT_WRITER = PySRProcess()
 
 
 
 
124
 
125
  if file_input is not None:
126
  try:
@@ -130,67 +133,62 @@ def processing(
130
  else:
131
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
132
 
133
- with tempfile.TemporaryDirectory() as tmpdirname:
134
- base = Path(tmpdirname)
135
- equation_file = base / "hall_of_fame.csv"
136
- equation_file_bkup = base / "hall_of_fame.csv.bkup"
137
- # Check if queue is empty, if not, kill the process
138
- # and start a new one
139
- if not PERSISTENT_WRITER.queue.empty():
140
- print("Restarting PySR process")
141
- if PERSISTENT_WRITER.process.is_alive():
142
- PERSISTENT_WRITER.process.terminate()
143
- PERSISTENT_WRITER.process.join()
144
-
145
- PERSISTENT_WRITER = PySRProcess()
146
- # Write these to queue instead:
147
- PERSISTENT_WRITER.queue.put(
148
- dict(
149
- X=X,
150
- y=y,
151
- kwargs=dict(
152
- niterations=niterations,
153
- maxsize=maxsize,
154
- binary_operators=binary_operators,
155
- unary_operators=unary_operators,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  equation_file=equation_file,
157
- parsimony=parsimony,
158
- populations=populations,
159
- population_size=population_size,
160
- ncycles_per_iteration=ncycles_per_iteration,
161
- elementwise_loss=elementwise_loss,
162
- adaptive_parsimony_scaling=adaptive_parsimony_scaling,
163
- optimizer_algorithm=optimizer_algorithm,
164
- optimizer_iterations=optimizer_iterations,
165
- batching=batching,
166
- batch_size=batch_size,
167
- ),
168
  )
169
- )
170
- while PERSISTENT_WRITER.out_queue.empty():
171
- if equation_file_bkup.exists():
172
- # First, copy the file to a the copy file
173
- equation_file_copy = base / "hall_of_fame_copy.csv"
174
- os.system(f"cp {equation_file_bkup} {equation_file_copy}")
175
- try:
176
- equations = pd.read_csv(equation_file_copy)
177
- except pd.errors.EmptyDataError:
178
- continue
179
-
180
- # Ensure it is pareto dominated, with more complex expressions
181
- # having higher loss. Otherwise remove those rows.
182
- # TODO: Not sure why this occurs; could be the result of a late copy?
183
- equations.sort_values("Complexity", ascending=True, inplace=True)
184
- equations.reset_index(inplace=True)
185
- bad_idx = []
186
- min_loss = None
187
- for i in equations.index:
188
- if min_loss is None or equations.loc[i, "Loss"] < min_loss:
189
- min_loss = float(equations.loc[i, "Loss"])
190
- else:
191
- bad_idx.append(i)
192
- equations.drop(index=bad_idx, inplace=True)
193
-
194
- yield equations[["Complexity", "Loss", "Equation"]]
195
-
196
- time.sleep(0.1)
 
3
  import tempfile
4
  import time
5
  from pathlib import Path
6
+ from typing import Callable
7
 
8
  import pandas as pd
9
  from data import generate_data, read_csv
 
38
 
39
 
40
  def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
 
 
41
  import pysr
42
 
43
  while True:
 
48
 
49
  X = args["X"]
50
  equation_file = str(args["equation_file"])
51
+ index = args["index"]
52
 
53
  equation_file_pkl = equation_file.replace(".csv", ".pkl")
54
  equation_file_bkup = equation_file + ".bkup"
 
65
  except pd.errors.EmptyDataError:
66
  continue
67
 
 
68
  ypred = model.predict(X, index)
69
 
70
+ # Rename the columns to uppercase
71
+ equations = model.equations_[["complexity", "loss", "equation"]].copy()
72
 
73
+ # Remove any row that has worse loss than previous row:
74
+ equations = equations[equations["loss"].cummin() == equations["loss"]]
75
+ # TODO: Why is this needed? Are rows not being removed?
76
 
77
+ equations.columns = ["Complexity", "Loss", "Equation"]
78
+ out_queue.put(dict(ypred=ypred, equations=equations))
 
 
 
 
79
 
80
 
81
+ class ProcessWrapper:
82
+ def __init__(self, target: Callable[[mp.Queue, mp.Queue], None]):
83
+ self.queue = mp.Queue(maxsize=1)
84
+ self.out_queue = mp.Queue(maxsize=1)
85
+ self.process = mp.Process(target=target, args=(self.queue, self.out_queue))
 
 
86
  self.process.start()
87
 
88
 
89
  PERSISTENT_WRITER = None
90
+ PERSISTENT_READER = None
91
 
92
 
93
  def processing(
 
115
  ):
116
  """Load data, then spawn a process to run the greet function."""
117
  global PERSISTENT_WRITER
118
+ global PERSISTENT_READER
119
+
120
  if PERSISTENT_WRITER is None:
121
+ print("Starting PySR fit process")
122
+ PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
123
+
124
+ if PERSISTENT_READER is None:
125
+ print("Starting PySR predict process")
126
+ PERSISTENT_READER = ProcessWrapper(pysr_predict)
127
 
128
  if file_input is not None:
129
  try:
 
133
  else:
134
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
135
 
136
+ tmpdirname = tempfile.mkdtemp()
137
+ base = Path(tmpdirname)
138
+ equation_file = base / "hall_of_fame.csv"
139
+ # Check if queue is empty, if not, kill the process
140
+ # and start a new one
141
+ if not PERSISTENT_WRITER.queue.empty():
142
+ print("Restarting PySR fit process")
143
+ if PERSISTENT_WRITER.process.is_alive():
144
+ PERSISTENT_WRITER.process.terminate()
145
+ PERSISTENT_WRITER.process.join()
146
+
147
+ PERSISTENT_WRITER = ProcessWrapper(pysr_fit)
148
+
149
+ if not PERSISTENT_READER.queue.empty():
150
+ print("Restarting PySR predict process")
151
+ if PERSISTENT_READER.process.is_alive():
152
+ PERSISTENT_READER.process.terminate()
153
+ PERSISTENT_READER.process.join()
154
+
155
+ PERSISTENT_READER = ProcessWrapper(pysr_predict)
156
+
157
+ PERSISTENT_WRITER.queue.put(
158
+ dict(
159
+ X=X,
160
+ y=y,
161
+ kwargs=dict(
162
+ niterations=niterations,
163
+ maxsize=maxsize,
164
+ binary_operators=binary_operators,
165
+ unary_operators=unary_operators,
166
+ equation_file=equation_file,
167
+ parsimony=parsimony,
168
+ populations=populations,
169
+ population_size=population_size,
170
+ ncycles_per_iteration=ncycles_per_iteration,
171
+ elementwise_loss=elementwise_loss,
172
+ adaptive_parsimony_scaling=adaptive_parsimony_scaling,
173
+ optimizer_algorithm=optimizer_algorithm,
174
+ optimizer_iterations=optimizer_iterations,
175
+ batching=batching,
176
+ batch_size=batch_size,
177
+ ),
178
+ )
179
+ )
180
+ while PERSISTENT_WRITER.out_queue.empty():
181
+ if equation_file.exists():
182
+ # First, copy the file to a the copy file
183
+ PERSISTENT_READER.queue.put(
184
+ dict(
185
+ X=X,
186
  equation_file=equation_file,
187
+ index=-1,
188
+ )
 
 
 
 
 
 
 
 
 
189
  )
190
+ out = PERSISTENT_READER.out_queue.get()
191
+ equations = out["equations"]
192
+ yield equations[["Complexity", "Loss", "Equation"]]
193
+
194
+ time.sleep(0.1)