Biomap / biomap /streamlit_app.py
jeremyLE-Ekimetrics's picture
fix water
4debc65
import streamlit as st
from streamlit_folium import st_folium
import folium
import logging
import sys
import hydra
from plot_functions import *
import hydra
import torch
from model import LitUnsupervisedSegmenter
from helper import inference_on_location_and_month, inference_on_location
DEFAULT_LATITUDE = 48.81
DEFAULT_LONGITUDE = 2.98
DEFAULT_ZOOM = 5
MIN_YEAR = 2018
MAX_YEAR = 2024
FOLIUM_WIDTH = 925
FOLIUM_HEIGHT = 300
st.set_page_config(layout="wide")
@st.cache_resource
def init_cfg(cfg_name):
hydra.initialize(config_path="configs", job_name="corine")
return hydra.compose(config_name=cfg_name)
@st.cache_resource
def init_app(cfg_name) -> LitUnsupervisedSegmenter:
file_handler = logging.FileHandler(filename='biomap.log')
stdout_handler = logging.StreamHandler(stream=sys.stdout)
handlers = [file_handler, stdout_handler]
logging.basicConfig(handlers=handlers, encoding='utf-8', level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
# # Initialize hydra with configs
# GlobalHydra.instance().clear()
cfg = init_cfg(cfg_name)
logging.info(f"config : {cfg}")
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
model = model.cpu()
logging.info(f"Model Initialiazed")
model_path = "biomap/checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
return model
def app(model):
if "infered" not in st.session_state:
st.session_state["infered"] = False
if "submit" not in st.session_state:
st.session_state["submit"] = False
if "submit2" not in st.session_state:
st.session_state["submit2"] = False
st.markdown("<h1 style='text-align: center;'>🐒 Biomap by Ekimetrics 🐒</h1>", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>Estimate Biodiversity in the world with the help of land cover.</h2>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>The segmentation model is an association of UNet and DinoV1 trained on the dataset CORINE. Land use is divided into 6 differents classes : Each class is assigned a GBS score from 0 to 1</p>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>Buildings : 0.1 | Infrastructure : 0.1 | Cultivation : 0.4 | Wetland : 0.9 | Water : 0.9 | Natural green : 1 </p>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center;'>The score is then averaged on the full image.</p>", unsafe_allow_html=True)
if st.session_state["submit"]:
fig = inference_on_location(model, st.session_state["lat"], st.session_state["long"], st.session_state["start_date"], st.session_state["end_date"], st.session_state["segment_interval"])
st.session_state["infered"] = True
st.session_state["previous_fig"] = fig
if st.session_state["submit2"]:
fig = inference_on_location_and_month(model, st.session_state["lat_2"], st.session_state["long_2"], st.session_state["date_2"])
st.session_state["infered"] = True
st.session_state["previous_fig"] = fig
if st.session_state["infered"]:
st.plotly_chart(st.session_state["previous_fig"], use_container_width=True)
col_1, col_2 = st.columns([0.5, 0.5])
with col_1:
m = folium.Map(location=[DEFAULT_LATITUDE, DEFAULT_LONGITUDE], zoom_start=DEFAULT_ZOOM)
m.add_child(folium.LatLngPopup())
f_map = st_folium(m, width=FOLIUM_WIDTH, height=FOLIUM_HEIGHT)
selected_latitude = DEFAULT_LATITUDE
selected_longitude = DEFAULT_LONGITUDE
if f_map.get("last_clicked"):
selected_latitude = f_map["last_clicked"]["lat"]
selected_longitude = f_map["last_clicked"]["lng"]
with col_2:
tabs1, tabs2 = st.tabs(["TimeLapse", "Single Image"])
with tabs1:
submit = st.button("Predict TimeLapse", use_container_width=True, type="primary")
st.session_state["submit"] = submit
col_tab1_1, col_tab1_2 = st.columns(2)
with col_tab1_1:
lat = st.text_input("latitude", value=selected_latitude)
st.session_state["lat"] = lat
with col_tab1_2:
long = st.text_input("longitude", value=selected_longitude)
st.session_state["long"] = long
col_tab1_11, col_tab1_22 = st.columns(2)
years = list(range(MIN_YEAR, MAX_YEAR, 1))
with col_tab1_11:
start_date = st.selectbox("Start date", years)
st.session_state["start_date"] = start_date
end_years = [year for year in years if year > start_date]
with col_tab1_22:
end_date = st.selectbox("End date", end_years)
st.session_state["end_date"] = end_date
segment_interval = st.radio("Interval of time between two segmentation", options=['month','2months', 'year'],horizontal=True)
st.session_state["segment_interval"] = segment_interval
with tabs2:
submit2 = st.button("Predict Single Image", use_container_width=True, type="primary")
st.session_state["submit2"] = submit2
col_tab2_1, col_tab2_2 = st.columns(2)
with col_tab2_1:
lat_2 = st.text_input("lat.", value=selected_latitude)
st.session_state["lat_2"] = lat_2
with col_tab2_2:
long_2 = st.text_input("long.", value=selected_longitude)
st.session_state["long_2"] = long_2
date_2 = st.text_input("date", "2021-01-01", placeholder="2021-01-01")
st.session_state["date_2"] = date_2
if __name__ == "__main__":
model = init_app("my_train_config.yml")
app(model)