rosenyu's picture
Update app.py
02589a4 verified
import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
from test_functions.Ackley10D import *
from test_functions.Ackley2D import *
from test_functions.Ackley6D import *
from test_functions.HeatExchanger import *
from test_functions.CantileverBeam import *
from test_functions.Car import *
from test_functions.CompressionSpring import *
from test_functions.GKXWC1 import *
from test_functions.GKXWC2 import *
from test_functions.HeatExchanger import *
from test_functions.JLH1 import *
from test_functions.JLH2 import *
from test_functions.KeaneBump import *
from test_functions.GKXWC1 import *
from test_functions.GKXWC2 import *
from test_functions.PressureVessel import *
from test_functions.ReinforcedConcreteBeam import *
from test_functions.SpeedReducer import *
from test_functions.ThreeTruss import *
from test_functions.WeldedBeam import *
# Import other objective functions as needed
import time
from Rosen_PFN4BO import *
from PIL import Image
def s(input_string):
return input_string
def optimize(objective_function, iteration_input, progress=gr.Progress()):
# print(objective_function)
# Variable setup
Current_BEST = torch.tensor( -1e10 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -1e10 )
if objective_function=="CantileverBeam.png":
Current_BEST = torch.tensor( -82500 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -82500 )
elif objective_function=="CompressionSpring.png":
Current_BEST = torch.tensor( -8 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -8 )
elif objective_function=="HeatExchanger.png":
Current_BEST = torch.tensor( -30000 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -30000 )
elif objective_function=="ThreeTruss.png":
Current_BEST = torch.tensor( -300 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -300 )
elif objective_function=="Reinforcement.png":
Current_BEST = torch.tensor( -440 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -440 )
elif objective_function=="PressureVessel.png":
Current_BEST = torch.tensor( -40000 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -40000 )
elif objective_function=="SpeedReducer.png":
Current_BEST = torch.tensor( -3200 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -3200 )
elif objective_function=="WeldedBeam.png":
Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -35 )
elif objective_function=="Car.png":
Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number
Prev_BEST = torch.tensor( -35 )
# Initial random samples
# print(objective_functions)
trained_X = torch.rand(20, objective_functions[objective_function]['dim'])
# Scale it to the domain of interest using the selected function
# print(objective_function)
X_Scaled = objective_functions[objective_function]['scaling'](trained_X)
# Get the constraints and objective
trained_gx, trained_Y = objective_functions[objective_function]['function'](X_Scaled)
# Convergence list to store best values
convergence = []
time_conv = []
START_TIME = time.time()
# with gr.Progress(track_tqdm=True) as progress:
# Optimization Loop
for ii in progress.tqdm(range(iteration_input)): # Example with 100 iterations
# (0) Get the updated data for this iteration
X_scaled = objective_functions[objective_function]['scaling'](trained_X)
trained_gx, trained_Y = objective_functions[objective_function]['function'](X_scaled)
# (1) Randomly sample Xpen
X_pen = torch.rand(1000,trained_X.shape[1])
# (2) PFN inference phase with EI
default_model = 'final_models/model_hebo_morebudget_9_unused_features_3.pt'
ei, p_feas = Rosen_PFN_Parallel(default_model,
trained_X,
trained_Y,
trained_gx,
X_pen,
'power',
'ei'
)
# Calculating CEI
CEI = ei
for jj in range(p_feas.shape[1]):
CEI = CEI*p_feas[:,jj]
# (4) Get the next search value
rec_idx = torch.argmax(CEI)
best_candidate = X_pen[rec_idx,:].unsqueeze(0)
# (5) Append the next search point
trained_X = torch.cat([trained_X, best_candidate])
################################################################################
# This is just for visualizing the best value.
# This section can be remove for pure optimization purpose
Current_X = objective_functions[objective_function]['scaling'](trained_X)
Current_GX, Current_Y = objective_functions[objective_function]['function'](Current_X)
if ((Current_GX<=0).all(dim=1)).any():
Current_BEST = torch.max(Current_Y[(Current_GX<=0).all(dim=1)])
else:
Current_BEST = Prev_BEST
################################################################################
# (ii) Convergence tracking (assuming the best Y is to be maximized)
# if Current_BEST != -1e10:
# print(Current_BEST)
# print(convergence)
convergence.append(Current_BEST.abs())
time_conv.append(time.time() - START_TIME)
# Timing
END_TIME = time.time()
TOTAL_TIME = END_TIME - START_TIME
# Website visualization
# (i) Radar chart for trained_X
radar_chart = None
# radar_chart = create_radar_chart(X_scaled)
# (ii) Convergence tracking (assuming the best Y is to be maximized)
convergence_plot = create_convergence_plot(objective_function, iteration_input,
time_conv,
convergence, TOTAL_TIME)
return convergence_plot
# return radar_chart, convergence_plot
def create_radar_chart(X_scaled):
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
labels = [f'x{i+1}' for i in range(X_scaled.shape[1])]
values = X_scaled.mean(dim=0).numpy()
num_vars = len(labels)
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
values = np.concatenate((values, [values[0]]))
angles += angles[:1]
ax.fill(angles, values, color='green', alpha=0.25)
ax.plot(angles, values, color='green', linewidth=2)
ax.set_yticklabels([])
ax.set_xticks(angles[:-1])
# ax.set_xticklabels(labels)
ax.set_xticklabels([f'{label}\n({value:.2f})' for label, value in zip(labels, values[:-1])]) # Show values
ax.set_title("Selected Design", size=15, color='black', y=1.1)
plt.close(fig)
return fig
def create_convergence_plot(objective_function, iteration_input, time_conv, convergence, TOTAL_TIME):
fig, ax = plt.subplots()
# Realtime optimization data
ax.plot(time_conv, convergence, '^-', label='PFN-CBO (Realtime)' )
# Stored GP data
if objective_function=="CantileverBeam.png":
GP_TIME = torch.load('CantileverBeam_CEI_Avg_Time.pt')
GP_OBJ = torch.load('CantileverBeam_CEI_Avg_Obj.pt')
elif objective_function=="CompressionSpring.png":
GP_TIME = torch.load('CompressionSpring_CEI_Avg_Time.pt')
GP_OBJ = torch.load('CompressionSpring_CEI_Avg_Obj.pt')
elif objective_function=="HeatExchanger.png":
GP_TIME = torch.load('HeatExchanger_CEI_Avg_Time.pt')
GP_OBJ = torch.load('HeatExchanger_CEI_Avg_Obj.pt')
elif objective_function=="ThreeTruss.png":
GP_TIME = torch.load('ThreeTruss_CEI_Avg_Time.pt')
GP_OBJ = torch.load('ThreeTruss_CEI_Avg_Obj.pt')
elif objective_function=="Reinforcement.png":
GP_TIME = torch.load('ReinforcedConcreteBeam_CEI_Avg_Time.pt')
GP_OBJ = torch.load('ReinforcedConcreteBeam_CEI_Avg_Obj.pt')
elif objective_function=="PressureVessel.png":
GP_TIME = torch.load('PressureVessel_CEI_Avg_Time.pt')
GP_OBJ = torch.load('PressureVessel_CEI_Avg_Obj.pt')
elif objective_function=="SpeedReducer.png":
GP_TIME = torch.load('SpeedReducer_CEI_Avg_Time.pt')
GP_OBJ = torch.load('SpeedReducer_CEI_Avg_Obj.pt')
elif objective_function=="WeldedBeam.png":
GP_TIME = torch.load('WeldedBeam_CEI_Avg_Time.pt')
GP_OBJ = torch.load('WeldedBeam_CEI_Avg_Obj.pt')
elif objective_function=="Car.png":
GP_TIME = torch.load('Car_CEI_Avg_Time.pt')
GP_OBJ = torch.load('Car_CEI_Avg_Obj.pt')
# Plot GP data
ax.plot(GP_TIME[:iteration_input], GP_OBJ[:iteration_input], '^-', label='GP-CBO (Data)' )
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Objective Value (Minimization)')
ax.set_title('Convergence Plot for {t} iterations'.format(t=iteration_input))
# ax.legend()
if objective_function=="CantileverBeam.png":
ax.axhline(y=50000, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="CompressionSpring.png":
ax.axhline(y=0, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="HeatExchanger.png":
ax.axhline(y=4700, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="ThreeTruss.png":
ax.axhline(y=262, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="Reinforcement.png":
ax.axhline(y=355, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="PressureVessel.png":
ax.axhline(y=5000, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="SpeedReducer.png":
ax.axhline(y=2650, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="WeldedBeam.png":
ax.axhline(y=3.3, color='red', linestyle='--', label='Optimal Value')
elif objective_function=="Car.png":
ax.axhline(y=25, color='red', linestyle='--', label='Optimal Value')
ax.legend(loc='best')
# ax.legend(loc='lower left')
# Add text to the top right corner of the plot
if len(convergence) == 0:
ax.text(0.5, 0.5, 'No Feasible Design Found', transform=ax.transAxes, fontsize=12,
verticalalignment='top', horizontalalignment='right')
plt.close(fig)
return fig
# Define available objective functions
objective_functions = {
# "ThreeTruss.png": {"image": "ThreeTruss.png",
# "function": ThreeTruss,
# "scaling": ThreeTruss_Scaling,
# "dim": 2},
"CompressionSpring.png": {"image": "CompressionSpring.png",
"function": CompressionSpring,
"scaling": CompressionSpring_Scaling,
"dim": 3},
"Reinforcement.png": {"image": "Reinforcement.png", "function": ReinforcedConcreteBeam, "scaling": ReinforcedConcreteBeam_Scaling, "dim": 3},
"PressureVessel.png": {"image": "PressureVessel.png", "function": PressureVessel, "scaling": PressureVessel_Scaling, "dim": 4},
"SpeedReducer.png": {"image": "SpeedReducer.png", "function": SpeedReducer, "scaling": SpeedReducer_Scaling, "dim": 7},
"WeldedBeam.png": {"image": "WeldedBeam.png", "function": WeldedBeam, "scaling": WeldedBeam_Scaling, "dim": 4},
"HeatExchanger.png": {"image": "HeatExchanger.png", "function": HeatExchanger, "scaling": HeatExchanger_Scaling, "dim": 8},
"CantileverBeam.png": {"image": "CantileverBeam.png", "function": CantileverBeam, "scaling": CantileverBeam_Scaling, "dim": 10},
"Car.png": {"image": "Car.png", "function": Car, "scaling": Car_Scaling, "dim": 11},
}
# Extract just the image paths for the gallery
image_paths = [key for key in objective_functions]
def submit_action(objective_function_choices, iteration_input):
# print(iteration_input)
# print(len(objective_function_choices))
# print(objective_functions[objective_function_choices]['function'])
if len(objective_function_choices)>0:
selected_function = objective_functions[objective_function_choices]['function']
return optimize(objective_function_choices, iteration_input)
return None
# Function to clear the output
def clear_output():
# print(gallery.selected_index)
return gr.update(value=[], selected=None), None, 15, gr.Markdown(""), 'Formulation_default.png'
def reset_gallery():
return gr.update(value=image_paths)
with gr.Blocks() as demo:
# Centered Title and Description using gr.HTML
gr.HTML(
"""
<div style="text-align: center;">
<p style="text-align: center; font-size:30px;"><b>
Constrained Bayesian Optimization with Pre-trained Transformers
</b></p>
<p style="text-align: center; font-size:18px;"><b>
Paper: <a href="https://arxiv.org/abs/2404.04495">
Fast and Accurate Bayesian Optimization with Pre-trained Transformers for Constrained Engineering Problems</a>
</b></p>
<p style="text-align: left;font-size:18px;">
Explore our interactive demo that uses PFN (Prior-Data Fitted Networks) for solving constrained Bayesian optimization problems!
</p>
<p style="text-align: left;font-size:24px;"><b>
Get Started:
</b> </p>
<p style="text-align: left;font-size:18px;">
<ol style="text-align: left;font-size:18px;text-indent: 30px;">
<li> <b>Select a Problem:</b> Click on an image from the problem gallery to choose your objective function. </li>
<li> <b>Set Iterations:</b> Adjust the slider to set the number of iterations for the optimization process. </li>
<li> <b>Run Optimization:</b> Click "Submit" to start the optimization. Use "Clear" if you need to reselect your parameters. </li>
</ol>
</p>
</div>
"""
)
gr.HTML(
"""
<p style="text-align: left;font-size:24px;"><b>
Result Display:
</b> </p>
<p style="text-align: left;font-size:18px;">
<ol style="text-align: left;font-size:18px;text-indent: 30px;">
<li> <b>Panel Display:</b> Shows the problem formulation and the optimization results. </li>
<li> <b>Convergence Plot:</b> Visualizes the best observed objective against the algorithm's runtime over the chosen iterations. </li>
<ul>
<li> <b>PFN-CBO:</b> Displays results from real-time optimization. </li>
<li> <b>GP-CBO:</b> Provides pre-computed data from our past experiments, as GP real-time runs are impractical for a demo. </li>
</ul>
</ol>
</p>
"""
)
with gr.Row():
with gr.Column(variant='compact'):
# gr.Markdown("# Inputs: ")
with gr.Row():
gr.Markdown("## Select a problem (objective): ")
img_key = gr.Markdown(value="", visible=False)
gallery = gr.Gallery(value=image_paths, label="Objectives",
# height = 450,
object_fit='contain',
columns=3, rows=3, elem_id="gallery")
gr.Markdown("## Enter iteration Number: ")
iteration_input = gr.Slider(label="Iterations:", minimum=15, maximum=50, step=1, value=15)
# Row for the Clear and Submit buttons
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Submit", variant="primary")
with gr.Column():
# gr.Markdown("# Outputs: ")
gr.Markdown("""
## Convergence Plot:
""")
convergence_plot = gr.Plot(label="Convergence Plot")
gr.Markdown("")
gr.Markdown("## Problem formulation: ")
formulation = gr.Image(value='Formulation_default.png', label="Eq")
def handle_select(evt: gr.SelectData):
selected_image = evt.value
key = evt.value['image']['orig_name']
if key=="CantileverBeam.png":
formulation = 'Cantilever_formulation.png'
elif key=="CompressionSpring.png":
formulation = 'Compressed_Formulation.png'
elif key=="HeatExchanger.png":
formulation = 'Heat_Formulation.png'
elif key=="Reinforcement.png":
formulation = 'Reinforce_Formulation.png'
elif key=="PressureVessel.png":
formulation = 'Pressure_Formulation.png'
elif key=="SpeedReducer.png":
formulation = 'Speed_Formulation.png'
elif key=="WeldedBeam.png":
formulation = 'Welded_Formulation.png'
elif key=="Car.png":
formulation = 'Car_Formulation_2.png'
# formulation = 'Test_formulation.png'
# print('here')
# print(key)
return key, formulation
gallery.select(fn=handle_select, inputs=None, outputs=[img_key, formulation])
submit_button.click(
submit_action,
inputs=[img_key, iteration_input],
# outputs= [radar_plot, convergence_plot],
outputs= convergence_plot,
# progress=True # Enable progress tracking
)
clear_button.click(
clear_output,
inputs=None,
outputs=[gallery, convergence_plot, iteration_input, img_key, formulation]
).then(
# Step 2: Reset the gallery to the original list
reset_gallery,
inputs=None,
outputs=gallery
)
demo.launch(share=True)