|
import streamlit as st |
|
import torch |
|
import random |
|
import numpy as np |
|
import yaml |
|
import logging |
|
import os |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
import tempfile |
|
import traceback |
|
|
|
from data_utils import ( |
|
save_uploaded_files, |
|
load_dataset, |
|
) |
|
|
|
from inference_utils import run_inference |
|
from config_utils import load_config |
|
from plot_utils import plot_prithvi_output, plot_aurora_output |
|
from prithvi_utils import ( |
|
prithvi_config_ui, |
|
initialize_prithvi_model, |
|
prepare_prithvi_batch |
|
) |
|
from aurora_utils import aurora_config_ui, prepare_aurora_batch, initialize_aurora_model |
|
|
|
from pangu_utils import ( |
|
pangu_config_data, |
|
inference_1hr, |
|
inference_3hrs, |
|
inference_6hrs, |
|
inference_24hrs, |
|
inference_custom_hrs, |
|
plot_pangu_output, |
|
) |
|
|
|
from fengwu_utils import (fengwu_config_data, inference_6hrs_fengwu, inference_12hrs_fengwu, inference_custom_hrs_fengwu, plot_fengwu_output) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
st.set_page_config( |
|
page_title="Weather Data Processor", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
|
|
header_col1, header_col2 = st.columns([4, 1]) |
|
with header_col1: |
|
st.title("π¦οΈ Weather & Climate Data Processor and Forecaster") |
|
|
|
with header_col2: |
|
st.markdown("### Select a Model") |
|
selected_model = st.selectbox( |
|
"", |
|
options=["Pangu-Weather", "FengWu", "Aurora", "Climax", "Prithvi", "GEOS-Specific-LSTM", "GEOS-Finetuned-Climax"], |
|
index=0, |
|
key="model_selector", |
|
help="Select the model you want to use." |
|
) |
|
|
|
st.write("---") |
|
|
|
|
|
left_col, right_col = st.columns([1, 2]) |
|
|
|
with left_col: |
|
st.header("π§ Configuration") |
|
|
|
|
|
if selected_model == "Prithvi": |
|
(config, uploaded_surface_files, uploaded_vertical_files, |
|
clim_surf_path, clim_vert_path, config_path, weights_path) = prithvi_config_ui() |
|
elif selected_model == "Climax": |
|
st.info("Climax model is not yet available.") |
|
st.stop() |
|
elif selected_model == "GEOS-Specific-LSTM": |
|
st.info("GEOS-Specific-LSTM model is not yet available.") |
|
st.stop() |
|
elif selected_model == "GEOS-Finetuned-Climax": |
|
st.info("GEOS-Finetuned-Climax model is not yet available.") |
|
st.stop() |
|
elif selected_model == "Aurora": |
|
uploaded_files = aurora_config_ui() |
|
elif selected_model == "Pangu-Weather": |
|
input_surface_file, input_upper_file = pangu_config_data() |
|
elif selected_model == "FengWu": |
|
input_file1_fengwu, input_file2_fengwu = fengwu_config_data() |
|
else: |
|
|
|
st.subheader(f"{selected_model} Model Data Upload") |
|
st.markdown("### Drag and Drop Your Data Files Here") |
|
uploaded_files = st.file_uploader( |
|
f"Upload Data Files for {selected_model}", |
|
accept_multiple_files=True, |
|
key=f"{selected_model.lower()}_uploader", |
|
type=["nc", "netcdf", "nc4"], |
|
) |
|
|
|
st.write("---") |
|
|
|
|
|
st.subheader("Forecast Duration") |
|
forecast_options = ["1 hour", "3 hours", "6 hours", "24 hours", "Custom"] |
|
selected_duration = st.selectbox( |
|
"Select forecast duration", |
|
forecast_options, |
|
index=3, |
|
help="Select how many hours to forecast." |
|
) |
|
|
|
custom_hours = None |
|
if selected_duration == "Custom": |
|
custom_hours = st.number_input( |
|
"Enter custom forecast hours", |
|
min_value=24, |
|
max_value=480, |
|
value=48, |
|
step=24, |
|
help="Enter the number of hours you want to forecast." |
|
) |
|
|
|
st.write("---") |
|
|
|
|
|
if st.button("π Run Inference"): |
|
with right_col: |
|
st.header("π Inference Progress & Visualization") |
|
|
|
|
|
try: |
|
torch.jit.enable_onednn_fusion(True) |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
st.write(f"Using device: **{torch.cuda.get_device_name()}**") |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = True |
|
else: |
|
device = torch.device("cpu") |
|
st.write("Using device: **CPU**") |
|
|
|
random.seed(42) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(42) |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
except Exception: |
|
st.error("Error initializing device:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
with st.spinner("Running inference, please wait..."): |
|
|
|
if selected_model == "Prithvi": |
|
model, in_mu, in_sig, output_sig, static_mu, static_sig = initialize_prithvi_model( |
|
config, config_path, weights_path, device |
|
) |
|
batch = prepare_prithvi_batch( |
|
uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device |
|
) |
|
out = run_inference(selected_model, model, batch, device) |
|
|
|
st.session_state['prithvi_out'] = out |
|
st.session_state['prithvi_done'] = True |
|
|
|
elif selected_model == "Aurora": |
|
if uploaded_files: |
|
save_uploaded_files(uploaded_files) |
|
ds = load_dataset(st.session_state.temp_file_paths) |
|
if ds is not None: |
|
batch = prepare_aurora_batch(ds) |
|
model = initialize_aurora_model(device) |
|
out = run_inference(selected_model, model, batch, device) |
|
st.session_state['aurora_out'] = out |
|
st.session_state['aurora_ds_subset'] = ds |
|
st.session_state['aurora_done'] = True |
|
else: |
|
st.error("Failed to load dataset for Aurora.") |
|
st.stop() |
|
else: |
|
st.error("Please upload data files for Aurora.") |
|
st.stop() |
|
|
|
elif selected_model == "FengWu": |
|
if input_file1_fengwu and input_file2_fengwu: |
|
try: |
|
input1 = np.load(input_file1_fengwu) |
|
input2 = np.load(input_file2_fengwu) |
|
if selected_duration == "1 hour": |
|
st.warning("1hr inference is not yet available on this model.") |
|
elif selected_duration == "3 hours": |
|
st.warning("3hrs inference is not yet available on this model.") |
|
elif selected_duration == "6 hours": |
|
output_fengwu = inference_6hrs_fengwu(input1, input2) |
|
elif selected_duration == "12 hours": |
|
output_fengwu = inference_12hrs_fengwu(input1, input2) |
|
else: |
|
output_fengwu = inference_custom_hrs_fengwu(input1, input2, custom_hours) |
|
|
|
st.session_state['output_fengwu'] = output_fengwu |
|
st.session_state['fengwu_done'] = True |
|
st.session_state['input_fengwu'] = input_file2_fengwu |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
else: |
|
st.error("Please upload data files for Aurora.") |
|
st.stop() |
|
|
|
elif selected_model == "Pangu-Weather": |
|
if input_surface_file and input_upper_file: |
|
try: |
|
surface_data = np.load(input_surface_file) |
|
upper_data = np.load(input_upper_file) |
|
|
|
|
|
if selected_duration == "1 hour": |
|
out_upper, out_surface = inference_1hr(upper_data, surface_data) |
|
elif selected_duration == "3 hours": |
|
out_upper, out_surface = inference_3hrs(upper_data, surface_data) |
|
elif selected_duration == "6 hours": |
|
out_upper, out_surface = inference_6hrs(upper_data, surface_data) |
|
elif selected_duration == "24 hours": |
|
out_upper, out_surface = inference_24hrs(upper_data, surface_data) |
|
else: |
|
out_upper, out_surface = inference_custom_hrs(upper_data, surface_data, custom_hours) |
|
|
|
|
|
st.session_state['pangu_upper_data'] = upper_data |
|
st.session_state['pangu_surface_data'] = surface_data |
|
st.session_state['pangu_out_upper'] = out_upper |
|
st.session_state['pangu_out_surface'] = out_surface |
|
st.session_state['pangu_done'] = True |
|
|
|
st.write("**Forecast Results:**") |
|
st.write("Upper Data Forecast Shape:", out_upper.shape) |
|
st.write("Surface Data Forecast Shape:", out_surface.shape) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred: {e}") |
|
else: |
|
st.error("Please upload data files for Pangu-Weather.") |
|
st.stop() |
|
|
|
else: |
|
st.warning("Inference not implemented for this model.") |
|
st.stop() |
|
|
|
|
|
if selected_model == "Prithvi": |
|
if 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: |
|
plot_prithvi_output(st.session_state['prithvi_out']) |
|
elif selected_model == "Aurora": |
|
if 'aurora_done' in st.session_state and st.session_state['aurora_done']: |
|
plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) |
|
elif selected_model == "FengWu": |
|
if 'fengwu_done' in st.session_state and st.session_state['fengwu_done']: |
|
plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) |
|
elif selected_model == "Pangu-Weather": |
|
if 'pangu_done' in st.session_state and st.session_state['pangu_done']: |
|
plot_pangu_output( |
|
st.session_state['pangu_upper_data'], |
|
st.session_state['pangu_surface_data'], |
|
st.session_state['pangu_out_upper'], |
|
st.session_state['pangu_out_surface'] |
|
) |
|
else: |
|
st.info("No visualization implemented for this model.") |
|
|
|
else: |
|
|
|
with right_col: |
|
st.header("π₯οΈ Visualization & Progress") |
|
|
|
|
|
if selected_model == "Prithvi" and 'prithvi_done' in st.session_state and st.session_state['prithvi_done']: |
|
plot_prithvi_output(st.session_state['prithvi_out']) |
|
elif selected_model == "Aurora" and 'aurora_done' in st.session_state and st.session_state['aurora_done']: |
|
plot_aurora_output(st.session_state['aurora_out'], st.session_state['aurora_ds_subset']) |
|
elif selected_model == "Pangu-Weather" and 'pangu_done' in st.session_state and st.session_state['pangu_done']: |
|
plot_pangu_output( |
|
st.session_state['pangu_upper_data'], |
|
st.session_state['pangu_surface_data'], |
|
st.session_state['pangu_out_upper'], |
|
st.session_state['pangu_out_surface'] |
|
) |
|
elif selected_model == "FengWu" and 'output_fengwu' in st.session_state and st.session_state['fengwu_done']: |
|
plot_fengwu_output(st.session_state['input_fengwu'], st.session_state['output_fengwu']) |
|
else: |
|
st.info("Awaiting inference to display results.") |
|
|
|
|