from typing import cast import uuid import os import gradio as gr import numpy as np import pandas as pd import torch import wildtorch as wt USE_OFFLINE_DATA = False ENABLE_DOWNLOAD_SNAPSHOTS = False if USE_OFFLINE_DATA: wildfire_sim_maps = torch.load('wildfire_sim_maps.pt') else: wildfire_sim_maps = wt.dataset.load_wildfire_sim_maps() # torch.save(wildfire_sim_maps, 'wildfire_sim_maps.pt') DEFAULT_SHAPE = (512, 512) DEFAULT_STATE = { 'ds': { 'name': None, 'shape': None, 'data': None, }, 'constants': { 'p_h': 0.58, 'c_1': 0.045, 'c_2': 0.131, 'a': 0.078, 'theta_w': 0, 'v': 10, 'p_firebreak': 0.9, 'p_continue_burn': 0.6, 'device': torch.device('cpu'), 'dtype': torch.float32, }, 'ignition': None, 'out_video_path': None, 'snapshots_path': None, 'checkpoint': None, 'logger': None, } with (gr.Blocks() as demo): def remove_state_files(in_state): if in_state['out_video_path'] is not None: os.remove(in_state['out_video_path']) if in_state['snapshots_path'] is not None: os.remove(in_state['snapshots_path']) state_var = gr.State(DEFAULT_STATE, delete_callback=remove_state_files) with gr.Tabs(selected='tab_1') as tabs: with gr.Tab("1. Datasets", interactive=True, id='tab_1') as tab_1: sel_dataset = gr.Dropdown(cast(list, wildfire_sim_maps['name']) + ['empty'], label='Dataset') with gr.Row() as shape_row: sel_shape_h = gr.Number(label="Map Height", visible=False) sel_shape_w = gr.Number(label="Map Width", visible=False) with gr.Row() as preview_row: canopy_img = gr.Image(label="canopy") density_img = gr.Image(label="density") slope_img = gr.Image(label="slope") tab_1_confirm_btn = gr.Button("Confirm", interactive=True) @tab_1_confirm_btn.click(inputs=[state_var], outputs=[state_var, tabs]) def jump_to_tab_2(in_state): return in_state, gr.Tabs(selected='tab_2') with gr.Tab("2. Simulation Constants and Initial Ignition", interactive=False, id='tab_2') as tab_2: with gr.Row(): sel_p_h = gr.Slider(label="p_h", info="The probability that a burnable cell adjacent to a burning cell will " "catch fire at the next time step under normal conditions", value=DEFAULT_STATE['constants']['p_h'], minimum=0, maximum=1, step=0.01, interactive=True) sel_p_continue_burn = gr.Slider(label="p_continue_burn", info="The probability that a burning cell will continue to burn " "at the next time step", value=DEFAULT_STATE['constants']['p_continue_burn'], minimum=0, maximum=1, step=0.01, interactive=True) with gr.Row(): sel_a = gr.Slider(label="a", info="The coefficient of ground elevation", value=DEFAULT_STATE['constants']['a'], minimum=0, maximum=1, step=0.001, interactive=True) sel_p_firebreak = gr.Slider(label="p_firebreak", info="The probability that a burnable cell will not catch fire even " "if it is adjacent to a burning cell", value=DEFAULT_STATE['constants']['p_firebreak'], minimum=0, maximum=1, step=0.01, interactive=True) with gr.Row(): sel_c_1 = gr.Slider(label="c_1", info="The coefficient of wind velocity", value=DEFAULT_STATE['constants']['c_1'], minimum=0, maximum=1, step=0.001, interactive=True) sel_c_2 = gr.Slider(label="c_2", info="The coefficient of wind direction", value=DEFAULT_STATE['constants']['c_2'], minimum=0, maximum=1, step=0.001, interactive=True) with gr.Row(): sel_theta_w = gr.Slider(label="theta_w", info="The direction of the wind in degrees, measured clockwise from north", value=DEFAULT_STATE['constants']['theta_w'], minimum=0, maximum=360, step=1, interactive=True) sel_v = gr.Slider(label="v", info="The wind velocity, unit in m/s", value=DEFAULT_STATE['constants']['v'], minimum=0, maximum=60, step=1, interactive=True) with gr.Row(): sel_device = gr.Dropdown(label="device", choices=['cpu', 'cuda', 'mps'], info="The device to use", value='cpu', allow_custom_value=True, interactive=True) sel_dtype = gr.Dropdown(label="data type", choices=['float16', 'float32', 'float64'], info="The data type to use", value='float32', interactive=True) @gr.on(triggers=[sel_p_h.input, sel_c_1.input, sel_c_2.input, sel_a.input, sel_theta_w.input, sel_v.input, sel_p_firebreak.input, sel_p_continue_burn.input, sel_device.input, sel_dtype.input], inputs=[state_var, sel_p_h, sel_c_1, sel_c_2, sel_a, sel_theta_w, sel_v, sel_p_firebreak, sel_p_continue_burn, sel_device, sel_dtype], outputs=[state_var]) def update_constants_state(in_state, in_p_h, in_c_1, in_c_2, in_a, in_theta_w, in_v, in_p_firebreak, in_p_continue_burn, in_device, in_dtype): in_state['constants']['p_h'] = in_p_h in_state['constants']['c_1'] = in_c_1 in_state['constants']['c_2'] = in_c_2 in_state['constants']['a'] = in_a in_state['constants']['theta_w'] = in_theta_w in_state['constants']['v'] = in_v in_state['constants']['p_firebreak'] = in_p_firebreak in_state['constants']['p_continue_burn'] = in_p_continue_burn in_state['constants']['device'] = torch.device(in_device) in_state['constants']['dtype'] = { 'float16': torch.float16, 'float32': torch.float32, 'float64': torch.float64, }[in_dtype] return in_state sel_ignition_mode = gr.Dropdown(label="Initial Ignition", choices=['random', 'center', 'custom'], interactive=True) gr.Markdown( 'to use custom ignition, please use the crop to fix the size, and then draw on the image. Please ' 'click on the green button once done. Drawing on the black will be good choices.') with gr.Row(): custom_ignition_paint = gr.Paint(label="custom ignition", image_mode='L', interactive=True, brush=gr.Brush(default_size=3, color_mode='fixed')) ignition_img_over_map = gr.Image(label="ignition over map") @gr.on(triggers=[sel_shape_h.input, sel_shape_w.input], inputs=[state_var, sel_shape_h, sel_shape_w], outputs=[state_var, canopy_img, density_img, slope_img, tab_2, custom_ignition_paint]) def update_preview_row(in_state, in_h, in_w): shape = in_h, in_w data = wt.dataset.generate_empty_dataset(shape) in_state['ds']['shape'] = shape in_state['ds']['data'] = data in_state['ignition'] = None return in_state, gr.Image(wt.utils.colorize_array(np.array(data[0]))), gr.Image( wt.utils.colorize_array(np.array(data[1]))), gr.Image( wt.utils.colorize_array(np.array(data[2]))), gr.Tab(interactive=True), gr.ImageEditor( crop_size=(shape[1], shape[0])) @sel_dataset.change(inputs=[state_var, sel_dataset], outputs=[state_var, sel_shape_h, sel_shape_w, tab_2, custom_ignition_paint, canopy_img, density_img, slope_img]) def update_shape_row(in_state, in_dataset): if in_dataset == 'empty': shape = DEFAULT_SHAPE data = wt.dataset.generate_empty_dataset(shape) editable = True else: idx_dict = {item['name']: index for index, item in enumerate(wildfire_sim_maps)} shape = tuple(cast(torch.Tensor, wildfire_sim_maps[idx_dict[in_dataset]]['shape']).tolist()) data = wt.dataset.transform_wildfire_sim_map(wildfire_sim_maps[idx_dict[in_dataset]]) editable = False in_state['ds']['name'] = in_dataset in_state['ds']['shape'] = shape in_state['ds']['data'] = data return in_state, gr.Number(value=shape[0], interactive=editable, visible=True), gr.Number( value=shape[1], interactive=editable, visible=True), gr.Tab(interactive=True), gr.ImageEditor(interactive=True, crop_size=(shape[1], shape[0])), gr.Image( wt.utils.colorize_array(np.array(data[0]))), gr.Image( wt.utils.colorize_array(np.array(data[1]))), gr.Image( wt.utils.colorize_array(np.array(data[2]))) tab_2_confirm_btn = gr.Button("Confirm", interactive=False) @sel_ignition_mode.input( inputs=[state_var, sel_ignition_mode, custom_ignition_paint], outputs=[state_var, ignition_img_over_map, tab_2_confirm_btn]) def update_ignition_img(in_state, in_mode, in_custom): ignition = torch.zeros(in_state['ds']['shape'], dtype=torch.bool) if in_mode == 'random': ignition = wt.utils.create_ignition(shape=in_state['ds']['shape'], mode='random-single') elif in_mode == 'center': ignition = wt.utils.create_ignition(shape=in_state['ds']['shape'], mode='center') elif in_mode == 'custom': if in_custom['composite'] is not None: ignition_ndarray = in_custom['composite'] != 0 ignition = torch.tensor(ignition_ndarray) else: return in_state, gr.Image( wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']), cmap='grey')), gr.Button(interactive=False) in_state['ignition'] = ignition ignition_ndarray = wt.utils.to_ndarray(ignition) ignition__over_map = wt.utils.overlay_arrays( wt.utils.colorize_array(ignition_ndarray), wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']), cmap='grey'), 0.5 ) return in_state, gr.Image((ignition__over_map * 255).astype(np.uint8)), gr.Button(interactive=True) @custom_ignition_paint.change( inputs=[state_var, custom_ignition_paint], outputs=[state_var, sel_ignition_mode, ignition_img_over_map, tab_2_confirm_btn]) def update_ignition_img_over_map(in_state, in_custom): if in_custom['composite'] is not None: ignition_ndarray = in_custom['composite'] != 0 ignition = torch.tensor(ignition_ndarray) else: return in_state, gr.Dropdown(), gr.Image(), gr.Button() in_state['ignition'] = ignition ignition__over_map = wt.utils.overlay_arrays( wt.utils.colorize_array(ignition_ndarray), wt.utils.colorize_array(wt.utils.compose_vis_wildfire_map(in_state['ds']['data']), cmap='grey'), 0.5 ) return in_state, gr.Dropdown(value='custom'), gr.Image( (ignition__over_map * 255).astype(np.uint8)), gr.Button(interactive=True) with gr.Tab("3. Simulation Control", interactive=False, id='tab_3') as tab_3: @tab_2_confirm_btn.click(inputs=[state_var], outputs=[state_var, tabs, tab_3]) def update_tab_34_components(in_state): return in_state, gr.Tabs(selected='tab_3'), gr.Tab(interactive=True) with gr.Row(): with gr.Column(): gr.Markdown("## Memory Control") checkpoint_cb = gr.Checkbox(label="Checkpoint -> Memory", value=False, interactive=True) run_from_cp_cb = gr.Checkbox(label="Begin from Memory", value=False, interactive=True) reset_btn = gr.Button("Reset Memory", interactive=True) with gr.Column(): gr.Markdown("## Misc Control") sel_steps = gr.Number(label="Number of Steps", value=200, minimum=1, step=1, interactive=True) auto_run_cb = gr.Checkbox(label="Auto Run", value=False, interactive=True) auto_reseed_cb = gr.Checkbox(label="Auto Regenerate Seed when open Tab", value=False, interactive=True) track_p_burn_cb = gr.Checkbox(label="Track p(burn), slow", value=False, interactive=True) with gr.Column(): gr.Markdown("## Random Seed Control") sel_seed = gr.Number(label="Random Seed", value=torch.Generator().seed(), minimum=0, step=1, interactive=True) random_seed_btn = gr.Button("Randomize Seed", interactive=True) @random_seed_btn.click(inputs=[state_var], outputs=[state_var, sel_seed]) def randomize_seed(in_state): return in_state, torch.Generator().seed() with gr.Row(): run_btn = gr.Button("Run Simulation", interactive=True) download_snap_btn = gr.DownloadButton(label="Download Snapshots", interactive=False, visible=False) progress_bar = gr.Progress(track_tqdm=True) with gr.Row(): output_video = gr.Video(label="Simulation Video", interactive=False, autoplay=True) stats_plot = gr.LinePlot(title="Simulation Stats", interactive=True, height=600, width=600, ) with gr.Tab("4. Advanced Simulation", interactive=False, id='tab_4') as tab_4: sel_tab4_step = gr.Slider(label='Step', minimum=0, step=1, value=0, interactive=True) with gr.Row(): cof_tb = gr.Textbox(label='cell_on_fire', interactive=False) cbo_tb = gr.Textbox(label='cell_burned_out', interactive=False) with gr.Row(): fire_state_img = gr.Image(label="Fire State", interactive=False) p_burn_plot = gr.Image(label="p(burn)", interactive=False) # gr.Plot is bad at presenting stats_df = gr.DataFrame() @sel_tab4_step.input(inputs=[state_var, sel_tab4_step], outputs=[state_var, fire_state_img, p_burn_plot, cof_tb, cbo_tb]) def update_tab4_step(in_state, in_user_step): o_fsi, o_pbp = gr.Image(), gr.Image() o_cof_tb, o_cbo_tb = gr.Textbox(), gr.Textbox() if in_state['logger'] is not None: snapshot = in_state['logger'].snapshots[in_user_step] log = in_state['logger'].logs[in_user_step] o_fsi = gr.Image( value=wt.utils.colorize_array(wt.utils.compose_vis_fire_state(snapshot['fire_state']))) if len(in_state['logger'].p_burns) > 0: p_burn_arr = in_state['logger'].p_burns[in_user_step].cpu().numpy() o_pbp = gr.Image( value=wt.utils.colorize_array(p_burn_arr)) o_cof_tb = gr.Textbox(value=str(log['num_cells_on_fire'])) o_cbo_tb = gr.Textbox(value=str(log['num_cells_burned_out'])) return in_state, o_fsi, o_pbp, o_cof_tb, o_cbo_tb @tab_3.select( inputs=[state_var, auto_run_cb, sel_steps, sel_seed, auto_reseed_cb, checkpoint_cb, run_from_cp_cb, track_p_burn_cb], outputs=[state_var, output_video, tab_4, stats_plot, download_snap_btn, sel_seed, sel_tab4_step, stats_df]) def auto_run_simulation(in_state, in_auto_run, in_steps, in_seed, in_auto_reseed, in_checkpoint_cb, in_run_from_cp_cb, in_track_p_burn_cb): o_s = in_state o_v = gr.Video() o_t = gr.Tab() o_lp = gr.LinePlot() o_dsb = gr.DownloadButton() o_ts = gr.Slider() o_sdf = pd.DataFrame() if in_auto_reseed: in_seed = torch.Generator().seed() if in_auto_run: o_s, o_v, o_t, o_lp, o_dsb, o_ts, o_sdf = run_simulation(in_state, in_steps, in_seed, in_checkpoint_cb, in_run_from_cp_cb, in_track_p_burn_cb) return o_s, o_v, o_t, o_lp, o_dsb, in_seed, o_ts, o_sdf @reset_btn.click(inputs=[state_var], outputs=[state_var]) def reset_simulation(in_state): if in_state['checkpoint'] is not None: in_state['checkpoint'] = None gr.Info('Checkpoint Cleared.') return in_state @run_btn.click(inputs=[state_var, sel_steps, sel_seed, checkpoint_cb, run_from_cp_cb, track_p_burn_cb], outputs=[state_var, output_video, tab_4, stats_plot, download_snap_btn, sel_tab4_step, stats_df]) def run_simulation(in_state, in_steps, in_seed, in_checkpoint_cb, in_run_from_cp_cb, in_track_p_burn_cb, in_progress=gr.Progress(track_tqdm=True)): if in_state['out_video_path'] is None: in_state['out_video_path'] = f'runs/{str(uuid.uuid4())}.mp4' simulator = wt.WildTorchSimulator( wildfire_map=in_state['ds']['data'], simulator_constants=wt.SimulatorConstants( p_h=in_state['constants']['p_h'], c_1=in_state['constants']['c_1'], c_2=in_state['constants']['c_2'], a=in_state['constants']['a'], theta_w=in_state['constants']['theta_w'], v=in_state['constants']['v'], p_firebreak=in_state['constants']['p_firebreak'], p_continue_burn=in_state['constants']['p_continue_burn'], device=in_state['constants']['device'], dtype=in_state['constants']['dtype'], ), maximum_step=in_steps, initial_ignition=in_state['ignition'], seed=in_seed, ) if in_state['checkpoint'] is not None and in_run_from_cp_cb: simulator.load_checkpoint(in_state['checkpoint'], restore_seed=False) logger = wt.logger.Logger(disable_writing=True, verbose=False) for i in in_progress.tqdm(range(in_steps)): simulator.step() logger.snapshot_simulation(simulator) logger.log_stats( step=i, num_cells_on_fire=wt.metrics.cell_on_fire(simulator.fire_state).item(), num_cells_burned_out=wt.metrics.cell_burned_out(simulator.fire_state).item(), ) if in_track_p_burn_cb: logger.log_p_burn(simulator) gr.Info('Simulation Completed. Generating Video ...') in_state['logger'] = logger if in_checkpoint_cb: in_state['checkpoint'] = simulator.checkpoint if ENABLE_DOWNLOAD_SNAPSHOTS: logger.snapshots_filepath = in_state['snapshots_path'] = f'runs/{str(uuid.uuid4())}.pt' logger.save_snapshots() can_download_snapshots = True else: can_download_snapshots = False wt.utils.animate_snapshots(logger.snapshots, simulator.wildfire_map, out_filename=in_state['out_video_path']) m_stats_df = pd.DataFrame(logger.logs) m_stats_df = m_stats_df.melt(id_vars=["step"], var_name="key", value_name="value") o_stats_df = pd.DataFrame(logger.logs) return in_state, gr.Video(value=in_state['out_video_path']), gr.Tab(interactive=True), gr.LinePlot( m_stats_df, x='step', y='value', color="key", color_legend_position="bottom", tooltip=["step", "key", "value"], container=False, ), gr.DownloadButton( value=in_state['snapshots_path'], interactive=can_download_snapshots, visible=can_download_snapshots), gr.Slider(maximum=in_steps - 1), o_stats_df demo.queue().launch()