MilesCranmer commited on
Commit
bb76c1f
1 Parent(s): 5a5a76f

Only run PySR in another process

Browse files
Files changed (1) hide show
  1. gui/app.py +58 -43
gui/app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
- import pysr
5
  import tempfile
6
  from typing import Optional
7
 
@@ -35,26 +35,22 @@ def generate_data(s: str, num_points: int, noise_level: float):
35
  return pd.DataFrame({"x": x}), y_noisy
36
 
37
 
38
- def greet(
39
- file_obj: Optional[tempfile._TemporaryFileWrapper],
40
- test_equation: str,
41
- num_points: int,
42
- noise_level: float,
43
- niterations: int,
44
- maxsize: int,
45
- binary_operators: list,
46
- unary_operators: list,
47
- seed: int,
48
- force_run: bool,
49
  ):
50
- if file_obj is not None:
51
- if len(binary_operators) == 0 and len(unary_operators) == 0:
52
- return (
53
- empty_df,
54
- "Please select at least one operator!",
55
- )
56
  # Look at some statistics of the file:
57
- df = pd.read_csv(file_obj)
58
  if len(df) == 0:
59
  return (
60
  empty_df,
@@ -78,10 +74,44 @@ def greet(
78
  y = np.array(df[col_to_fit])
79
  X = df.drop([col_to_fit], axis=1)
80
  else:
 
81
  X, y = generate_data(test_equation, num_points, noise_level)
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  model = pysr.PySRRegressor(
84
- bumper=True,
85
  maxsize=maxsize,
86
  niterations=niterations,
87
  binary_operators=binary_operators,
@@ -94,25 +124,11 @@ def greet(
94
  )
95
  model.fit(X, y)
96
 
97
- df = model.equations_[["equation", "loss", "complexity"]]
98
  # Convert all columns to string type:
99
- df = df.astype(str)
100
- msg = (
101
- "Success!\n"
102
- f"You may run the model locally (faster) with "
103
- f"the following parameters:"
104
- + f"""
105
- model = PySRRegressor(
106
- niterations={niterations},
107
- binary_operators={str(binary_operators)},
108
- unary_operators={str(unary_operators)},
109
- maxsize={maxsize},
110
- )
111
- model.fit(X, y)"""
112
- )
113
 
114
- df.to_csv("pysr_output.csv", index=False)
115
- return df, msg
116
 
117
 
118
  def _data_layout():
@@ -218,18 +234,18 @@ def main():
218
 
219
  with gr.Column():
220
  blocks["df"] = gr.Dataframe(
221
- headers=["Equation", "Loss", "Complexity"],
222
- datatype=["str", "number", "number"],
223
  )
224
  blocks["run"] = gr.Button()
225
- blocks["error_log"] = gr.Textbox(label="Error Log")
226
 
227
  blocks["run"].click(
228
- greet,
229
  inputs=[
230
  blocks[k]
231
  for k in [
232
  "file_input",
 
233
  "test_equation",
234
  "num_points",
235
  "noise_level",
@@ -238,10 +254,9 @@ def main():
238
  "binary_operators",
239
  "unary_operators",
240
  "seed",
241
- "force_run",
242
  ]
243
  ],
244
- outputs=[blocks["df"], blocks["error_log"]],
245
  )
246
 
247
  # Any update to the equation choice will trigger a replot:
 
1
  import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
+ import multiprocessing as mp
5
  import tempfile
6
  from typing import Optional
7
 
 
35
  return pd.DataFrame({"x": x}), y_noisy
36
 
37
 
38
+ def _greet_dispatch(
39
+ file_input,
40
+ force_run,
41
+ test_equation,
42
+ num_points,
43
+ noise_level,
44
+ niterations,
45
+ maxsize,
46
+ binary_operators,
47
+ unary_operators,
48
+ seed,
49
  ):
50
+ """Load data, then spawn a process to run the greet function."""
51
+ if file_input is not None:
 
 
 
 
52
  # Look at some statistics of the file:
53
+ df = pd.read_csv(file_input)
54
  if len(df) == 0:
55
  return (
56
  empty_df,
 
74
  y = np.array(df[col_to_fit])
75
  X = df.drop([col_to_fit], axis=1)
76
  else:
77
+ # X, y = generate_data(block["test_equation"], block["num_points"], block["noise_level"])
78
  X, y = generate_data(test_equation, num_points, noise_level)
79
 
80
+ queue = mp.Queue()
81
+ process = mp.Process(
82
+ target=greet,
83
+ kwargs=dict(
84
+ X=X,
85
+ y=y,
86
+ queue=queue,
87
+ niterations=niterations,
88
+ maxsize=maxsize,
89
+ binary_operators=binary_operators,
90
+ unary_operators=unary_operators,
91
+ seed=seed,
92
+ ),
93
+ )
94
+ process.start()
95
+ output = queue.get()
96
+ process.join()
97
+ return output
98
+
99
+
100
+ def greet(
101
+ *,
102
+ queue: mp.Queue,
103
+ X,
104
+ y,
105
+ niterations: int,
106
+ maxsize: int,
107
+ binary_operators: list,
108
+ unary_operators: list,
109
+ seed: int,
110
+ ):
111
+ import pysr
112
+
113
  model = pysr.PySRRegressor(
114
+ progress=False,
115
  maxsize=maxsize,
116
  niterations=niterations,
117
  binary_operators=binary_operators,
 
124
  )
125
  model.fit(X, y)
126
 
127
+ df = model.equations_[["complexity", "loss", "equation"]]
128
  # Convert all columns to string type:
129
+ queue.put(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ return 0
 
132
 
133
 
134
  def _data_layout():
 
234
 
235
  with gr.Column():
236
  blocks["df"] = gr.Dataframe(
237
+ headers=["complexity", "loss", "equation"],
238
+ datatype=["number", "number", "str"],
239
  )
240
  blocks["run"] = gr.Button()
 
241
 
242
  blocks["run"].click(
243
+ _greet_dispatch,
244
  inputs=[
245
  blocks[k]
246
  for k in [
247
  "file_input",
248
+ "force_run",
249
  "test_equation",
250
  "num_points",
251
  "noise_level",
 
254
  "binary_operators",
255
  "unary_operators",
256
  "seed",
 
257
  ]
258
  ],
259
+ outputs=[blocks["df"]],
260
  )
261
 
262
  # Any update to the equation choice will trigger a replot: