#import branca import folium import streamlit as st from loguru import logger import rioxarray as rxr import numpy as np import xarray as xr import torch from .utils import compute_mask, compute_vndvi, compute_vdi import os @st.cache_resource def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23): """Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap.""" if isinstance(crs, int): crs = f"EPSG{crs}" assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}." m = folium.Map( location=location, zoom_start=zoom_start, crs=crs, max_zoom=max_zoom, tiles="OpenStreetMap", # Esri.WorldImagery attributionControl=False, prefer_canvas=True, ) # Add Esri.WorldImagery as optional basemap (radio button) folium.TileLayer( tiles="Esri.WorldImagery", show=False, overlay=False, control=True, ).add_to(m) return m def get_clean_rendering_container(app_state: str): """Makes sure we can render from a clean slate on state changes.""" slot_in_use = st.session_state.slot_in_use = st.session_state.get( "slot_in_use", "a" ) if app_state != st.session_state.get("previous_state", app_state): if slot_in_use == "a": slot_in_use = st.session_state.slot_in_use = "b" else: slot_in_use = st.session_state.slot_in_use = "a" st.session_state.previous_state = app_state slot = { "a": st.empty(), "b": st.empty(), }[slot_in_use] return slot.container() def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True): """ Create a folium image overlay from a raster filepath or xarray.DataArray. """ if isinstance(raster_path_or_array, str): # Open the raster and its metadata r = rxr.open_rasterio(raster_path_or_array) else: r = raster_path_or_array nodata = r.rio.nodata or 0 if r.rio.crs.to_epsg() != to_crs: r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255 r = r.transpose("y", "x", "band") bounds = r.rio.bounds() # (left, bottom, right, top) # Create a folium image overlay overlay = folium.raster_layers.ImageOverlay( image=r.to_numpy(), name=name, bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]], # format for folium: ((bottom,left),(top,right)) opacity=opacity, interactive=True, cross_origin=False, zindex=1, show=show, ) return overlay @st.cache_resource def process_raster_and_overlays( raster_path: str, _model: torch.nn.Module, patch_size=512, stride=256, scaling_factor=None, rotate=False, batch_size=16, window_size=360, dilate_rows=False, _progress_bar=None, ): # Define paths for mask, vNDVI, and VDI mask_path = raster_path.replace('.tif', '_mask.tif') vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif') vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif') vdi_path = raster_path.replace('.tif', '_vdi.tif') if os.path.exists(mask_path): assert os.path.exists(vndvi_rows_path) assert os.path.exists(vndvi_interrows_path) assert os.path.exists(vdi_path) logger.info(f"Found mask at {mask_path!r}, vNDVI at {vndvi_rows_path!r} and {vndvi_interrows_path!r}, and VDI at {vdi_path!r}. Loading...") # Read raster logger.info(f'Reading raster image {raster_path!r}...') if _progress_bar: _progress_bar.progress(0, text=f'Reading raster image {raster_path!r}...') raster = rxr.open_rasterio(raster_path) # Compute mask logger.info('### Computing mask...') if _progress_bar: _progress_bar.progress(10, text='### Computing mask...') if os.path.exists(mask_path): mask_raster = rxr.open_rasterio(mask_path) # mask is RGBA (red for vine) else: mask = compute_mask( raster.to_numpy(), _model, patch_size=patch_size, stride=stride, scaling_factor=scaling_factor, rotate=rotate, batch_size=batch_size ) # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata # Convert mask from grayscale to RGBA, with red pixels for vine alpha = ((mask != 1)*255).astype(np.uint8) mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0) # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine # Georef mask like raster logger.info('Georeferencing mask...') if _progress_bar: _progress_bar.progress(30, text='Georeferencing mask...') mask_raster = xr.DataArray( mask_colored, dims=('band', 'y', 'x'), coords={'x': raster.x, 'y': raster.y, 'band': raster.band} ) mask_raster.rio.write_crs(raster.rio.crs, inplace=True) # Copy CRS mask_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Copy affine transform # Compute vNDVI logger.info('### Computing vNDVI...') if _progress_bar: _progress_bar.progress(35, text='### Computing vNDVI...') if os.path.exists(vndvi_rows_path) and os.path.exists(vndvi_interrows_path): vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path) # vNDVI is RGBA vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path) # vNDVI is RGBA else: vndvi_rows, vndvi_interrows = compute_vndvi( raster.to_numpy(), mask, dilate_rows=dilate_rows, window_size=window_size ) # vNDVI is RGBA # Georef vNDVI like raster logger.info('Georeferencing vNDVI...') if _progress_bar: _progress_bar.progress(55, text='Georeferencing vNDVI...') vndvi_rows_raster = xr.DataArray( vndvi_rows.transpose(2, 0, 1), dims=('band', 'y', 'x'), coords={'x': raster.x, 'y': raster.y, 'band': raster.band} ) vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True) vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True) vndvi_interrows_raster = xr.DataArray( vndvi_interrows.transpose(2, 0, 1), dims=('band', 'y', 'x'), coords={'x': raster.x, 'y': raster.y, 'band': raster.band} ) vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True) vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Compute VDI logger.info('### Computing VDI...') if _progress_bar: _progress_bar.progress(60, text='### Computing VDI...') if os.path.exists(vdi_path): vdi_raster = rxr.open_rasterio(vdi_path) # VDI is RGBA else: vdi = compute_vdi( raster.to_numpy(), mask, window_size=window_size ) # VDI is RGBA # Georef VDI like raster logger.info('Georeferencing VDI...') if _progress_bar: _progress_bar.progress(80, text='Georeferencing VDI...') vdi_raster = xr.DataArray( vdi.transpose(2, 0, 1), dims=('band', 'y', 'x'), coords={'x': raster.x, 'y': raster.y, 'band': raster.band} ) vdi_raster.rio.write_crs(raster.rio.crs, inplace=True) vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Reproject all rasters to EPSG:4326 if raster.rio.crs.to_epsg() != 4326: logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...") if _progress_bar: _progress_bar.progress(82, text=f"Reprojecting rasters to EPSG:4326 with NODATA value 0...") raster = raster.rio.reproject("EPSG:4326", nodata=0) # nodata default: 255 mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0) vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0) vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0) vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0) # Create overlays logger.info(f'Creating RGB raster overlay...') if _progress_bar: _progress_bar.progress(85, text='Creating overlays: drone image...') raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True) logger.info(f'Creating mask overlay...') if _progress_bar: _progress_bar.progress(88, text='Creating overlays: mask...') mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False) logger.info(f'Creating vNDVI rows overlay...') if _progress_bar: _progress_bar.progress(91, text='Creating overlays: vNDVI (rows)...') vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False) logger.info(f'Creating vNDVI interrows overlay...') if _progress_bar: _progress_bar.progress(94, text='Creating overlays: vNDVI (interrows)...') vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False) logger.info(f'Creating VDI overlay...') if _progress_bar: _progress_bar.progress(97, text='Creating overlays: VDI...') vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False) logger.info('Done!') if _progress_bar: _progress_bar.progress(100, text='Done!') return [raster_overlay, mask_overlay, vndvi_rows_overlay, vndvi_interrows_overlay, vdi_overlay]