rosenyu commited on
Commit
bbfe1cd
·
verified ·
1 Parent(s): 165ee00

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +468 -0
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from test_functions.Ackley10D import *
6
+ from test_functions.Ackley2D import *
7
+ from test_functions.Ackley6D import *
8
+ from test_functions.HeatExchanger import *
9
+ from test_functions.CantileverBeam import *
10
+ from test_functions.Car import *
11
+ from test_functions.CompressionSpring import *
12
+ from test_functions.GKXWC1 import *
13
+ from test_functions.GKXWC2 import *
14
+ from test_functions.HeatExchanger import *
15
+ from test_functions.JLH1 import *
16
+ from test_functions.JLH2 import *
17
+ from test_functions.KeaneBump import *
18
+ from test_functions.GKXWC1 import *
19
+ from test_functions.GKXWC2 import *
20
+ from test_functions.PressureVessel import *
21
+ from test_functions.ReinforcedConcreteBeam import *
22
+ from test_functions.SpeedReducer import *
23
+ from test_functions.ThreeTruss import *
24
+ from test_functions.WeldedBeam import *
25
+ # Import other objective functions as needed
26
+ import time
27
+
28
+ from Rosen_PFN4BO import *
29
+ from PIL import Image
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+ def s(input_string):
46
+ return input_string
47
+
48
+
49
+
50
+
51
+ def optimize(objective_function, iteration_input, progress=gr.Progress()):
52
+
53
+ print(objective_function)
54
+
55
+ # Variable setup
56
+ Current_BEST = torch.tensor( -1e10 ) # Some arbitrary very small number
57
+ Prev_BEST = torch.tensor( -1e10 )
58
+
59
+ if objective_function=="CantileverBeam.png":
60
+ Current_BEST = torch.tensor( -82500 ) # Some arbitrary very small number
61
+ Prev_BEST = torch.tensor( -82500 )
62
+ elif objective_function=="CompressionSpring.png":
63
+ Current_BEST = torch.tensor( -8 ) # Some arbitrary very small number
64
+ Prev_BEST = torch.tensor( -8 )
65
+ elif objective_function=="HeatExchanger.png":
66
+ Current_BEST = torch.tensor( -30000 ) # Some arbitrary very small number
67
+ Prev_BEST = torch.tensor( -30000 )
68
+ elif objective_function=="ThreeTruss.png":
69
+ Current_BEST = torch.tensor( -300 ) # Some arbitrary very small number
70
+ Prev_BEST = torch.tensor( -300 )
71
+ elif objective_function=="Reinforcement.png":
72
+ Current_BEST = torch.tensor( -440 ) # Some arbitrary very small number
73
+ Prev_BEST = torch.tensor( -440 )
74
+ elif objective_function=="PressureVessel.png":
75
+ Current_BEST = torch.tensor( -40000 ) # Some arbitrary very small number
76
+ Prev_BEST = torch.tensor( -40000 )
77
+ elif objective_function=="SpeedReducer.png":
78
+ Current_BEST = torch.tensor( -3200 ) # Some arbitrary very small number
79
+ Prev_BEST = torch.tensor( -3200 )
80
+ elif objective_function=="WeldedBeam.png":
81
+ Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number
82
+ Prev_BEST = torch.tensor( -35 )
83
+ elif objective_function=="Car.png":
84
+ Current_BEST = torch.tensor( -35 ) # Some arbitrary very small number
85
+ Prev_BEST = torch.tensor( -35 )
86
+
87
+ # Initial random samples
88
+ # print(objective_functions)
89
+ trained_X = torch.rand(20, objective_functions[objective_function]['dim'])
90
+
91
+ # Scale it to the domain of interest using the selected function
92
+ # print(objective_function)
93
+ X_Scaled = objective_functions[objective_function]['scaling'](trained_X)
94
+
95
+ # Get the constraints and objective
96
+ trained_gx, trained_Y = objective_functions[objective_function]['function'](X_Scaled)
97
+
98
+ # Convergence list to store best values
99
+ convergence = []
100
+ time_conv = []
101
+
102
+ START_TIME = time.time()
103
+
104
+
105
+ # with gr.Progress(track_tqdm=True) as progress:
106
+
107
+
108
+ # Optimization Loop
109
+ for ii in progress.tqdm(range(iteration_input)): # Example with 100 iterations
110
+
111
+ # (0) Get the updated data for this iteration
112
+ X_scaled = objective_functions[objective_function]['scaling'](trained_X)
113
+ trained_gx, trained_Y = objective_functions[objective_function]['function'](X_scaled)
114
+
115
+ # (1) Randomly sample Xpen
116
+ X_pen = torch.rand(1000,trained_X.shape[1])
117
+
118
+ # (2) PFN inference phase with EI
119
+ default_model = 'final_models/model_hebo_morebudget_9_unused_features_3.pt'
120
+
121
+ ei, p_feas = Rosen_PFN_Parallel(default_model,
122
+ trained_X,
123
+ trained_Y,
124
+ trained_gx,
125
+ X_pen,
126
+ 'power',
127
+ 'ei'
128
+ )
129
+
130
+ # Calculating CEI
131
+ CEI = ei
132
+ for jj in range(p_feas.shape[1]):
133
+ CEI = CEI*p_feas[:,jj]
134
+
135
+ # (4) Get the next search value
136
+ rec_idx = torch.argmax(CEI)
137
+ best_candidate = X_pen[rec_idx,:].unsqueeze(0)
138
+
139
+ # (5) Append the next search point
140
+ trained_X = torch.cat([trained_X, best_candidate])
141
+
142
+
143
+ ################################################################################
144
+ # This is just for visualizing the best value.
145
+ # This section can be remove for pure optimization purpose
146
+ Current_X = objective_functions[objective_function]['scaling'](trained_X)
147
+ Current_GX, Current_Y = objective_functions[objective_function]['function'](Current_X)
148
+ if ((Current_GX<=0).all(dim=1)).any():
149
+ Current_BEST = torch.max(Current_Y[(Current_GX<=0).all(dim=1)])
150
+ else:
151
+ Current_BEST = Prev_BEST
152
+ ################################################################################
153
+
154
+ # (ii) Convergence tracking (assuming the best Y is to be maximized)
155
+ # if Current_BEST != -1e10:
156
+ print(Current_BEST)
157
+ print(convergence)
158
+ convergence.append(Current_BEST.abs())
159
+ time_conv.append(time.time() - START_TIME)
160
+
161
+ # Timing
162
+ END_TIME = time.time()
163
+ TOTAL_TIME = END_TIME - START_TIME
164
+
165
+ # Website visualization
166
+ # (i) Radar chart for trained_X
167
+ radar_chart = None
168
+ # radar_chart = create_radar_chart(X_scaled)
169
+ # (ii) Convergence tracking (assuming the best Y is to be maximized)
170
+ convergence_plot = create_convergence_plot(objective_function, iteration_input,
171
+ time_conv,
172
+ convergence, TOTAL_TIME)
173
+
174
+
175
+ return convergence_plot
176
+ # return radar_chart, convergence_plot
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+ def create_radar_chart(X_scaled):
187
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))
188
+ labels = [f'x{i+1}' for i in range(X_scaled.shape[1])]
189
+ values = X_scaled.mean(dim=0).numpy()
190
+
191
+ num_vars = len(labels)
192
+ angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
193
+ values = np.concatenate((values, [values[0]]))
194
+ angles += angles[:1]
195
+
196
+ ax.fill(angles, values, color='green', alpha=0.25)
197
+ ax.plot(angles, values, color='green', linewidth=2)
198
+ ax.set_yticklabels([])
199
+ ax.set_xticks(angles[:-1])
200
+ # ax.set_xticklabels(labels)
201
+ ax.set_xticklabels([f'{label}\n({value:.2f})' for label, value in zip(labels, values[:-1])]) # Show values
202
+ ax.set_title("Selected Design", size=15, color='black', y=1.1)
203
+
204
+ plt.close(fig)
205
+ return fig
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+ def create_convergence_plot(objective_function, iteration_input, time_conv, convergence, TOTAL_TIME):
214
+ fig, ax = plt.subplots()
215
+
216
+ # Realtime optimization data
217
+ ax.plot(time_conv, convergence, '^-', label='PFN-CBO (Realtime)' )
218
+
219
+ # Stored GP data
220
+ if objective_function=="CantileverBeam.png":
221
+ GP_TIME = torch.load('CantileverBeam_CEI_Avg_Time.pt')
222
+ GP_OBJ = torch.load('CantileverBeam_CEI_Avg_Obj.pt')
223
+
224
+ elif objective_function=="CompressionSpring.png":
225
+ GP_TIME = torch.load('CompressionSpring_CEI_Avg_Time.pt')
226
+ GP_OBJ = torch.load('CompressionSpring_CEI_Avg_Obj.pt')
227
+
228
+ elif objective_function=="HeatExchanger.png":
229
+ GP_TIME = torch.load('HeatExchanger_CEI_Avg_Time.pt')
230
+ GP_OBJ = torch.load('HeatExchanger_CEI_Avg_Obj.pt')
231
+
232
+ elif objective_function=="ThreeTruss.png":
233
+ GP_TIME = torch.load('ThreeTruss_CEI_Avg_Time.pt')
234
+ GP_OBJ = torch.load('ThreeTruss_CEI_Avg_Obj.pt')
235
+
236
+ elif objective_function=="Reinforcement.png":
237
+ GP_TIME = torch.load('ReinforcedConcreteBeam_CEI_Avg_Time.pt')
238
+ GP_OBJ = torch.load('ReinforcedConcreteBeam_CEI_Avg_Obj.pt')
239
+
240
+ elif objective_function=="PressureVessel.png":
241
+ GP_TIME = torch.load('PressureVessel_CEI_Avg_Time.pt')
242
+ GP_OBJ = torch.load('PressureVessel_CEI_Avg_Obj.pt')
243
+
244
+ elif objective_function=="SpeedReducer.png":
245
+ GP_TIME = torch.load('SpeedReducer_CEI_Avg_Time.pt')
246
+ GP_OBJ = torch.load('SpeedReducer_CEI_Avg_Obj.pt')
247
+
248
+ elif objective_function=="WeldedBeam.png":
249
+ GP_TIME = torch.load('WeldedBeam_CEI_Avg_Time.pt')
250
+ GP_OBJ = torch.load('WeldedBeam_CEI_Avg_Obj.pt')
251
+
252
+ elif objective_function=="Car.png":
253
+ GP_TIME = torch.load('Car_CEI_Avg_Time.pt')
254
+ GP_OBJ = torch.load('Car_CEI_Avg_Obj.pt')
255
+
256
+ # Plot GP data
257
+ ax.plot(GP_TIME[:iteration_input], GP_OBJ[:iteration_input], '^-', label='GP-CBO (Data)' )
258
+
259
+
260
+ ax.set_xlabel('Time (seconds)')
261
+ ax.set_ylabel('Objective Value')
262
+ ax.set_title('Convergence Plot for {t} iterations'.format(t=iteration_input))
263
+ # ax.legend()
264
+
265
+ if objective_function=="CantileverBeam.png":
266
+ ax.axhline(y=50000, color='red', linestyle='--', label='Optimal Value')
267
+
268
+ elif objective_function=="CompressionSpring.png":
269
+ ax.axhline(y=0, color='red', linestyle='--', label='Optimal Value')
270
+
271
+ elif objective_function=="HeatExchanger.png":
272
+ ax.axhline(y=4700, color='red', linestyle='--', label='Optimal Value')
273
+
274
+ elif objective_function=="ThreeTruss.png":
275
+ ax.axhline(y=262, color='red', linestyle='--', label='Optimal Value')
276
+
277
+ elif objective_function=="Reinforcement.png":
278
+ ax.axhline(y=355, color='red', linestyle='--', label='Optimal Value')
279
+
280
+ elif objective_function=="PressureVessel.png":
281
+ ax.axhline(y=5000, color='red', linestyle='--', label='Optimal Value')
282
+
283
+ elif objective_function=="SpeedReducer.png":
284
+ ax.axhline(y=2650, color='red', linestyle='--', label='Optimal Value')
285
+
286
+ elif objective_function=="WeldedBeam.png":
287
+ ax.axhline(y=6, color='red', linestyle='--', label='Optimal Value')
288
+
289
+ elif objective_function=="Car.png":
290
+ ax.axhline(y=25, color='red', linestyle='--', label='Optimal Value')
291
+
292
+
293
+ ax.legend(loc='best')
294
+ # ax.legend(loc='lower left')
295
+
296
+
297
+ # Add text to the top right corner of the plot
298
+ if len(convergence) == 0:
299
+ ax.text(0.5, 0.5, 'No Feasible Design Found', transform=ax.transAxes, fontsize=12,
300
+ verticalalignment='top', horizontalalignment='right')
301
+
302
+
303
+ plt.close(fig)
304
+ return fig
305
+
306
+
307
+
308
+
309
+
310
+
311
+ # Define available objective functions
312
+ objective_functions = {
313
+ # "ThreeTruss.png": {"image": "ThreeTruss.png",
314
+ # "function": ThreeTruss,
315
+ # "scaling": ThreeTruss_Scaling,
316
+ # "dim": 2},
317
+ "CompressionSpring.png": {"image": "CompressionSpring.png",
318
+ "function": CompressionSpring,
319
+ "scaling": CompressionSpring_Scaling,
320
+ "dim": 3},
321
+ "Reinforcement.png": {"image": "Reinforcement.png", "function": ReinforcedConcreteBeam, "scaling": ReinforcedConcreteBeam_Scaling, "dim": 3},
322
+ "PressureVessel.png": {"image": "PressureVessel.png", "function": PressureVessel, "scaling": PressureVessel_Scaling, "dim": 4},
323
+ "SpeedReducer.png": {"image": "SpeedReducer.png", "function": SpeedReducer, "scaling": SpeedReducer_Scaling, "dim": 7},
324
+ "WeldedBeam.png": {"image": "WeldedBeam.png", "function": WeldedBeam, "scaling": WeldedBeam_Scaling, "dim": 4},
325
+ "HeatExchanger.png": {"image": "HeatExchanger.png", "function": HeatExchanger, "scaling": HeatExchanger_Scaling, "dim": 8},
326
+ "CantileverBeam.png": {"image": "CantileverBeam.png", "function": CantileverBeam, "scaling": CantileverBeam_Scaling, "dim": 10},
327
+ "Car.png": {"image": "Car.png", "function": Car, "scaling": Car_Scaling, "dim": 11},
328
+ }
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
352
+
353
+ # Extract just the image paths for the gallery
354
+ image_paths = [key for key in objective_functions]
355
+
356
+
357
+ def submit_action(objective_function_choices, iteration_input):
358
+ # print(iteration_input)
359
+ # print(len(objective_function_choices))
360
+ # print(objective_functions[objective_function_choices]['function'])
361
+ if len(objective_function_choices)>0:
362
+ selected_function = objective_functions[objective_function_choices]['function']
363
+ return optimize(objective_function_choices, iteration_input)
364
+ return None
365
+
366
+ # Function to clear the output
367
+ def clear_output():
368
+ # print(gallery.selected_index)
369
+
370
+ return gr.update(value=[], selected=None), None, 15, gr.Markdown(""), 'Test_formulation_default.png'
371
+
372
+ def reset_gallery():
373
+ return gr.update(value=image_paths)
374
+
375
+
376
+ with gr.Blocks() as demo:
377
+ # Centered Title and Description using gr.HTML
378
+ gr.HTML(
379
+ """
380
+ <div style="text-align: center;">
381
+ <h1>Pre-trained Transformer for Constrained Bayesian Optimization</h1>
382
+ <h4>Paper: <a href="https://arxiv.org/abs/2404.04495">
383
+ Fast and Accurate Bayesian Optimization with Pre-trained Transformers for Constrained Engineering Problems</a>
384
+ </h4>
385
+
386
+ <p style="text-align: left;">This is a demo for Bayesian Optimization using PFN (Prior-Data Fitted Networks).
387
+ Select your objective function by clicking on one of the check boxes below, then enter the iteration number to run the optimization process.
388
+ The results will be visualized in the radar chart and convergence plot.</p>
389
+
390
+
391
+
392
+
393
+ </div>
394
+ """
395
+ )
396
+
397
+
398
+ with gr.Row():
399
+
400
+
401
+ with gr.Column(variant='compact'):
402
+ # gr.Markdown("# Inputs: ")
403
+
404
+ with gr.Row():
405
+ gr.Markdown("## Select a problem (objective): ")
406
+ img_key = gr.Markdown(value="", visible=False)
407
+
408
+ gallery = gr.Gallery(value=image_paths, label="Objective Functions",
409
+ # height = 450,
410
+ object_fit='contain',
411
+ columns=3, rows=3, elem_id="gallery")
412
+
413
+ gr.Markdown("## Enter iteration Number: ")
414
+ iteration_input = gr.Slider(label="Iterations:", minimum=15, maximum=50, step=1, value=15)
415
+
416
+
417
+ # Row for the Clear and Submit buttons
418
+ with gr.Row():
419
+ clear_button = gr.Button("Clear")
420
+ submit_button = gr.Button("Submit", variant="primary")
421
+
422
+ with gr.Column():
423
+ # gr.Markdown("# Outputs: ")
424
+ gr.Markdown("## Problem Formulation: ")
425
+ formulation = gr.Image(value='Formulation_default.png', height=150)
426
+ gr.Markdown("## Results: ")
427
+ gr.Markdown("The graph will plot the best observed data v.s. the time for the algorithm to run up until the iteration. The PFN-CBO shows the result of the realtime optimization running in the backend while the GP-CBO shows the stored data from our previous experiments since running GP-CBO will take longer time.")
428
+ convergence_plot = gr.Plot(label="Convergence Plot")
429
+
430
+
431
+
432
+ def handle_select(evt: gr.SelectData):
433
+ selected_image = evt.value
434
+ key = evt.value['image']['orig_name']
435
+ formulation = 'Test_formulation.png'
436
+ print('here')
437
+ print(key)
438
+
439
+ return key, formulation
440
+
441
+ gallery.select(fn=handle_select, inputs=None, outputs=[img_key, formulation])
442
+
443
+
444
+
445
+ submit_button.click(
446
+ submit_action,
447
+ inputs=[img_key, iteration_input],
448
+ # outputs= [radar_plot, convergence_plot],
449
+ outputs= convergence_plot,
450
+
451
+ # progress=True # Enable progress tracking
452
+
453
+ )
454
+
455
+ clear_button.click(
456
+ clear_output,
457
+ inputs=None,
458
+ outputs=[gallery, convergence_plot, iteration_input, img_key, formulation]
459
+ ).then(
460
+ # Step 2: Reset the gallery to the original list
461
+ reset_gallery,
462
+ inputs=None,
463
+ outputs=gallery
464
+ )
465
+
466
+
467
+
468
+ demo.launch()