Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_folium import st_folium | |
import folium | |
import logging | |
import sys | |
import hydra | |
from plot_functions import * | |
import hydra | |
import torch | |
from model import LitUnsupervisedSegmenter | |
from helper import inference_on_location_and_month, inference_on_location | |
DEFAULT_LATITUDE = 48.81 | |
DEFAULT_LONGITUDE = 2.98 | |
DEFAULT_ZOOM = 5 | |
MIN_YEAR = 2018 | |
MAX_YEAR = 2024 | |
FOLIUM_WIDTH = 925 | |
FOLIUM_HEIGHT = 300 | |
st.set_page_config(layout="wide") | |
def init_cfg(cfg_name): | |
hydra.initialize(config_path="configs", job_name="corine") | |
return hydra.compose(config_name=cfg_name) | |
def init_app(cfg_name) -> LitUnsupervisedSegmenter: | |
file_handler = logging.FileHandler(filename='biomap.log') | |
stdout_handler = logging.StreamHandler(stream=sys.stdout) | |
handlers = [file_handler, stdout_handler] | |
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
# # Initialize hydra with configs | |
# GlobalHydra.instance().clear() | |
cfg = init_cfg(cfg_name) | |
logging.info(f"config : {cfg}") | |
nbclasses = cfg.dir_dataset_n_classes | |
model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
model = model.cpu() | |
logging.info(f"Model Initialiazed") | |
model_path = "biomap/checkpoint/model/model.pt" | |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
logging.info(f"Model weights Loaded") | |
model.load_state_dict(saved_state_dict) | |
return model | |
def app(model): | |
if "infered" not in st.session_state: | |
st.session_state["infered"] = False | |
if "submit" not in st.session_state: | |
st.session_state["submit"] = False | |
if "submit2" not in st.session_state: | |
st.session_state["submit2"] = False | |
st.markdown("<h1 style='text-align: center;'>π’ Biomap by Ekimetrics π’</h1>", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity in the world with the help of land cover.</h2>", unsafe_allow_html=True) | |
st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True) | |
st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True) | |
st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True) | |
if st.session_state["submit"]: | |
fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"]) | |
st.session_state["infered"] = True | |
st.session_state["previous_fig"] = fig | |
if st.session_state["submit2"]: | |
fig = inference_on_location_and_month(model, st.session_state["lat_2"], st.session_state["long_2"], st.session_state["date_2"]) | |
st.session_state["infered"] = True | |
st.session_state["previous_fig"] = fig | |
if st.session_state["infered"]: | |
st.plotly_chart(st.session_state["previous_fig"], use_container_width=True) | |
col_1, col_2 = st.columns([0.5, 0.5]) | |
with col_1: | |
m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM) | |
m.add_child(folium.LatLngPopup()) | |
f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT) | |
selected_latitude = DEFAULT_LATITUDE | |
selected_longitude = DEFAULT_LONGITUDE | |
if f_map.get("last_clicked"): | |
selected_latitude = f_map["last_clicked"]["lat"] | |
selected_longitude = f_map["last_clicked"]["lng"] | |
with col_2: | |
tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"]) | |
with tabs1: | |
submit = st.button("Predict TimeLapse", use_container_width=True, type="primary") | |
st.session_state["submit"] = submit | |
col_tab1_1, col_tab1_2 = st.columns(2) | |
with col_tab1_1: | |
lat = st.text_input("latitude", value=selected_latitude) | |
st.session_state["lat"] = lat | |
with col_tab1_2: | |
long = st.text_input("longitude", value=selected_longitude) | |
st.session_state["long"] = long | |
col_tab1_11, col_tab1_22 = st.columns(2) | |
years = list(range(MIN_YEAR, MAX_YEAR, 1)) | |
with col_tab1_11: | |
start_date = st.selectbox("Start date", years) | |
st.session_state["start_date"] = start_date | |
end_years = [year for year in years if year > start_date] | |
with col_tab1_22: | |
end_date = st.selectbox("End date", end_years) | |
st.session_state["end_date"] = end_date | |
segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True) | |
st.session_state["segment_interval"] = segment_interval | |
with tabs2: | |
submit2 = st.button("Predict Single Image", use_container_width=True, type="primary") | |
st.session_state["submit2"] = submit2 | |
col_tab2_1, col_tab2_2 = st.columns(2) | |
with col_tab2_1: | |
lat_2 = st.text_input("lat.", value=selected_latitude) | |
st.session_state["lat_2"] = lat_2 | |
with col_tab2_2: | |
long_2 = st.text_input("long.", value=selected_longitude) | |
st.session_state["long_2"] = long_2 | |
date_2 = st.text_input("date", "2021-01-01", placeholder="2021-01-01") | |
st.session_state["date_2"] = date_2 | |
if __name__ == "__main__": | |
model = init_app("my_train_config.yml") | |
app(model) |