import os from datetime import datetime from pathlib import Path import torch import folium import streamlit as st from loguru import logger from tqdm import tqdm from streamlit_folium import st_folium from transformers import SegformerForSemanticSegmentation from lib.folium import ( get_clean_rendering_container, create_map, process_raster_and_overlays, ) import streamlit.components.v1 as components # Page configs st.set_page_config(page_title="GrowSeg Demo", page_icon="🍇", layout="wide") # BUGFIX (https://discuss.streamlit.io/t/message-error-about-torch/90886/6) torch.classes.__path__ = [] # Interoperability with tqdm (https://loguru.readthedocs.io/en/stable/resources/recipes.html#interoperability-with-tqdm-iterations) logger.remove() logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, format="{message}") @st.cache_resource def load_model(hf_path='links-ads/gaia-growseg'): # logger.info(f'Loading GAIA GRowSeg on {device}...') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SegformerForSemanticSegmentation.from_pretrained( hf_path, num_labels=1, num_channels=3, id2label={1: 'vine'}, label2id={'vine': 1}, token=os.getenv('hf_read_access_token') ) return model.to(device).eval() # Load GAIA GRowSeg model model = load_model() def change_key(): st.session_state["key_map"] = str(datetime.now()) # Create selection menu container_predictions = st.container(border=True) with container_predictions: col1, col2 = st.columns([0.3, 0.7]) with col1: # raster_path = st.text_input( # "Enter the path to your local file: ", # key="raster_path_block", # ) # raster_path = st.file_uploader( # "Upload a raster file", # type=["tif", "tiff"], # key="raster_path_block", # ) precomputed_map_path = None raster_path = None raster_selection = st.selectbox( "Select an example or your own raster...", options=[ "Italy", "Portugal", "Spain", "Upload file...", ], key="raster_selection_block", index=None, placeholder="Choose an example or upload your own raster", ) if raster_selection == "Italy": st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.") # TODO GEOSERVER #precomputed_map_path = "data/italy_2022-06-13_cropped.html" elif raster_selection == "Portugal": precomputed_map_path = "data/portugal_2023-08-01.html" elif raster_selection == "Spain": st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.") #precomputed_map_path = "data/spain_2022-07-29_cropped.html" elif raster_selection == "Upload file...": uploaded_file = st.file_uploader( "Upload a raster file", type=["tif"], key="uploaded_file_block", ) if uploaded_file is not None: fn = Path(uploaded_file.name).name print(fn) raster_path = os.path.join("temp", fn) with open(raster_path, "wb") as f: f.write(uploaded_file.getbuffer()) is_raster_path_selected = raster_path is not None is_precomputed_map_selected = precomputed_map_path is not None with col2: with st.container(): st.write("######") with st.expander("More info on the model"): st.write(""" Under the hood, this model is a SegFormer-b5, trained on UAV-acquired vineyard orthoimages and their ground-truth delineation masks. Paper will be available soon. Stay tuned! """) if not is_precomputed_map_selected and is_raster_path_selected: progress_bar = st.progress(0, text="Begin processing...") # Process raster and get overlays overlays = process_raster_and_overlays(raster_path, model, _progress_bar=progress_bar) #progress_bar.empty() #container = get_clean_rendering_container(raster_path) container = st.empty() # draw map interactive_map = create_map() if is_raster_path_selected: # Add overlays to map for overlay in overlays: overlay.add_to(interactive_map) with container.form(key="form1"): if is_precomputed_map_selected: # Load precomputed map # interactive_map = folium.Map(location=[35, -10], zoom_start=6) # folium.IFrame( # precomputed_map_path, # width=1000, # height=500, # ).add_to(interactive_map) with open(precomputed_map_path, 'r') as f: html_content = f.read() interactive_map = components.html(html_content, height=500) else: if is_raster_path_selected: # Center map on overlays bounds = overlays[0].get_bounds() interactive_map.fit_bounds(bounds) else: # Center map on Europe interactive_map.fit_bounds([[35, -10], [60, 40]]) # Add Layer Control (first remove existing one) for key, child in list(interactive_map._children.items()): if isinstance(child, folium.map.LayerControl): del interactive_map._children[key] folium.LayerControl().add_to(interactive_map) # Folium Map component output_map = st_folium( interactive_map, width=None, height=500, returned_objects=["all_drawings"], key=st.session_state.get("key_map", "key_map"), # This is a workaround to force the map to recenter ) # Recenter map submit = st.form_submit_button("Recenter map")