{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4453c5ad-ec87-42e0-a6d5-e3fd3593aec2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7891\n", "Running on public URL: https://f714b6f956fb581264.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from test_functions.Ackley10D import *\n", "from test_functions.Ackley2D import *\n", "from test_functions.Ackley6D import *\n", "from test_functions.HeatExchanger import *\n", "from test_functions.CantileverBeam import *\n", "from test_functions.Car import *\n", "from test_functions.CompressionSpring import *\n", "from test_functions.GKXWC1 import *\n", "from test_functions.GKXWC2 import *\n", "from test_functions.HeatExchanger import *\n", "from test_functions.JLH1 import *\n", "from test_functions.JLH2 import *\n", "from test_functions.KeaneBump import *\n", "from test_functions.GKXWC1 import *\n", "from test_functions.GKXWC2 import *\n", "from test_functions.PressureVessel import *\n", "from test_functions.ReinforcedConcreteBeam import *\n", "from test_functions.SpeedReducer import *\n", "from test_functions.ThreeTruss import *\n", "from test_functions.WeldedBeam import *\n", "# Import other objective functions as needed\n", "import time\n", "\n", "from Rosen_PFN4BO import *\n", "\n", "def optimize(objective_function, iteration_input):\n", "\n", " # Variable setup\n", " Current_BEST = -1e10 # Some arbitrary very small number\n", " Prev_BEST = -1e10\n", "\n", " # Initial random samples\n", " # print(objective_functions)\n", " trained_X = torch.rand(20, objective_functions[objective_function]['dim'])\n", "\n", " # Scale it to the domain of interest using the selected function\n", " # print(objective_function)\n", " X_Scaled = objective_functions[objective_function]['scaling'](trained_X)\n", "\n", " # Get the constraints and objective\n", " trained_gx, trained_Y = objective_functions[objective_function]['function'](X_Scaled)\n", "\n", " # Convergence list to store best values\n", " convergence = []\n", "\n", " START_TIME = time.time()\n", "\n", " # Optimization Loop\n", " for ii in range(iteration_input): # Example with 100 iterations\n", "\n", " # (0) Get the updated data for this iteration\n", " X_scaled = objective_functions[objective_function]['scaling'](trained_X)\n", " trained_gx, trained_Y = objective_functions[objective_function]['function'](X_scaled)\n", "\n", " # (1) Randomly sample Xpen \n", " X_pen = torch.rand(1000,trained_X.shape[1])\n", "\n", " # (2) PFN inference phase with EI\n", " default_model = 'final_models/Cyril_500features_800epoch_cpu.pt'\n", " \n", " ei, p_feas = Rosen_PFN_Parallel(default_model,\n", " trained_X, \n", " trained_Y, \n", " trained_gx,\n", " X_pen,\n", " 'power',\n", " 'ei'\n", " )\n", "\n", " # Calculating CEI\n", " CEI = ei\n", " for jj in range(p_feas.shape[1]):\n", " CEI = CEI*p_feas[:,jj]\n", "\n", " # (4) Get the next search value\n", " rec_idx = torch.argmax(CEI)\n", " best_candidate = X_pen[rec_idx,:].unsqueeze(0)\n", "\n", " # (5) Append the next search point\n", " trained_X = torch.cat([trained_X, best_candidate])\n", "\n", "\n", " ################################################################################\n", " # This is just for visualizing the best value. \n", " # This section can be remove for pure optimization purpose\n", " Current_X = objective_functions[objective_function]['scaling'](trained_X)\n", " Current_GX, Current_Y = objective_functions[objective_function]['function'](Current_X)\n", " if ((Current_GX<=0).all(dim=1)).any():\n", " Current_BEST = torch.max(Current_Y[(Current_GX<=0).all(dim=1)])\n", " else:\n", " Current_BEST = Prev_BEST\n", " ################################################################################\n", " \n", " # (ii) Convergence tracking (assuming the best Y is to be maximized)\n", " if Current_BEST != -1e10:\n", " convergence.append(Current_BEST.abs())\n", "\n", " # Timing\n", " END_TIME = time.time()\n", " TOTAL_TIME = END_TIME - START_TIME\n", " \n", " # Website visualization\n", " # (i) Radar chart for trained_X\n", " radar_chart = create_radar_chart(X_scaled)\n", " # (ii) Convergence tracking (assuming the best Y is to be maximized)\n", " convergence_plot = create_convergence_plot(convergence, TOTAL_TIME)\n", " \n", " return radar_chart, convergence_plot\n", "\n", "def create_radar_chart(X_scaled):\n", " fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True))\n", " labels = [f'x{i+1}' for i in range(X_scaled.shape[1])]\n", " values = X_scaled.mean(dim=0).numpy()\n", " \n", " num_vars = len(labels)\n", " angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()\n", " values = np.concatenate((values, [values[0]]))\n", " angles += angles[:1]\n", "\n", " ax.fill(angles, values, color='green', alpha=0.25)\n", " ax.plot(angles, values, color='green', linewidth=2)\n", " ax.set_yticklabels([])\n", " ax.set_xticks(angles[:-1])\n", " # ax.set_xticklabels(labels)\n", " ax.set_xticklabels([f'{label}\\n({value:.2f})' for label, value in zip(labels, values[:-1])]) # Show values\n", " ax.set_title(\"Selected Design\", size=15, color='black', y=1.1)\n", " \n", " plt.close(fig)\n", " return fig\n", "\n", "def create_convergence_plot(convergence, TOTAL_TIME):\n", " fig, ax = plt.subplots()\n", " # print(len(convergence))\n", " ax.plot(convergence, label='Best Objective Value')\n", " ax.set_xlabel('Iteration')\n", " ax.set_ylabel('Objective Value')\n", " ax.set_title('Convergence Plot (Opt Runtime: {t} sec)'.format(t=round(TOTAL_TIME, 2)))\n", " ax.legend()\n", "\n", " # Add text to the top right corner of the plot\n", " if len(convergence) == 0:\n", " ax.text(0.5, 0.5, 'No Feasible Design Found', transform=ax.transAxes, fontsize=12,\n", " verticalalignment='top', horizontalalignment='right')\n", " \n", " plt.close(fig)\n", " return fig\n", "\n", "# Define available objective functions\n", "objective_functions = {\n", " \"Ackley2D\": {\"function\": Ackley2D, \"scaling\": Ackley2D_Scaling, \"dim\": 2},\n", " \"Ackley6D\": {\"function\": Ackley6D, \"scaling\": Ackley6D_Scaling, \"dim\": 6},\n", " \"Ackley10D\": {\"function\": Ackley10D, \"scaling\": Ackley10D_Scaling, \"dim\": 10},\n", " \"GKXWC1\": {\"function\": GKXWC1, \"scaling\": GKXWC1_Scaling, \"dim\": 2},\n", " \"GKXWC2\": {\"function\": GKXWC2, \"scaling\": GKXWC2_Scaling, \"dim\": 2},\n", " \"JLH1\": {\"function\": JLH1, \"scaling\": JLH1_Scaling, \"dim\": 2},\n", " \"JLH2\": {\"function\": JLH2, \"scaling\": JLH2_Scaling, \"dim\": 2},\n", " \"Keane Bump\": {\"function\": KeaneBump, \"scaling\": KeaneBump_Scaling, \"dim\": 18},\n", " \"Three Truss\": {\"function\": ThreeTruss, \"scaling\": ThreeTruss_Scaling, \"dim\": 2},\n", " \"Compression Spring\": {\"function\": CompressionSpring, \"scaling\": CompressionSpring_Scaling, \"dim\": 3},\n", " \"Reinforced Concrete Beam\": {\"function\": ReinforcedConcreteBeam, \"scaling\": ReinforcedConcreteBeam_Scaling, \"dim\": 3},\n", " \"Pressure Vessel\": {\"function\": PressureVessel, \"scaling\": PressureVessel_Scaling, \"dim\": 4},\n", " \"Speed Reducer\": {\"function\": SpeedReducer, \"scaling\": SpeedReducer_Scaling, \"dim\": 4},\n", " \"Welded Beam\": {\"function\": WeldedBeam, \"scaling\": WeldedBeam_Scaling, \"dim\": 4},\n", " \"Heat Exchanger\": {\"function\": HeatExchanger, \"scaling\": HeatExchanger_Scaling, \"dim\": 8},\n", " \"Cantilever Beam\": {\"function\": CantileverBeam, \"scaling\": CantileverBeam_Scaling, \"dim\": 10},\n", " \"Car\": {\"function\": Car, \"scaling\": Car_Scaling, \"dim\": 11},\n", " \n", " # Add more functions here\n", "}\n", "\n", "\n", "\n", "\n", "\n", "with gr.Blocks(theme=gr.themes.Default()) as demo:\n", " # Centered Title and Description using gr.HTML\n", " gr.HTML(\n", " \"\"\"\n", "
\n", "

Pre-trained Transformer for Constrained Bayesian Optimization

\n", "

This is a demo for Bayesian Optimization using PFN (Prior-Data Fitted Networks). \n", " Select your objective function by clicking on one of the check boxes below, then enter the iteration number to run the optimization process. \n", " The results will be visualized in the radar chart and convergence plot.

\n", " \"Example\n", "\n", "
\n", " \"\"\"\n", " )\n", "\n", " selected_objective = gr.State(None) # To store the selected objective function\n", "\n", "\n", "\n", "\n", "\n", " \n", " with gr.Row():\n", " \n", " objective_checkbox_group = gr.CheckboxGroup(\n", " choices=[\"JLH1\", \"JLH2\", \"GKXWC1\", \"GKXWC2\", \"Ackley2D\", \"Ackley6D\", \"Ackley10D\", \"Keane Bump\", \"Three Truss\", \"Reinforced Concrete Beam\", \"Pressure Vessel\", \"Welded Beam\", \"Speed Reducer\", \"Car\"],\n", " label=\"Select the design problem:\"\n", " )\n", " with gr.Row():\n", " iteration_input = gr.Number(label=\"Enter Iteration Number:\", value=10)\n", " \n", "\n", " # Row for the Clear and Submit buttons\n", " with gr.Row():\n", " clear_button = gr.Button(\"Clear\")\n", " submit_button = gr.Button(\"Submit\", variant=\"primary\")\n", " \n", " \n", " with gr.Row():\n", " with gr.Column():\n", " radar_plot = gr.Plot(label=\"Resulting Design\")\n", " with gr.Column():\n", " convergence_plot = gr.Plot(label=\"Convergence Plot\")\n", "\n", "\n", "\n", " # Define actions for buttons\n", " def clear_action():\n", " return None, None, None\n", "\n", " def submit_action(objective_function_choices, iteration_input):\n", " # Handle the case where multiple choices are selected\n", " if len(objective_function_choices) > 0:\n", " selected_function = objective_function_choices[0] # Assuming using the first selected function\n", " return optimize(selected_function, iteration_input)\n", " return None, None\n", "\n", " # Button click actions\n", " clear_button.click(clear_action, outputs=[objective_checkbox_group, radar_plot, convergence_plot])\n", " submit_button.click(\n", " submit_action, \n", " inputs=[objective_checkbox_group, iteration_input], \n", " outputs=[radar_plot, convergence_plot]\n", " )\n", "\n", "demo.launch(share=True)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "352d0291-93b4-43eb-b683-3d48776dc670", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "92ecbbe6-dea6-4e7f-aae1-f0d442dbda3b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ba69b5f9-c52c-4c23-8645-c81c27f7a815", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "id": "05789fba-2099-46b7-8675-64b7969427a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7899\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import gradio as gr\n", "\n", "def calculator(num1, operation, num2):\n", " if operation == \"add\":\n", " return num1 + num2\n", " elif operation == \"subtract\":\n", " return num1 - num2\n", " elif operation == \"multiply\":\n", " return num1 * num2\n", " elif operation == \"divide\":\n", " return num1 / num2\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " num_1 = gr.Number(value=4)\n", " operation = gr.Radio([\"add\", \"subtract\", \"multiply\", \"divide\"])\n", " num_2 = gr.Number(value=0)\n", " submit_btn = gr.Button(value=\"Calculate\")\n", " with gr.Column():\n", " result = gr.Number()\n", "\n", " submit_btn.click(\n", " calculator, inputs=[num_1, operation, num_2], outputs=[result], api_name=False\n", " )\n", " examples = gr.Examples(\n", " examples=[\n", " [5, \"add\", 3],\n", " [4, \"divide\", 2],\n", " [-4, \"multiply\", 2.5],\n", " [0, \"subtract\", 1.2],\n", " ],\n", " inputs=[num_1, operation, num_2],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch(show_api=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "a4bf709a-ff0a-4aac-a4b4-fd98cd5948bb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "679f7647-ca68-46f9-a1da-81d6c96267c9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ea40bfac-e090-4cd5-9caa-99b06db3ea8d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 50, "id": "928ac99a-af8f-401c-8c0b-ef83cfef5ba9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7890\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import gradio as gr\n", "\n", "def calculator(num1, operation, num2):\n", " if operation == \"add\":\n", " return num1 + num2\n", " elif operation == \"subtract\":\n", " return num1 - num2\n", " elif operation == \"multiply\":\n", " return num1 * num2\n", " elif operation == \"divide\":\n", " return num1 / num2\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " num_1 = gr.Number(value=4)\n", " operation = gr.Radio([\"add\", \"subtract\", \"multiply\", \"divide\"])\n", " num_2 = gr.Number(value=0)\n", " submit_btn = gr.Button(value=\"Calculate\")\n", " with gr.Column():\n", " result = gr.Number()\n", "\n", " submit_btn.click(\n", " calculator, inputs=[num_1, operation, num_2], outputs=[result], api_name=False\n", " )\n", " examples = gr.Examples(\n", " examples=[\n", " [5, \"add\", 3],\n", " [4, \"divide\", 2],\n", " [-4, \"multiply\", 2.5],\n", " [0, \"subtract\", 1.2],\n", " ],\n", " inputs=[num_1, operation, num_2],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch(show_api=False)" ] }, { "cell_type": "code", "execution_count": 36, "id": "09a251df-4076-4925-8799-9a2a59cb8246", "metadata": {}, "outputs": [], "source": [ "# import gradio as gr\n", "\n", "# def greet(selected_options):\n", "# return f\"You selected: {', '.join(selected_options)}\"\n", "\n", "# with gr.Blocks() as demo:\n", "# with gr.Row():\n", "# checkbox_group = gr.CheckboxGroup(\n", "# choices=[\"Option 1\", \"Option 2\"],\n", "# label=\"Select your options\",\n", "# elem_id=\"custom_checkbox_group\"\n", "# )\n", "# output = gr.Textbox(label=\"Output\")\n", " \n", "# checkbox_group.change(greet, checkbox_group, output)\n", "\n", "# gr.HTML(\n", "# f\"\"\"\n", "# \n", "# \n", "# \"\"\"\n", "# )\n", "\n", "# demo.launch()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f52549d5-4be0-4672-be6d-df462957cb56", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }