PySR / gui /app.py
MilesCranmer's picture
refactor(gui): put buttons in same row
e4dfed6 unverified
from collections import OrderedDict
import gradio as gr
import numpy as np
from data import TEST_EQUATIONS
from gradio.components.base import Component
from plots import plot_example_data, plot_pareto_curve
from processing import processing, stop
class ExampleData:
def __init__(self, demo: gr.Blocks) -> None:
with gr.Column(scale=1):
self.example_plot = gr.Plot()
with gr.Column(scale=1):
self.test_equation = gr.Radio(
TEST_EQUATIONS, value=TEST_EQUATIONS[0], label="Test Equation"
)
self.num_points = gr.Slider(
minimum=10,
maximum=1000,
value=200,
label="Number of Data Points",
step=1,
)
self.noise_level = gr.Slider(
minimum=0, maximum=1, value=0.05, label="Noise Level"
)
self.data_seed = gr.Number(value=0, label="Random Seed")
# Set up plotting:
eqn_components = [
self.test_equation,
self.num_points,
self.noise_level,
self.data_seed,
]
for eqn_component in eqn_components:
eqn_component.change(
plot_example_data,
eqn_components,
self.example_plot,
show_progress=False,
)
demo.load(plot_example_data, eqn_components, self.example_plot)
class UploadData:
def __init__(self) -> None:
self.file_input = gr.File(label="Upload a CSV File")
self.label = gr.Markdown(
"The rightmost column of your CSV file will be used as the target variable."
)
class Data:
def __init__(self, demo: gr.Blocks) -> None:
with gr.Tab("Example Data"):
self.example_data = ExampleData(demo)
with gr.Tab("Upload Data"):
self.upload_data = UploadData()
class BasicSettings:
def __init__(self) -> None:
self.binary_operators = gr.CheckboxGroup(
choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"],
label="Binary Operators",
value=["+", "-", "*", "/"],
)
self.unary_operators = gr.CheckboxGroup(
choices=[
"sin",
"cos",
"tan",
"exp",
"log",
"square",
"cube",
"sqrt",
"abs",
"erf",
"relu",
"round",
"sign",
],
label="Unary Operators",
value=["sin"],
)
self.niterations = gr.Slider(
minimum=1,
maximum=1000,
value=40,
label="Number of Iterations",
step=1,
)
self.maxsize = gr.Slider(
minimum=7,
maximum=100,
value=20,
label="Maximum Complexity",
step=1,
)
self.parsimony = gr.Number(
value=0.0032,
label="Parsimony Coefficient",
)
class AdvancedSettings:
def __init__(self) -> None:
self.populations = gr.Slider(
minimum=2,
maximum=100,
value=15,
label="Number of Populations",
step=1,
)
self.population_size = gr.Slider(
minimum=2,
maximum=1000,
value=33,
label="Population Size",
step=1,
)
self.ncycles_per_iteration = gr.Number(
value=550,
label="Cycles per Iteration",
)
self.elementwise_loss = gr.Radio(
["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"],
value="L2DistLoss()",
label="Loss Function",
)
self.adaptive_parsimony_scaling = gr.Number(
value=20.0,
label="Adaptive Parsimony Scaling",
)
self.optimizer_algorithm = gr.Radio(
["BFGS", "NelderMead"],
value="BFGS",
label="Optimizer Algorithm",
)
self.optimizer_iterations = gr.Slider(
minimum=1,
maximum=100,
value=8,
label="Optimizer Iterations",
step=1,
)
self.batching = gr.Checkbox(
value=False,
label="Batching",
)
self.batch_size = gr.Slider(
minimum=2,
maximum=1000,
value=50,
label="Batch Size",
step=1,
)
class GradioSettings:
def __init__(self) -> None:
self.plot_update_delay = gr.Slider(
minimum=1,
maximum=100,
value=3,
label="Plot Update Delay",
)
self.force_run = gr.Checkbox(
value=False,
label="Ignore Warnings",
)
class Settings:
def __init__(self):
with gr.Tab("Basic Settings"):
self.basic_settings = BasicSettings()
with gr.Tab("Advanced Settings"):
self.advanced_settings = AdvancedSettings()
with gr.Tab("Gradio Settings"):
self.gradio_settings = GradioSettings()
class Results:
def __init__(self):
with gr.Tab("Pareto Front"):
self.pareto = gr.Plot()
with gr.Tab("Predictions"):
self.predictions_plot = gr.Plot()
self.df = gr.Dataframe(
headers=["complexity", "loss", "equation"],
datatype=["number", "number", "str"],
wrap=True,
column_widths=[75, 75, 200],
interactive=False,
)
self.messages = gr.Textbox(label="Messages", value="", interactive=False)
def flatten_attributes(
component_group, absolute_name: str, d: OrderedDict
) -> OrderedDict:
if not hasattr(component_group, "__dict__"):
return d
for name, elem in component_group.__dict__.items():
new_absolute_name = absolute_name + "." + name
if name.startswith("_"):
# Private attribute
continue
elif elem in d.values():
# Don't duplicate any tiems
continue
elif isinstance(elem, Component):
# Only add components to dict
d[new_absolute_name] = elem
else:
flatten_attributes(elem, new_absolute_name, d)
return d
class AppInterface:
def __init__(self, demo: gr.Blocks) -> None:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
self.data = Data(demo)
with gr.Row():
self.settings = Settings()
with gr.Column(scale=2):
self.results = Results()
with gr.Row():
with gr.Column(scale=1):
self.stop = gr.Button(value="Stop")
with gr.Column(scale=1, min_width=200):
self.run = gr.Button()
# Update plot when dataframe is updated:
self.results.df.change(
plot_pareto_curve,
inputs=[self.results.df, self.settings.basic_settings.maxsize],
outputs=[self.results.pareto],
show_progress=False,
)
ignore = ["df", "predictions_plot", "pareto", "messages"]
self.run.click(
create_processing_function(self, ignore=ignore),
inputs=[
v
for k, v in flatten_attributes(self, "interface", OrderedDict()).items()
if last_part(k) not in ignore
],
outputs=[
self.results.df,
self.results.predictions_plot,
self.results.messages,
],
show_progress=True,
)
self.stop.click(stop)
def last_part(k: str) -> str:
return k.split(".")[-1]
def create_processing_function(interface: AppInterface, ignore=[]):
d = flatten_attributes(interface, "interface", OrderedDict())
keys = [k for k in map(last_part, d.keys()) if k not in ignore]
_, idx, counts = np.unique(keys, return_index=True, return_counts=True)
if np.any(counts > 1):
raise AssertionError("Bad keys: " + ",".join(np.array(keys)[idx[counts > 1]]))
def f(*components):
n = len(components)
assert n == len(keys)
for output in processing(**{keys[i]: components[i] for i in range(n)}):
yield output
return f
def main():
with gr.Blocks(theme="default") as demo:
_ = AppInterface(demo)
demo.launch(debug=True)
if __name__ == "__main__":
main()