helboukkouri's picture
adjust plot size
1862393
raw
history blame contribute delete
No virus
6.24 kB
import gradio as gr
import numpy as np
import sympy as sp
import seaborn as sns
from matplotlib import pyplot as plt
sns.set_style(style="darkgrid")
sns.set_context(context="notebook", font_scale=0.7)
MAX_NOISE = 20
DEFAULT_NOISE = 6
SLIDE_NOISE_STEP = 2
MAX_POINTS = 100
DEFAULT_POINTS = 20
SLIDE_POINTS_STEP = 5
def generate_equation(process_params):
process_params = process_params.astype(float).values.tolist()
# Define symbols
x = sp.symbols('x')
coefficients = sp.symbols('a b c d e')
# Create the polynomial expression
polynomial_expression = None
for i, coef in enumerate(reversed(coefficients)):
polynomial_expression = polynomial_expression + coef * x**i if polynomial_expression else coef * x**i
# Parameter mapping
parameters = {coef: value for coef, value in zip(coefficients, process_params[0])}
# Substitute parameter values into the expression
polynomial_with_values = polynomial_expression.subs(parameters)
latex_representation = sp.latex(polynomial_with_values)
return fr"Underlying process $${latex_representation}$$"
def true_process(x, process_params):
"""The true process we want to model."""
process_params = process_params.astype(float).values.tolist()
return (
process_params[0][0] * (x ** 4)
+ process_params[0][1] * (x ** 3)
+ process_params[0][2] * (x ** 2)
+ process_params[0][3] * x
+ process_params[0][4]
)
def generate_data(num_points, noise_level, process_params):
# x is the list of input values
input_values = np.linspace(-5, 2, num_points)
input_values_dense = np.linspace(-5, 2, MAX_POINTS)
# y = f(x) is the underlying process we want to model
y = [true_process(x, process_params) for x in input_values]
y_dense = [true_process(x, process_params) for x in input_values_dense]
# however, we can only observe a noisy version of f(x)
noise = np.random.normal(0, noise_level, len(input_values))
y_noisy = y + noise
return input_values, input_values_dense, y, y_dense, y_noisy
def make_plot(
num_points, noise_level, process_params,
show_true_process, show_original_points, show_added_noise, show_noisy_points,
):
x, x_dense, y, y_dense, y_noisy = generate_data(num_points, noise_level, process_params)
fig = plt.figure(dpi=300)
if show_true_process:
plt.plot(
x_dense, y_dense, "-", color="#363A4F",
label="True Process",
lw=1.5,
)
if show_added_noise:
plt.vlines(
x, y, y_noisy, color="#556D9A",
linestyles="dashed",
alpha=0.75,
lw=1,
label="Added Noise",
)
if show_original_points:
plt.plot(
x, y, "-o", color="none",
ms=6,
markerfacecolor="white",
markeredgecolor="#556D9A",
markeredgewidth=1.2,
label="Original Points",
)
if show_noisy_points:
plt.plot(
x, y_noisy, "-o", color="none",
ms=6.5,
markerfacecolor="#556D9A",
markeredgecolor="none",
markeredgewidth=1.5,
alpha=1,
label="Noisy Points",
)
plt.xlabel("\nx")
plt.ylabel("y")
plt.legend(fontsize=7.5)
plt.tight_layout()
plt.show()
return fig
# Force main column to be 100 pixels wide, knowing that the parent is a flex container with column direction
css = """
.gradio-container {
width: min(1000px, 50%)!important;
min-width: 800px;
}
.main-plot {
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
with gr.Row():
process_params = gr.DataFrame(
value=[[0.5, 2, -0.5, -2, 1]],
label="Underlying Process Coefficients",
type="pandas",
column_widths=("2", "1", "1", "1", "1w"),
headers=["x ** 4", "x ** 3", "x ** 2", "x", "1"],
interactive=True
)
equation = gr.Markdown()
with gr.Row():
with gr.Column():
num_points = gr.Slider(
minimum=5,
maximum=MAX_POINTS,
value=DEFAULT_POINTS,
step=SLIDE_POINTS_STEP,
label="Number of Points"
)
with gr.Column():
noise_level = gr.Slider(
minimum=0,
maximum=MAX_NOISE,
value=DEFAULT_NOISE,
step=SLIDE_NOISE_STEP,
label="Noise Level"
)
show_params = []
with gr.Row():
with gr.Column():
show_params.append(gr.Checkbox(label="Show Underlying Process", value=True))
show_params.append(gr.Checkbox(label="Show Original Points", value=True))
with gr.Column():
show_params.append(gr.Checkbox(label="Show Added Noise", value=True))
show_params.append(gr.Checkbox(label="Show Noisy Points", value=True))
scatter_plot = gr.Plot(elem_classes=["main-plot"])
num_points.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
noise_level.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
process_params.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
process_params.change(fn=generate_equation, inputs=[process_params], outputs=equation)
for component in show_params:
component.change(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
demo.load(fn=make_plot, inputs=[num_points, noise_level, process_params, *show_params], outputs=scatter_plot)
demo.load(fn=generate_equation, inputs=[process_params], outputs=equation)
if __name__ == "__main__":
demo.launch()