Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from test_functions.Ackley10D import * | |
from test_functions.Ackley2D import * | |
from test_functions.Ackley6D import * | |
from test_functions.HeatExchanger import * | |
from test_functions.CantileverBeam import * | |
from test_functions.Car import * | |
from test_functions.CompressionSpring import * | |
from test_functions.GKXWC1 import * | |
from test_functions.GKXWC2 import * | |
from test_functions.HeatExchanger import * | |
from test_functions.JLH1 import * | |
from test_functions.JLH2 import * | |
from test_functions.KeaneBump import * | |
from test_functions.GKXWC1 import * | |
from test_functions.GKXWC2 import * | |
from test_functions.PressureVessel import * | |
from test_functions.ReinforcedConcreteBeam import * | |
from test_functions.SpeedReducer import * | |
from test_functions.ThreeTruss import * | |
from test_functions.WeldedBeam import * | |
# Import other objective functions as needed | |
import time | |
from Rosen_PFN4BO import * | |
from PIL import Image | |
def s(input_string): | |
return input_string | |
def optimize(objective_function, iteration_input, progress=gr.Progress()): | |
# print(objective_function) | |
# Variable setup | |
Current_BEST = torch.tensor( -1e10 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -1e10 ) | |
if objective_function=="CantileverBeam.png": | |
Current_BEST = torch.tensor( -82500 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -82500 ) | |
elif objective_function=="CompressionSpring.png": | |
Current_BEST = torch.tensor( -8 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -8 ) | |
elif objective_function=="HeatExchanger.png": | |
Current_BEST = torch.tensor( -30000 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -30000 ) | |
elif objective_function=="ThreeTruss.png": | |
Current_BEST = torch.tensor( -300 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -300 ) | |
elif objective_function=="Reinforcement.png": | |
Current_BEST = torch.tensor( -440 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -440 ) | |
elif objective_function=="PressureVessel.png": | |
Current_BEST = torch.tensor( -40000 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -40000 ) | |
elif objective_function=="SpeedReducer.png": | |
Current_BEST = torch.tensor( -3200 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -3200 ) | |
elif objective_function=="WeldedBeam.png": | |
Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -35 ) | |
elif objective_function=="Car.png": | |
Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number | |
Prev_BEST = torch.tensor( -35 ) | |
# Initial random samples | |
# print(objective_functions) | |
trained_X = torch.rand(20, objective_functions[objective_function]['dim']) | |
# Scale it to the domain of interest using the selected function | |
# print(objective_function) | |
X_Scaled = objective_functions[objective_function]['scaling'](trained_X) | |
# Get the constraints and objective | |
trained_gx, trained_Y = objective_functions[objective_function]['function'](X_Scaled) | |
# Convergence list to store best values | |
convergence = [] | |
time_conv = [] | |
START_TIME = time.time() | |
# with gr.Progress(track_tqdm=True) as progress: | |
# Optimization Loop | |
for ii in progress.tqdm(range(iteration_input)): # Example with 100 iterations | |
# (0) Get the updated data for this iteration | |
X_scaled = objective_functions[objective_function]['scaling'](trained_X) | |
trained_gx, trained_Y = objective_functions[objective_function]['function'](X_scaled) | |
# (1) Randomly sample Xpen | |
X_pen = torch.rand(1000,trained_X.shape[1]) | |
# (2) PFN inference phase with EI | |
default_model = 'final_models/model_hebo_morebudget_9_unused_features_3.pt' | |
ei, p_feas = Rosen_PFN_Parallel(default_model, | |
trained_X, | |
trained_Y, | |
trained_gx, | |
X_pen, | |
'power', | |
'ei' | |
) | |
# Calculating CEI | |
CEI = ei | |
for jj in range(p_feas.shape[1]): | |
CEI = CEI*p_feas[:,jj] | |
# (4) Get the next search value | |
rec_idx = torch.argmax(CEI) | |
best_candidate = X_pen[rec_idx,:].unsqueeze(0) | |
# (5) Append the next search point | |
trained_X = torch.cat([trained_X, best_candidate]) | |
################################################################################ | |
# This is just for visualizing the best value. | |
# This section can be remove for pure optimization purpose | |
Current_X = objective_functions[objective_function]['scaling'](trained_X) | |
Current_GX, Current_Y = objective_functions[objective_function]['function'](Current_X) | |
if ((Current_GX<=0).all(dim=1)).any(): | |
Current_BEST = torch.max(Current_Y[(Current_GX<=0).all(dim=1)]) | |
else: | |
Current_BEST = Prev_BEST | |
################################################################################ | |
# (ii) Convergence tracking (assuming the best Y is to be maximized) | |
# if Current_BEST != -1e10: | |
# print(Current_BEST) | |
# print(convergence) | |
convergence.append(Current_BEST.abs()) | |
time_conv.append(time.time() - START_TIME) | |
# Timing | |
END_TIME = time.time() | |
TOTAL_TIME = END_TIME - START_TIME | |
# Website visualization | |
# (i) Radar chart for trained_X | |
radar_chart = None | |
# radar_chart = create_radar_chart(X_scaled) | |
# (ii) Convergence tracking (assuming the best Y is to be maximized) | |
convergence_plot = create_convergence_plot(objective_function, iteration_input, | |
time_conv, | |
convergence, TOTAL_TIME) | |
return convergence_plot | |
# return radar_chart, convergence_plot | |
def create_radar_chart(X_scaled): | |
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) | |
labels = [f'x{i+1}' for i in range(X_scaled.shape[1])] | |
values = X_scaled.mean(dim=0).numpy() | |
num_vars = len(labels) | |
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() | |
values = np.concatenate((values, [values[0]])) | |
angles += angles[:1] | |
ax.fill(angles, values, color='green', alpha=0.25) | |
ax.plot(angles, values, color='green', linewidth=2) | |
ax.set_yticklabels([]) | |
ax.set_xticks(angles[:-1]) | |
# ax.set_xticklabels(labels) | |
ax.set_xticklabels([f'{label}\n({value:.2f})' for label, value in zip(labels, values[:-1])]) # Show values | |
ax.set_title("Selected Design", size=15, color='black', y=1.1) | |
plt.close(fig) | |
return fig | |
def create_convergence_plot(objective_function, iteration_input, time_conv, convergence, TOTAL_TIME): | |
fig, ax = plt.subplots() | |
# Realtime optimization data | |
ax.plot(time_conv, convergence, '^-', label='PFN-CBO (Realtime)' ) | |
# Stored GP data | |
if objective_function=="CantileverBeam.png": | |
GP_TIME = torch.load('CantileverBeam_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('CantileverBeam_CEI_Avg_Obj.pt') | |
elif objective_function=="CompressionSpring.png": | |
GP_TIME = torch.load('CompressionSpring_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('CompressionSpring_CEI_Avg_Obj.pt') | |
elif objective_function=="HeatExchanger.png": | |
GP_TIME = torch.load('HeatExchanger_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('HeatExchanger_CEI_Avg_Obj.pt') | |
elif objective_function=="ThreeTruss.png": | |
GP_TIME = torch.load('ThreeTruss_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('ThreeTruss_CEI_Avg_Obj.pt') | |
elif objective_function=="Reinforcement.png": | |
GP_TIME = torch.load('ReinforcedConcreteBeam_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('ReinforcedConcreteBeam_CEI_Avg_Obj.pt') | |
elif objective_function=="PressureVessel.png": | |
GP_TIME = torch.load('PressureVessel_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('PressureVessel_CEI_Avg_Obj.pt') | |
elif objective_function=="SpeedReducer.png": | |
GP_TIME = torch.load('SpeedReducer_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('SpeedReducer_CEI_Avg_Obj.pt') | |
elif objective_function=="WeldedBeam.png": | |
GP_TIME = torch.load('WeldedBeam_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('WeldedBeam_CEI_Avg_Obj.pt') | |
elif objective_function=="Car.png": | |
GP_TIME = torch.load('Car_CEI_Avg_Time.pt') | |
GP_OBJ = torch.load('Car_CEI_Avg_Obj.pt') | |
# Plot GP data | |
ax.plot(GP_TIME[:iteration_input], GP_OBJ[:iteration_input], '^-', label='GP-CBO (Data)' ) | |
ax.set_xlabel('Time (seconds)') | |
ax.set_ylabel('Objective Value (Minimization)') | |
ax.set_title('Convergence Plot for {t} iterations'.format(t=iteration_input)) | |
# ax.legend() | |
if objective_function=="CantileverBeam.png": | |
ax.axhline(y=50000, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="CompressionSpring.png": | |
ax.axhline(y=0, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="HeatExchanger.png": | |
ax.axhline(y=4700, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="ThreeTruss.png": | |
ax.axhline(y=262, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="Reinforcement.png": | |
ax.axhline(y=355, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="PressureVessel.png": | |
ax.axhline(y=5000, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="SpeedReducer.png": | |
ax.axhline(y=2650, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="WeldedBeam.png": | |
ax.axhline(y=3.3, color='red', linestyle='--', label='Optimal Value') | |
elif objective_function=="Car.png": | |
ax.axhline(y=25, color='red', linestyle='--', label='Optimal Value') | |
ax.legend(loc='best') | |
# ax.legend(loc='lower left') | |
# Add text to the top right corner of the plot | |
if len(convergence) == 0: | |
ax.text(0.5, 0.5, 'No Feasible Design Found', transform=ax.transAxes, fontsize=12, | |
verticalalignment='top', horizontalalignment='right') | |
plt.close(fig) | |
return fig | |
# Define available objective functions | |
objective_functions = { | |
# "ThreeTruss.png": {"image": "ThreeTruss.png", | |
# "function": ThreeTruss, | |
# "scaling": ThreeTruss_Scaling, | |
# "dim": 2}, | |
"CompressionSpring.png": {"image": "CompressionSpring.png", | |
"function": CompressionSpring, | |
"scaling": CompressionSpring_Scaling, | |
"dim": 3}, | |
"Reinforcement.png": {"image": "Reinforcement.png", "function": ReinforcedConcreteBeam, "scaling": ReinforcedConcreteBeam_Scaling, "dim": 3}, | |
"PressureVessel.png": {"image": "PressureVessel.png", "function": PressureVessel, "scaling": PressureVessel_Scaling, "dim": 4}, | |
"SpeedReducer.png": {"image": "SpeedReducer.png", "function": SpeedReducer, "scaling": SpeedReducer_Scaling, "dim": 7}, | |
"WeldedBeam.png": {"image": "WeldedBeam.png", "function": WeldedBeam, "scaling": WeldedBeam_Scaling, "dim": 4}, | |
"HeatExchanger.png": {"image": "HeatExchanger.png", "function": HeatExchanger, "scaling": HeatExchanger_Scaling, "dim": 8}, | |
"CantileverBeam.png": {"image": "CantileverBeam.png", "function": CantileverBeam, "scaling": CantileverBeam_Scaling, "dim": 10}, | |
"Car.png": {"image": "Car.png", "function": Car, "scaling": Car_Scaling, "dim": 11}, | |
} | |
# Extract just the image paths for the gallery | |
image_paths = [key for key in objective_functions] | |
def submit_action(objective_function_choices, iteration_input): | |
# print(iteration_input) | |
# print(len(objective_function_choices)) | |
# print(objective_functions[objective_function_choices]['function']) | |
if len(objective_function_choices)>0: | |
selected_function = objective_functions[objective_function_choices]['function'] | |
return optimize(objective_function_choices, iteration_input) | |
return None | |
# Function to clear the output | |
def clear_output(): | |
# print(gallery.selected_index) | |
return gr.update(value=[], selected=None), None, 15, gr.Markdown(""), 'Formulation_default.png' | |
def reset_gallery(): | |
return gr.update(value=image_paths) | |
with gr.Blocks() as demo: | |
# Centered Title and Description using gr.HTML | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<p style="text-align: center; font-size:30px;"><b> | |
Constrained Bayesian Optimization with Pre-trained Transformers | |
</b></p> | |
<p style="text-align: center; font-size:18px;"><b> | |
Paper: <a href="https://arxiv.org/abs/2404.04495"> | |
Fast and Accurate Bayesian Optimization with Pre-trained Transformers for Constrained Engineering Problems</a> | |
</b></p> | |
<p style="text-align: left;font-size:18px;"> | |
Explore our interactive demo that uses PFN (Prior-Data Fitted Networks) for solving constrained Bayesian optimization problems! | |
</p> | |
<p style="text-align: left;font-size:24px;"><b> | |
Get Started: | |
</b> </p> | |
<p style="text-align: left;font-size:18px;"> | |
<ol style="text-align: left;font-size:18px;text-indent: 30px;"> | |
<li> <b>Select a Problem:</b> Click on an image from the problem gallery to choose your objective function. </li> | |
<li> <b>Set Iterations:</b> Adjust the slider to set the number of iterations for the optimization process. </li> | |
<li> <b>Run Optimization:</b> Click "Submit" to start the optimization. Use "Clear" if you need to reselect your parameters. </li> | |
</ol> | |
</p> | |
</div> | |
""" | |
) | |
gr.HTML( | |
""" | |
<p style="text-align: left;font-size:24px;"><b> | |
Result Display: | |
</b> </p> | |
<p style="text-align: left;font-size:18px;"> | |
<ol style="text-align: left;font-size:18px;text-indent: 30px;"> | |
<li> <b>Panel Display:</b> Shows the problem formulation and the optimization results. </li> | |
<li> <b>Convergence Plot:</b> Visualizes the best observed objective against the algorithm's runtime over the chosen iterations. </li> | |
<ul> | |
<li> <b>PFN-CBO:</b> Displays results from real-time optimization. </li> | |
<li> <b>GP-CBO:</b> Provides pre-computed data from our past experiments, as GP real-time runs are impractical for a demo. </li> | |
</ul> | |
</ol> | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(variant='compact'): | |
# gr.Markdown("# Inputs: ") | |
with gr.Row(): | |
gr.Markdown("## Select a problem (objective): ") | |
img_key = gr.Markdown(value="", visible=False) | |
gallery = gr.Gallery(value=image_paths, label="Objectives", | |
# height = 450, | |
object_fit='contain', | |
columns=3, rows=3, elem_id="gallery") | |
gr.Markdown("## Enter iteration Number: ") | |
iteration_input = gr.Slider(label="Iterations:", minimum=15, maximum=50, step=1, value=15) | |
# Row for the Clear and Submit buttons | |
with gr.Row(): | |
clear_button = gr.Button("Clear") | |
submit_button = gr.Button("Submit", variant="primary") | |
with gr.Column(): | |
# gr.Markdown("# Outputs: ") | |
gr.Markdown(""" | |
## Convergence Plot: | |
""") | |
convergence_plot = gr.Plot(label="Convergence Plot") | |
gr.Markdown("") | |
gr.Markdown("## Problem formulation: ") | |
formulation = gr.Image(value='Formulation_default.png', label="Eq") | |
def handle_select(evt: gr.SelectData): | |
selected_image = evt.value | |
key = evt.value['image']['orig_name'] | |
if key=="CantileverBeam.png": | |
formulation = 'Cantilever_formulation.png' | |
elif key=="CompressionSpring.png": | |
formulation = 'Compressed_Formulation.png' | |
elif key=="HeatExchanger.png": | |
formulation = 'Heat_Formulation.png' | |
elif key=="Reinforcement.png": | |
formulation = 'Reinforce_Formulation.png' | |
elif key=="PressureVessel.png": | |
formulation = 'Pressure_Formulation.png' | |
elif key=="SpeedReducer.png": | |
formulation = 'Speed_Formulation.png' | |
elif key=="WeldedBeam.png": | |
formulation = 'Welded_Formulation.png' | |
elif key=="Car.png": | |
formulation = 'Car_Formulation_2.png' | |
# formulation = 'Test_formulation.png' | |
# print('here') | |
# print(key) | |
return key, formulation | |
gallery.select(fn=handle_select, inputs=None, outputs=[img_key, formulation]) | |
submit_button.click( | |
submit_action, | |
inputs=[img_key, iteration_input], | |
# outputs= [radar_plot, convergence_plot], | |
outputs= convergence_plot, | |
# progress=True # Enable progress tracking | |
) | |
clear_button.click( | |
clear_output, | |
inputs=None, | |
outputs=[gallery, convergence_plot, iteration_input, img_key, formulation] | |
).then( | |
# Step 2: Reset the gallery to the original list | |
reset_gallery, | |
inputs=None, | |
outputs=gallery | |
) | |
demo.launch(share=True) |