tommonopoli commited on
Commit
03e7460
·
1 Parent(s): 674e446

load app & the rest

Browse files
Files changed (8) hide show
  1. .streamlit/config.toml +10 -0
  2. README.md +32 -6
  3. app.py +173 -0
  4. lib/folium.py +246 -0
  5. lib/utils.py +587 -0
  6. lib/viz_utils.py +125 -0
  7. precompute_examples.ipynb +358 -0
  8. requirements.txt +14 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+
3
+ # Max size, in megabytes, for files uploaded with the file_uploader.
4
+ # Default: 200
5
+ maxUploadSize = 1024
6
+
7
+ # Max size, in megabytes, of messages that can be sent via the WebSocket
8
+ # connection.
9
+ # Default: 200
10
+ maxMessageSize = 1024
README.md CHANGED
@@ -1,14 +1,40 @@
1
  ---
2
- title: Gaia Growseg Demo
3
- emoji: 🌖
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.43.2
 
8
  app_file: app.py
 
9
  pinned: false
10
  license: mit
11
- short_description: Vineyard row segmentation from UAV imagery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: GRowSeg demo
3
+ emoji: 🍇
4
+ colorFrom: indigo
5
+ colorTo: gray
6
  sdk: streamlit
7
  sdk_version: 1.43.2
8
+ suggested_hardware: t4-small
9
  app_file: app.py
10
+ short_description: Vineyard row segmentation from UAV imagery
11
  pinned: false
12
  license: mit
13
+ models:
14
+ - links-ads/gaia-growseg
15
+ datasets:
16
+ - links-ads/gaia-vineyard-uav-dataset
17
+ preload_from_hub:
18
+ - links-ads/gaia-growseg
19
+ tags:
20
+ - agriculture
21
+ - viticulture
22
+ - remote-sensing
23
+ - image-segmentation
24
+ - segmentation
25
+ - semantic-segmentation
26
+ - grapevines
27
+ - grapes
28
+ - vineyard
29
+ - uav
30
+ - drone
31
+ - aerial-imagery
32
+ - aerial-photography
33
+ - aerial-photos
34
+ - aerial-images
35
+ - crop
36
+ - field
37
+ - links-ads
38
  ---
39
 
40
+ GRowSeg, a deep learning model for vineyard row segmentation from UAV imagery
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import torch
5
+ import folium
6
+ import streamlit as st
7
+ from loguru import logger
8
+ from tqdm import tqdm
9
+ from streamlit_folium import st_folium
10
+ from transformers import SegformerForSemanticSegmentation
11
+ from lib.folium import (
12
+ get_clean_rendering_container,
13
+ create_map,
14
+ process_raster_and_overlays,
15
+ )
16
+ import streamlit.components.v1 as components
17
+
18
+ # Page configs
19
+ st.set_page_config(page_title="GrowSeg Demo", page_icon="🍇", layout="wide")
20
+
21
+ # BUGFIX (https://discuss.streamlit.io/t/message-error-about-torch/90886/6)
22
+ torch.classes.__path__ = []
23
+
24
+ # Interoperability with tqdm (https://loguru.readthedocs.io/en/stable/resources/recipes.html#interoperability-with-tqdm-iterations)
25
+ logger.remove()
26
+ logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, format="<green>{message}</green>")
27
+
28
+ @st.cache_resource
29
+ def load_model(hf_path='links-ads/gaia-growseg'):
30
+ # logger.info(f'Loading GAIA GRowSeg on {device}...')
31
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ model = SegformerForSemanticSegmentation.from_pretrained(
33
+ hf_path,
34
+ num_labels=1,
35
+ num_channels=3,
36
+ id2label={1: 'vine'},
37
+ label2id={'vine': 1},
38
+ token=os.getenv('hf_read_access_token')
39
+ )
40
+ return model.to(device).eval()
41
+
42
+ # Load GAIA GRowSeg model
43
+ model = load_model()
44
+
45
+
46
+ def change_key():
47
+ st.session_state["key_map"] = str(datetime.now())
48
+
49
+
50
+ # Create selection menu
51
+ container_predictions = st.container(border=True)
52
+ with container_predictions:
53
+ col1, col2 = st.columns([0.3, 0.7])
54
+ with col1:
55
+ # raster_path = st.text_input(
56
+ # "Enter the path to your local file: ",
57
+ # key="raster_path_block",
58
+ # )
59
+ # raster_path = st.file_uploader(
60
+ # "Upload a raster file",
61
+ # type=["tif", "tiff"],
62
+ # key="raster_path_block",
63
+ # )
64
+ precomputed_map_path = None
65
+ raster_path = None
66
+ raster_selection = st.selectbox(
67
+ "Select an example or your own raster...",
68
+ options=[
69
+ "Italy",
70
+ "Portugal",
71
+ "Spain",
72
+ "Upload file...",
73
+ ],
74
+ key="raster_selection_block",
75
+ index=None,
76
+ placeholder="Choose an example or upload your own raster",
77
+ )
78
+ if raster_selection == "Italy":
79
+ st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
80
+ # TODO GEOSERVER
81
+ #precomputed_map_path = "data/italy_2022-06-13_cropped.html"
82
+ elif raster_selection == "Portugal":
83
+ precomputed_map_path = "data/portugal_2023-08-01.html"
84
+ elif raster_selection == "Spain":
85
+ st.markdown("At this stage, only Portugal is available due to the WebSocket payload limit.")
86
+ #precomputed_map_path = "data/spain_2022-07-29_cropped.html"
87
+ elif raster_selection == "Upload file...":
88
+ uploaded_file = st.file_uploader(
89
+ "Upload a raster file",
90
+ type=["tif"],
91
+ key="uploaded_file_block",
92
+ )
93
+ if uploaded_file is not None:
94
+ fn = Path(uploaded_file.name).name
95
+ print(fn)
96
+ raster_path = os.path.join("temp", fn)
97
+ with open(raster_path, "wb") as f:
98
+ f.write(uploaded_file.getbuffer())
99
+
100
+ is_raster_path_selected = raster_path is not None
101
+ is_precomputed_map_selected = precomputed_map_path is not None
102
+
103
+ with col2:
104
+ with st.container():
105
+ st.write("######")
106
+ with st.expander("More info on the model"):
107
+ st.write("""
108
+ Under the hood, this model is a SegFormer-b5, trained on
109
+ UAV-acquired vineyard orthoimages and their ground-truth
110
+ delineation masks. Paper will be available soon. Stay tuned!
111
+ """)
112
+
113
+ if not is_precomputed_map_selected and is_raster_path_selected:
114
+ progress_bar = st.progress(0, text="Begin processing...")
115
+ # Process raster and get overlays
116
+ overlays = process_raster_and_overlays(raster_path, model, _progress_bar=progress_bar)
117
+ #progress_bar.empty()
118
+
119
+
120
+ #container = get_clean_rendering_container(raster_path)
121
+ container = st.empty()
122
+
123
+ # draw map
124
+ interactive_map = create_map()
125
+
126
+ if is_raster_path_selected:
127
+ # Add overlays to map
128
+ for overlay in overlays:
129
+ overlay.add_to(interactive_map)
130
+
131
+ with container.form(key="form1"):
132
+
133
+ if is_precomputed_map_selected:
134
+ # Load precomputed map
135
+ # interactive_map = folium.Map(location=[35, -10], zoom_start=6)
136
+ # folium.IFrame(
137
+ # precomputed_map_path,
138
+ # width=1000,
139
+ # height=500,
140
+ # ).add_to(interactive_map)
141
+
142
+ with open(precomputed_map_path, 'r') as f:
143
+ html_content = f.read()
144
+ interactive_map = components.html(html_content, height=500)
145
+
146
+
147
+ else:
148
+
149
+ if is_raster_path_selected:
150
+ # Center map on overlays
151
+ bounds = overlays[0].get_bounds()
152
+ interactive_map.fit_bounds(bounds)
153
+ else:
154
+ # Center map on Europe
155
+ interactive_map.fit_bounds([[35, -10], [60, 40]])
156
+
157
+ # Add Layer Control (first remove existing one)
158
+ for key, child in list(interactive_map._children.items()):
159
+ if isinstance(child, folium.map.LayerControl):
160
+ del interactive_map._children[key]
161
+ folium.LayerControl().add_to(interactive_map)
162
+
163
+ # Folium Map component
164
+ output_map = st_folium(
165
+ interactive_map,
166
+ width=None,
167
+ height=500,
168
+ returned_objects=["all_drawings"],
169
+ key=st.session_state.get("key_map", "key_map"), # This is a workaround to force the map to recenter
170
+ )
171
+
172
+ # Recenter map
173
+ submit = st.form_submit_button("Recenter map")
lib/folium.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import branca
2
+ import folium
3
+ import geopandas as gpd
4
+ import streamlit as st
5
+ from loguru import logger
6
+ import rioxarray as rxr
7
+ import numpy as np
8
+ import xarray as xr
9
+ import torch
10
+ from .utils import compute_mask, compute_vndvi, compute_vdi
11
+ import os
12
+
13
+
14
+
15
+ @st.cache_resource
16
+ def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23):
17
+ """Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap."""
18
+ if isinstance(crs, int):
19
+ crs = f"EPSG{crs}"
20
+ assert crs in ["EPSG3857"], f"Only EPSG:3857 supported for now. Got {crs}."
21
+
22
+ m = folium.Map(
23
+ location=location,
24
+ zoom_start=zoom_start,
25
+ crs=crs,
26
+ max_zoom=max_zoom,
27
+ tiles="OpenStreetMap", # Esri.WorldImagery
28
+ attributionControl=False,
29
+ prefer_canvas=True,
30
+ )
31
+
32
+ # Add Esri.WorldImagery as optional basemap (radio button)
33
+ folium.TileLayer(
34
+ tiles="Esri.WorldImagery",
35
+ show=False,
36
+ overlay=False,
37
+ control=True,
38
+ ).add_to(m)
39
+
40
+ return m
41
+
42
+
43
+
44
+ def get_clean_rendering_container(app_state: str):
45
+ """Makes sure we can render from a clean slate on state changes."""
46
+ slot_in_use = st.session_state.slot_in_use = st.session_state.get(
47
+ "slot_in_use", "a"
48
+ )
49
+ if app_state != st.session_state.get("previous_state", app_state):
50
+ if slot_in_use == "a":
51
+ slot_in_use = st.session_state.slot_in_use = "b"
52
+ else:
53
+ slot_in_use = st.session_state.slot_in_use = "a"
54
+
55
+ st.session_state.previous_state = app_state
56
+
57
+ slot = {
58
+ "a": st.empty(),
59
+ "b": st.empty(),
60
+ }[slot_in_use]
61
+
62
+ return slot.container()
63
+
64
+
65
+
66
+ def create_image_overlay(raster_path_or_array, name="Raster", opacity=1.0, to_crs=4326, show=True):
67
+ """ Create a folium image overlay from a raster filepath or xarray.DataArray. """
68
+ if isinstance(raster_path_or_array, str):
69
+ # Open the raster and its metadata
70
+ r = rxr.open_rasterio(raster_path_or_array)
71
+ else:
72
+ r = raster_path_or_array
73
+ nodata = r.rio.nodata or 0
74
+ if r.rio.crs.to_epsg() != to_crs:
75
+ r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255
76
+ r = r.transpose("y", "x", "band")
77
+ bounds = r.rio.bounds() # (left, bottom, right, top)
78
+
79
+ # Create a folium image overlay
80
+ overlay = folium.raster_layers.ImageOverlay(
81
+ image=r.to_numpy(),
82
+ name=name,
83
+ bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]], # format for folium: ((bottom,left),(top,right))
84
+ opacity=opacity,
85
+ interactive=True,
86
+ cross_origin=False,
87
+ zindex=1,
88
+ show=show,
89
+ )
90
+
91
+ return overlay
92
+
93
+
94
+
95
+ @st.cache_resource
96
+ def process_raster_and_overlays(
97
+ raster_path: str,
98
+ _model: torch.nn.Module,
99
+ patch_size=512,
100
+ stride=256,
101
+ scaling_factor=None,
102
+ rotate=False,
103
+ batch_size=16,
104
+ window_size=360,
105
+ dilate_rows=False,
106
+ _progress_bar=None,
107
+ ):
108
+
109
+ # Define paths for mask, vNDVI, and VDI
110
+ mask_path = raster_path.replace('.tif', '_mask.tif')
111
+ vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif')
112
+ vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif')
113
+ vdi_path = raster_path.replace('.tif', '_vdi.tif')
114
+ if os.path.exists(mask_path):
115
+ assert os.path.exists(vndvi_rows_path)
116
+ assert os.path.exists(vndvi_interrows_path)
117
+ assert os.path.exists(vdi_path)
118
+ logger.info(f"Found mask at {mask_path!r}, vNDVI at {vndvi_rows_path!r} and {vndvi_interrows_path!r}, and VDI at {vdi_path!r}. Loading...")
119
+
120
+ # Read raster
121
+ logger.info(f'Reading raster image {raster_path!r}...')
122
+ if _progress_bar: _progress_bar.progress(0, text=f'Reading raster image {raster_path!r}...')
123
+ raster = rxr.open_rasterio(raster_path)
124
+
125
+ # Compute mask
126
+ logger.info('### Computing mask...')
127
+ if _progress_bar: _progress_bar.progress(10, text='### Computing mask...')
128
+
129
+
130
+ if os.path.exists(mask_path):
131
+ mask_raster = rxr.open_rasterio(mask_path) # mask is RGBA (red for vine)
132
+ else:
133
+ mask = compute_mask(
134
+ raster.to_numpy(),
135
+ _model,
136
+ patch_size=patch_size,
137
+ stride=stride,
138
+ scaling_factor=scaling_factor,
139
+ rotate=rotate,
140
+ batch_size=batch_size
141
+ ) # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata
142
+
143
+ # Convert mask from grayscale to RGBA, with red pixels for vine
144
+ alpha = ((mask != 1)*255).astype(np.uint8)
145
+ mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0) # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine
146
+
147
+ # Georef mask like raster
148
+ logger.info('Georeferencing mask...')
149
+ if _progress_bar: _progress_bar.progress(30, text='Georeferencing mask...')
150
+ mask_raster = xr.DataArray(
151
+ mask_colored,
152
+ dims=('band', 'y', 'x'),
153
+ coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
154
+ )
155
+ mask_raster.rio.write_crs(raster.rio.crs, inplace=True) # Copy CRS
156
+ mask_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Copy affine transform
157
+
158
+ # Compute vNDVI
159
+ logger.info('### Computing vNDVI...')
160
+ if _progress_bar: _progress_bar.progress(35, text='### Computing vNDVI...')
161
+
162
+ if os.path.exists(vndvi_rows_path) and os.path.exists(vndvi_interrows_path):
163
+ vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path) # vNDVI is RGBA
164
+ vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path) # vNDVI is RGBA
165
+ else:
166
+ vndvi_rows, vndvi_interrows = compute_vndvi(
167
+ raster.to_numpy(),
168
+ mask,
169
+ dilate_rows=dilate_rows,
170
+ window_size=window_size
171
+ ) # vNDVI is RGBA
172
+
173
+ # Georef vNDVI like raster
174
+ logger.info('Georeferencing vNDVI...')
175
+ if _progress_bar: _progress_bar.progress(55, text='Georeferencing vNDVI...')
176
+ vndvi_rows_raster = xr.DataArray(
177
+ vndvi_rows.transpose(2, 0, 1),
178
+ dims=('band', 'y', 'x'),
179
+ coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
180
+ )
181
+ vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True)
182
+ vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True)
183
+
184
+ vndvi_interrows_raster = xr.DataArray(
185
+ vndvi_interrows.transpose(2, 0, 1),
186
+ dims=('band', 'y', 'x'),
187
+ coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
188
+ )
189
+ vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True)
190
+ vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True)
191
+
192
+ # Compute VDI
193
+ logger.info('### Computing VDI...')
194
+ if _progress_bar: _progress_bar.progress(60, text='### Computing VDI...')
195
+
196
+ if os.path.exists(vdi_path):
197
+ vdi_raster = rxr.open_rasterio(vdi_path) # VDI is RGBA
198
+ else:
199
+ vdi = compute_vdi(
200
+ raster.to_numpy(),
201
+ mask,
202
+ window_size=window_size
203
+ ) # VDI is RGBA
204
+
205
+ # Georef VDI like raster
206
+ logger.info('Georeferencing VDI...')
207
+ if _progress_bar: _progress_bar.progress(80, text='Georeferencing VDI...')
208
+ vdi_raster = xr.DataArray(
209
+ vdi.transpose(2, 0, 1),
210
+ dims=('band', 'y', 'x'),
211
+ coords={'x': raster.x, 'y': raster.y, 'band': raster.band}
212
+ )
213
+ vdi_raster.rio.write_crs(raster.rio.crs, inplace=True)
214
+ vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True)
215
+
216
+ # Reproject all rasters to EPSG:4326
217
+ if raster.rio.crs.to_epsg() != 4326:
218
+ logger.info(f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
219
+ if _progress_bar: _progress_bar.progress(82, text=f"Reprojecting rasters to EPSG:4326 with NODATA value 0...")
220
+ raster = raster.rio.reproject("EPSG:4326", nodata=0) # nodata default: 255
221
+ mask_raster = mask_raster.rio.reproject("EPSG:4326", nodata=0)
222
+ vndvi_rows_raster = vndvi_rows_raster.rio.reproject("EPSG:4326", nodata=0)
223
+ vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject("EPSG:4326", nodata=0)
224
+ vdi_raster = vdi_raster.rio.reproject("EPSG:4326", nodata=0)
225
+
226
+ # Create overlays
227
+ logger.info(f'Creating RGB raster overlay...')
228
+ if _progress_bar: _progress_bar.progress(85, text='Creating overlays: drone image...')
229
+ raster_overlay = create_image_overlay(raster, name="Orthoimage", opacity=1.0, show=True)
230
+ logger.info(f'Creating mask overlay...')
231
+ if _progress_bar: _progress_bar.progress(88, text='Creating overlays: mask...')
232
+ mask_overlay = create_image_overlay(mask_raster, name="Mask", opacity=1.0, show=False)
233
+ logger.info(f'Creating vNDVI rows overlay...')
234
+ if _progress_bar: _progress_bar.progress(91, text='Creating overlays: vNDVI (rows)...')
235
+ vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name="vNDVI Rows", opacity=1.0, show=False)
236
+ logger.info(f'Creating vNDVI interrows overlay...')
237
+ if _progress_bar: _progress_bar.progress(94, text='Creating overlays: vNDVI (interrows)...')
238
+ vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name="vNDVI Interrows", opacity=1.0, show=False)
239
+ logger.info(f'Creating VDI overlay...')
240
+ if _progress_bar: _progress_bar.progress(97, text='Creating overlays: VDI...')
241
+ vdi_overlay = create_image_overlay(vdi_raster, name="VDI", opacity=1.0, show=False)
242
+
243
+ logger.info('Done!')
244
+ if _progress_bar: _progress_bar.progress(100, text='Done!')
245
+
246
+ return [raster_overlay, mask_overlay, vndvi_rows_overlay, vndvi_interrows_overlay, vdi_overlay]
lib/utils.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import rasterio
4
+ import xarray as xr
5
+ import rioxarray as rxr
6
+ import cv2
7
+ from transformers import SegformerForSemanticSegmentation
8
+ from tqdm import tqdm
9
+ from scipy.ndimage import grey_dilation
10
+ import matplotlib as mpl
11
+ import matplotlib.pyplot as plt
12
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
13
+ from .viz_utils import alpha_composite
14
+ from loguru import logger
15
+
16
+
17
+
18
+ def resize(img, shape=None, scaling_factor=1., order='CHW'):
19
+ """Resize an image by a given scaling factor"""
20
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
21
+ assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them"
22
+
23
+ # resize image
24
+ if order == 'CHW':
25
+ img = np.moveaxis(img, 0, -1) # CHW -> HWC
26
+
27
+ if shape is not None:
28
+ img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR)
29
+ else:
30
+ img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR)
31
+
32
+ # NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension
33
+ if len(img.shape) == 2:
34
+ img = img[..., None]
35
+
36
+ if order == 'CHW':
37
+ img = np.moveaxis(img, -1, 0) # HWC -> CHW
38
+
39
+ return img
40
+
41
+
42
+ def minimum_needed_padding(img_size, patch_size: int, stride: int):
43
+ """
44
+ Compute the minimum padding needed to make an image divisible by a patch size with a given stride.
45
+ Args:
46
+ image_shape (tuple): the shape (H,W) of the image tensor
47
+ patch_size (int): the size of the patches to extract
48
+ stride (int): the stride to use when extracting patches
49
+ Returns:
50
+ tuple: the padding needed to make the image tensor divisible by the patch size with the given stride
51
+ """
52
+
53
+ img_size = np.array(img_size)
54
+ pad = np.where(
55
+ img_size <= patch_size,
56
+ (patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0)
57
+ (stride - (img_size - patch_size)) % stride
58
+ )
59
+ pad_t, pad_l = pad // 2
60
+ pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l
61
+
62
+ return pad_t, pad_b, pad_l, pad_r
63
+
64
+
65
+ def pad(img, pad, order='CHW'):
66
+ """Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)"""
67
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
68
+
69
+ pad_t, pad_b, pad_l, pad_r = pad
70
+
71
+ # pad image
72
+ if order == 'HWC':
73
+ padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect'
74
+ else:
75
+ padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect'
76
+
77
+ if isinstance(img, torch.Tensor):
78
+ padded_img = torch.tensor(padded_img)
79
+
80
+ return padded_img
81
+
82
+
83
+ def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True):
84
+ """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)"""
85
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
86
+
87
+ if order == 'HWC':
88
+ H, W = img.shape[:2]
89
+ else:
90
+ H, W = img.shape[1:]
91
+
92
+ # compute the number of patches
93
+ n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1)
94
+
95
+ # extract patches
96
+ patches = []
97
+ patches_idx = []
98
+ for i in range(0, H-patch_size+1, stride):
99
+ for j in range(0, W-patch_size+1, stride):
100
+
101
+ patches_idx.append((i, i+patch_size, j, j+patch_size))
102
+
103
+ if not only_return_idx:
104
+ if order == 'HWC':
105
+ patch = img[i:i+patch_size, j:j+patch_size, :]
106
+ else:
107
+ patch = img[:, i:i+patch_size, j:j+patch_size]
108
+ patches.append(patch)
109
+
110
+ if only_return_idx:
111
+ return patches_idx
112
+ return patches, patches_idx
113
+
114
+
115
+ def segment_batch(batch, model):
116
+
117
+ # perform prediction
118
+ with torch.no_grad():
119
+ out = model(batch) # (n_patches, 1, H, W) logits
120
+ if isinstance(model, SegformerForSemanticSegmentation):
121
+ out = upsample(out.logits, size=batch.shape[-2:])
122
+
123
+ # apply sigmoid
124
+ out = torch.sigmoid(out) # logits -> confidence scores
125
+
126
+ return out
127
+
128
+
129
+ def upsample(x, size):
130
+ """Upsample a 3D/4D/5D tensor"""
131
+ return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False)
132
+
133
+
134
+ def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO
135
+ """Merge patches into a single image"""
136
+ assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']"
137
+ if rotate:
138
+ axes_to_rotate = (0,1) if order == 'HWC' else (1,2)
139
+ patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)]
140
+ else:
141
+ assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes"
142
+
143
+ # if canvas_shape is None, infer it from patches_idx
144
+ if canvas_shape is None:
145
+ patches_idx_zipped = list(zip(*patches_idx))
146
+ canvas_H = max(patches_idx_zipped[1])
147
+ canvas_W = max(patches_idx_zipped[3])
148
+ else:
149
+ canvas_H, canvas_W = canvas_shape
150
+
151
+ # initialize canvas
152
+ dtype = patches[0].dtype
153
+ if order == 'HWC':
154
+ canvas_C = patches[0].shape[-1]
155
+ canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC
156
+ n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1))
157
+ else:
158
+ canvas_C = patches[0].shape[0]
159
+ canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW
160
+ n_overlapping_patches = np.zeros((1, canvas_H, canvas_W))
161
+
162
+ # merge patches
163
+ for p, (t,b,l,r) in zip(patches, patches_idx):
164
+ if order == 'HWC':
165
+ canvas[t:b, l:r, :] += p
166
+ n_overlapping_patches[t:b, l:r, 0] += 1
167
+ else:
168
+ canvas[:, t:b, l:r] += p
169
+ n_overlapping_patches[0, t:b, l:r] += 1
170
+
171
+
172
+ # compute average
173
+ canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0))
174
+
175
+ return canvas
176
+
177
+
178
+ def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False):
179
+ """Segment an RGB image by using a segmentation model. Returns a probability
180
+ map (and performance metrics, if requested)"""
181
+
182
+ # some checks
183
+ assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}"
184
+ assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}"
185
+ assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}"
186
+
187
+ # prepare model for evaluation
188
+ model = model.to(device)
189
+ model.eval()
190
+
191
+ # prepare alpha channel
192
+ original_shape = img.shape
193
+ if img.shape[0] == 3:
194
+ # create dummy alpha channel
195
+ alpha = np.full(original_shape[1:], 255, dtype=np.uint8)
196
+ else:
197
+ # extract alpha channel
198
+ img, alpha = img[:3], img[3]
199
+
200
+ # resize image
201
+ img = resize(img, scaling_factor=scaling_factor)
202
+
203
+ # pad image
204
+ pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride)
205
+ padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r))
206
+ padded_shape = padded_img.shape
207
+
208
+ # extract patches indexes
209
+ patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride)
210
+
211
+ ### segment
212
+ masks = []
213
+ masks_idx = []
214
+
215
+ batch = []
216
+ for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))):
217
+ t, b, l, r = p_idx
218
+
219
+ # extract patch
220
+ patch = padded_img[:, t:b, l:r]
221
+
222
+ # consider patch only if it is valid (i.e. not all black or all white)
223
+ if np.any(patch != 0) and np.any(patch != 255):
224
+
225
+ # convert patch to torch.tensor with float32 values in [0,1] (as required by torch)
226
+ patch = torch.tensor(patch).float() / 255.
227
+
228
+ # normalize patch with ImageNet mean and std
229
+ patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
230
+
231
+ # add patch to batch
232
+ batch.append(patch)
233
+ masks_idx.append(p_idx)
234
+
235
+ # (optional) for each patch extracted, consider also its rotated versions
236
+ if rotate:
237
+ for rot in range(1,4):
238
+ patch = torch.rot90(patch, rot, dims=[1,2])
239
+ batch.append(patch)
240
+ masks_idx.append(p_idx)
241
+
242
+ # if the batch is full, perform prediction
243
+ if len(batch) >= batch_size or i == len(patches_idx)-1:
244
+
245
+ # move batch to GPU
246
+ batch = torch.stack(batch).to(device)
247
+
248
+ # perform prediction
249
+ out = segment_batch(batch, model)
250
+
251
+ # append predictions to masks
252
+ masks.append(out.cpu().numpy())
253
+
254
+ # reset batch
255
+ batch = []
256
+
257
+ # concatenate predictions
258
+ masks = np.concatenate(masks) # (n_patches, 1, H, W)
259
+
260
+ # merge patches
261
+ mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W)
262
+
263
+ # undo padding
264
+ mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r]
265
+
266
+ # resize mask to original shape
267
+ mask = resize(mask, shape=original_shape[1:])
268
+
269
+ # apply alpha channel, i.e. set to -1 the pixels where alpha is 0
270
+ mask = np.where(alpha == 0, -1, mask)
271
+
272
+ return mask.squeeze()
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
284
+
285
+
286
+
287
+ def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., order="HWC", normalize=False, return_min_max=False, verbose=False):
288
+ assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}'
289
+ if order == "HWC":
290
+ assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}'
291
+ elif order == "CHW":
292
+ assert img.shape[0] == 1, f'Input image must be formatted as CHW, with C = 1. Got a shape of {img.shape}'
293
+
294
+ # check if alpha channel was given, and cast it to np.float32 with values in [0,1]
295
+ if alpha is not None:
296
+ assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match'
297
+ if alpha.dtype == np.uint8:
298
+ alpha = (alpha / 255).astype(np.float32)
299
+ elif alpha.dtype == bool:
300
+ alpha = alpha.astype(np.float32)
301
+ else:
302
+ alpha = np.ones_like(img, dtype=np.float32)
303
+
304
+ # compute threshold
305
+ thresh = min_nonblank_pixels * window**2
306
+
307
+ # extract patches idxs
308
+ patches_idx = extract_patches(img, patch_size=window, stride=granularity, order=order, only_return_idx=True)
309
+
310
+ # initialize canvas
311
+ canvas = np.zeros_like(img, dtype=np.float32)
312
+ n_overlapping_patches = np.zeros_like(img, dtype=np.float32)
313
+
314
+ # cycle through patches idxs
315
+ for t,b,l,r in tqdm(patches_idx, disable=not verbose):
316
+ p_a = alpha[t:b,l:r]
317
+ n_valid_pixels = p_a.sum()
318
+ # keep only if it has more than min_nonblank_pixels
319
+ if n_valid_pixels <= thresh:
320
+ continue
321
+
322
+ # compute average patch value (i.e. density inside the patch)
323
+ p = img[t:b,l:r]
324
+ p_density = (p * p_a).sum() / n_valid_pixels
325
+
326
+ # add to canvas
327
+ canvas[t:b,l:r] += p_density
328
+ n_overlapping_patches[t:b,l:r] += 1
329
+
330
+ # compute average density
331
+ density_map = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0))
332
+
333
+ # apply alpha
334
+ density_map = density_map * alpha
335
+
336
+ if normalize:
337
+ # [0,1]-normalize
338
+ density_map_min = density_map.min()
339
+ density_map_max = density_map.max()
340
+ density_map = (density_map - density_map_min) / (density_map_max - density_map_min)
341
+
342
+ if return_min_max:
343
+ return density_map, density_map_min, density_map_max
344
+
345
+ return density_map
346
+
347
+
348
+
349
+ def compute_vndvi(
350
+ raster: np.ndarray,
351
+ mask: np.ndarray,
352
+ dilate_rows=True,
353
+ window_size=360,
354
+ granularity=45,
355
+ ):
356
+ assert isinstance(raster, np.ndarray)
357
+ assert isinstance(mask, np.ndarray)
358
+ assert len(raster.shape) == 3 # CHW
359
+ assert len(mask.shape) == 2 # HW
360
+ assert raster.shape[0] in [3,4] # RGB or RGBA
361
+
362
+ # CHW -> HWC
363
+ raster = raster.transpose(1,2,0)
364
+
365
+ # Extract channels
366
+ _raster = raster.astype(np.float32) / 255 # convert to float32 in [0,1]
367
+ R, G, B = _raster[:,:,0], _raster[:,:,1], _raster[:,:,2]
368
+
369
+ # To avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels
370
+ R = np.where(R == 0, 1, R)
371
+ B = np.where(B == 0, 1, B)
372
+
373
+ # Mask has values: 0=interrows, 255=rows, 1=nodata
374
+ # Get mask for the rows and interrows
375
+ mask_rows = (mask == 255)
376
+ mask_interrows = (mask == 0)
377
+ mask_valid = mask_rows | mask_interrows
378
+
379
+ # Compute vndvi
380
+ vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118))
381
+
382
+ # Clip values to [0,1]
383
+ vndvi = np.clip(vndvi, 0, 1)
384
+
385
+ # Compute 10th and 90th percentile on whole vineyard vndvi heatmap
386
+ vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask_valid], [10,90])
387
+
388
+ # Clip values between 10th and 90th percentile
389
+ vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90)
390
+
391
+ # Perform sliding window average pooling to smooth the heatmap
392
+ # NB: the window takes into account only the rows
393
+ vndvi_rows_clipped_pooled = sliding_window_avg_pooling(
394
+ np.where(mask_rows, vndvi_clipped, 0)[..., None],
395
+ window = int(window_size / 4),
396
+ granularity = granularity,
397
+ alpha = mask_rows[..., None],
398
+ min_nonblank_pixels = 0.0,
399
+ verbose=True,
400
+ )
401
+ # Same, but for interrows
402
+ vndvi_interrows_clipped_pooled = sliding_window_avg_pooling(
403
+ np.where(mask_interrows, vndvi_clipped, 0)[..., None],
404
+ window = int(window_size / 4),
405
+ granularity = granularity,
406
+ alpha = mask_interrows[..., None],
407
+ min_nonblank_pixels = 0.0,
408
+ verbose=True,
409
+ )
410
+
411
+ # Apply dilation to rows mask
412
+ dil_factor = int(window_size / 60)
413
+ mask_rows_dilated = grey_dilation(mask_rows, size=(dil_factor, dil_factor))
414
+ vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor, dil_factor, 1))
415
+
416
+ # For visualization purposes, normalize with vndvi_perc10 and
417
+ # vndvi_perc90 (because we want vndvi_perc10 to be the first color of
418
+ # the colormap and vndvi_perc90 to be the last)
419
+ vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
420
+ vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
421
+ vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10)
422
+
423
+ # for visualization
424
+ vndvi_rows_img = alpha_composite(
425
+ raster,
426
+ vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized,
427
+ opacity = 1.0,
428
+ colormap = 'RdYlGn',
429
+ alpha_image = np.zeros_like(raster[:,:,[0]]),
430
+ alpha_mask = mask_rows_dilated[...,None] if dilate_rows else mask_rows[...,None],
431
+ ) # HW4 RGBA
432
+
433
+ vndvi_interrows_img = alpha_composite(
434
+ raster,
435
+ vndvi_interrows_clipped_pooled_normalized,
436
+ opacity = 1.0,
437
+ colormap = 'RdYlGn',
438
+ alpha_image = np.zeros_like(raster[:,:,[0]]),
439
+ alpha_mask = mask_interrows[...,None],
440
+ ) # HW4 RGBA
441
+
442
+ # add colorbar
443
+ # fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10))
444
+ # divider = make_axes_locatable(ax)
445
+ # cax = divider.append_axes('right', size='5%', pad=0.15)
446
+ # ax.imshow(vndvi_rows_img)
447
+ # fig_rows.colorbar(
448
+ # mappable = mpl.cm.ScalarMappable(
449
+ # norm = mpl.colors.Normalize(
450
+ # vmin = vndvi_perc10,
451
+ # vmax = vndvi_perc90),
452
+ # cmap = 'RdYlGn'),
453
+ # cax = cax,
454
+ # orientation = 'vertical',
455
+ # label = 'vNDVI',
456
+ # shrink = 1)
457
+
458
+ # fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10))
459
+ # divider = make_axes_locatable(ax)
460
+ # cax = divider.append_axes('right', size='5%', pad=0.15)
461
+ # ax.imshow(vndvi_interrows_img)
462
+ # fig_interrows.colorbar(
463
+ # mappable = mpl.cm.ScalarMappable(
464
+ # norm = mpl.colors.Normalize(
465
+ # vmin = vndvi_perc10,
466
+ # vmax = vndvi_perc90),
467
+ # cmap = 'RdYlGn'),
468
+ # cax = cax,
469
+ # orientation = 'vertical',
470
+ # label = 'vNDVI',
471
+ # shrink = 1)
472
+
473
+ # return fig_rows, fig_interrows
474
+ return vndvi_rows_img, vndvi_interrows_img
475
+
476
+
477
+
478
+ def compute_vdi(
479
+ raster: np.ndarray,
480
+ mask: np.ndarray,
481
+ window_size=360,
482
+ granularity=40,
483
+ ):
484
+
485
+ # CHW -> HWC
486
+ raster = raster.transpose(1,2,0)
487
+
488
+ # Mask has values: 0=interrows, 255=rows, 1=nodata
489
+ # Get mask for the rows and interrows
490
+ mask_rows = (mask == 255)
491
+ mask_interrows = (mask == 0)
492
+ mask_valid = mask_rows | mask_interrows
493
+
494
+ # compute vdi
495
+ vdi, vdi_min, vdi_max = sliding_window_avg_pooling(
496
+ mask_rows[...,None],
497
+ window=window_size,
498
+ granularity=granularity,
499
+ alpha=mask_valid[...,None],
500
+ min_nonblank_pixels=0.9,
501
+ normalize=True,
502
+ return_min_max=True,
503
+ verbose=True,
504
+ )
505
+
506
+ # for visualization
507
+ vdi_img = alpha_composite(
508
+ raster,
509
+ vdi,
510
+ opacity = 1,
511
+ colormap = 'jet_r',
512
+ alpha_image = mask_valid[...,None],
513
+ alpha_mask = mask_valid[...,None],
514
+ )
515
+
516
+ # add colorbar
517
+ # fig, ax = plt.subplots(1, 1, figsize=(10, 10))
518
+ # divider = make_axes_locatable(ax)
519
+ # cax = divider.append_axes('right', size='5%', pad=0.15)
520
+ # ax.imshow(vdi_img)
521
+ # fig.colorbar(
522
+ # mappable = mpl.cm.ScalarMappable(
523
+ # norm = mpl.colors.Normalize(
524
+ # vmin = vdi_min,
525
+ # vmax = vdi_max),
526
+ # cmap = 'jet_r'),
527
+ # cax = cax,
528
+ # orientation = 'vertical',
529
+ # label = 'VDI',
530
+ # shrink = 1)
531
+
532
+ # return fig
533
+ return vdi_img
534
+
535
+
536
+
537
+ def compute_mask(
538
+ raster: np.ndarray,
539
+ model: torch.nn.Module,
540
+ patch_size=512,
541
+ stride=256,
542
+ scaling_factor=None,
543
+ rotate=False,
544
+ batch_size=16
545
+ ):
546
+ assert isinstance(raster, np.ndarray), f'Input raster must be a numpy array. Got {type(raster)}'
547
+ assert len(raster.shape) == 3, f'Input raster must have 3 dimensions (bands, rows, cols). Got shape {raster.shape}'
548
+ assert raster.shape[0] in [3,4], f'Input raster must have 3 bands (RGB) or 4 bands (RGBA). Got {raster.shape[0]} bands'
549
+ assert isinstance(model, torch.nn.Module), 'Model must be a torch.nn.Module'
550
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
551
+
552
+ # Infer GSD
553
+ #gsd = abs(raster.rio.transform()[0]) # ground sampling distance (NB: valid only if image is a GeoTIFF)
554
+
555
+ # Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to
556
+ # specify a scaling factor different from 1 if your image has a different gsd.
557
+ # E.g.: SCALING_FACTOR = gsd / 0.015
558
+ # logger.info(f'Image GSD: {gsd*100:.2f} cm/px')
559
+ # scaling_factor = scaling_factor or (gsd / 0.015)
560
+ scaling_factor = scaling_factor or 1
561
+ logger.info(f'Applying scaling factor: {scaling_factor:.2f}')
562
+
563
+ # segment
564
+ logger.info('Segmenting image...')
565
+ score_map = segment(
566
+ raster,
567
+ model,
568
+ patch_size=patch_size,
569
+ stride=stride,
570
+ scaling_factor=scaling_factor,
571
+ rotate=rotate,
572
+ device=device,
573
+ batch_size=batch_size,
574
+ verbose=True
575
+ ) # mask is a HxW float32 array in [0, 1]
576
+
577
+ # apply threshold on confidence scores
578
+ alpha = (score_map == -1)
579
+ mask = (score_map > 0.5)
580
+
581
+ # convert to uint8
582
+ mask = (mask * 255).astype(np.uint8)
583
+
584
+ # set nodata pixels to 1
585
+ mask[alpha] = 1
586
+
587
+ return mask
lib/viz_utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import functools
3
+ import numpy as np
4
+ import cv2
5
+ import cmapy
6
+ from PIL import Image
7
+ import matplotlib
8
+
9
+
10
+
11
+ # BUGFIX in cmapy.py
12
+ def cmap(cmap_name, rgb_order=False):
13
+ """
14
+ Extract colormap color information as a LUT compatible with cv2.applyColormap().
15
+ Default channel order is BGR.
16
+
17
+ Args:
18
+ cmap_name: string, name of the colormap.
19
+ rgb_order: boolean, if false or not set, the returned array will be in
20
+ BGR order (standard OpenCV format). If true, the order
21
+ will be RGB.
22
+
23
+ Returns:
24
+ A numpy array of type uint8 containing the colormap.
25
+ """
26
+
27
+ c_map = matplotlib.colormaps.get_cmap(cmap_name)
28
+ rgba_data = matplotlib.cm.ScalarMappable(cmap=c_map).to_rgba(
29
+ np.arange(0, 1.0, 1.0 / 256.0), bytes=True
30
+ )
31
+ rgba_data = rgba_data[:, 0:-1].reshape((256, 1, 3))
32
+
33
+ # Convert to BGR (or RGB), uint8, for OpenCV.
34
+ cmap = np.zeros((256, 1, 3), np.uint8)
35
+
36
+ if not rgb_order:
37
+ cmap[:, :, :] = rgba_data[:, :, ::-1]
38
+ else:
39
+ cmap[:, :, :] = rgba_data[:, :, :]
40
+
41
+ return cmap
42
+
43
+ # If python 3, redefine cmap() to use lru_cache.
44
+ if sys.version_info > (3, 0):
45
+ cmap = functools.lru_cache(maxsize=200)(cmap)
46
+
47
+
48
+
49
+ def alpha_composite(img, msk, opacity=0.5, colormap=None, alpha_image=None, alpha_mask=None, red_mask=False):
50
+ """Alpha composite an RGBA image (img) and a grayscale mask (msk).
51
+ - If alpha_image is None, img's alpha channel is used (or, if not present,
52
+ initialized to all 255).
53
+ - If alpha_mask is None, msk is overlaid on img only where img's alpha
54
+ channel is not 0.
55
+ - If alpha_mask is not None, the above behavior is overridden and msk is
56
+ overlaid on img only where alpha_mask is not 0."""
57
+ # only HWC numpy arrays allowed
58
+ assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}'
59
+ assert isinstance(msk, np.ndarray), f'Input mask must be a numpy array. Got {type(msk)}'
60
+ if alpha_mask is not None:
61
+ assert isinstance(alpha_mask, np.ndarray), f'Alpha mask must be a numpy array. Got {type(alpha_mask)}'
62
+ assert alpha_mask.dtype in [np.float32, bool], f'Alpha mask must be of type np.float32 or bool. Got {alpha_mask.dtype}'
63
+ assert alpha_mask.shape[2] == 1, f'Alpha mask must be formatted as HWC, with C = 1. Got a shape of {msk.shape}'
64
+ assert img.shape[2] in [3,4], f'Input image must be formatted as HWC, with C = 3,4. Got a shape of {img.shape}'
65
+ assert msk.shape[2] == 1, f'Input mask must be formatted as HWC, with C = 1. Got a shape of {msk.shape}'
66
+ assert (opacity >= 0) and (opacity <= 1), f'Mask opacity must be between 0 and 1. Got {opacity}'
67
+
68
+ # to avoid modifying the original arrays
69
+ img = img.copy()
70
+ msk = msk.copy()
71
+
72
+ if img.shape[2] == 3:
73
+ # add alpha channel to img
74
+ img = np.concatenate([
75
+ img,
76
+ np.full((img.shape[0], img.shape[1], 1), 255, dtype=np.uint8)
77
+ ], axis=-1)
78
+
79
+ if alpha_image is None:
80
+ # initialize alpha_image to all Trues
81
+ alpha_image = img[:,:,[3]]
82
+ # convert alpha image to bool
83
+ alpha_image = alpha_image.astype(bool)
84
+
85
+ if alpha_mask is None:
86
+ # initialize alpha_mask to alpha_image
87
+ alpha_mask = alpha_image # so that alpha_mask is AT LEAST as restrictive as alpha_image
88
+ # convert alpha mask to bool
89
+ alpha_mask = alpha_mask.astype(bool)
90
+
91
+
92
+ if msk.dtype != np.uint8:
93
+ # convert mask to a uint8 grayscale image ([0,1] -> [0,255])
94
+ # NB: normalize the pixels of the mask we are interested in to [0,1]
95
+ # before passing it as input!!!
96
+ msk = (msk * 255).astype(np.uint8)
97
+
98
+ # convert mask from grayscale to RGBA
99
+ msk = cv2.cvtColor(msk, cv2.COLOR_GRAY2RGBA)
100
+
101
+ if colormap is not None:
102
+ # apply specified colormap to msk
103
+ # NB: values near 0 will be converted to the first colors of the chosen
104
+ # colormap, whereas values near 255 will be converted to the last colors
105
+ msk[:,:,:3] = cmapy.colorize(msk[:,:,:3], colormap, rgb_order=True)
106
+ elif red_mask:
107
+ # convert white to red
108
+ msk[:,:,[1,2]] = 0
109
+
110
+
111
+ # apply alpha_image to img's alpha channel
112
+ img[:,:,[3]] = (alpha_image * img[:,:,[3]]).astype(np.uint8)
113
+
114
+ # apply alpha_mask and opacity to msk's alpha channel
115
+ msk[:,:,[3]] = (alpha_mask * opacity * msk[:,:,[3]]).astype(np.uint8)
116
+
117
+ # alpha compositing
118
+ img_pil = Image.fromarray(img)
119
+ msk_pil = Image.fromarray(msk)
120
+ img_pil.alpha_composite(msk_pil)
121
+
122
+ return np.array(img_pil)
123
+
124
+
125
+
precompute_examples.ipynb ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import pandas as pd\n",
10
+ "import geopandas as gpd\n",
11
+ "import rioxarray as rxr\n",
12
+ "import xarray as xr\n",
13
+ "import numpy as np\n",
14
+ "import os\n",
15
+ "import torch\n",
16
+ "from transformers import SegformerForSemanticSegmentation\n",
17
+ "from lib.utils import compute_mask, compute_vndvi, compute_vdi"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 2,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "# # Read raster data\n",
27
+ "# raster_path = \"data/spain_2022-07-29.tif\"\n",
28
+ "# raster = rxr.open_rasterio(raster_path)\n",
29
+ "\n",
30
+ "# # Crop raster with GeoJSON geometry, if available\n",
31
+ "# geom_path = raster_path.replace(\".tif\", \".geojson\")\n",
32
+ "# if os.path.exists(geom_path):\n",
33
+ "# geom = gpd.read_file(geom_path)\n",
34
+ "# raster = raster.rio.clip(geom.geometry)\n",
35
+ "# raster.rio.to_raster(raster_path.replace(\".tif\", \"_cropped.tif\"))"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 3,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def load_model(hf_path='links-ads/gaia-growseg'):\n",
45
+ " # logger.info(f'Loading GAIA GRowSeg on {device}...')\n",
46
+ " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
47
+ " model = SegformerForSemanticSegmentation.from_pretrained(\n",
48
+ " hf_path,\n",
49
+ " num_labels=1,\n",
50
+ " num_channels=3,\n",
51
+ " id2label={1: 'vine'},\n",
52
+ " label2id={'vine': 1},\n",
53
+ " token=os.getenv('hf_read_access_token')\n",
54
+ " )\n",
55
+ " return model.to(device).eval()\n",
56
+ "\n",
57
+ "# Load GAIA GRowSeg model\n",
58
+ "model = load_model()"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 10,
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stderr",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "\u001b[32m2025-03-20 12:39:09.921\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlib.utils\u001b[0m:\u001b[36msliding_window_avg_pooling\u001b[0m:\u001b[36m308\u001b[0m - \u001b[1mExtracting patches idx...\u001b[0m\n",
71
+ "100%|█████████████████████████████████████████████| 67848/67848 [00:03<00:00, 20745.29it/s]\n",
72
+ "\u001b[32m2025-03-20 12:39:14.795\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlib.utils\u001b[0m:\u001b[36msliding_window_avg_pooling\u001b[0m:\u001b[36m308\u001b[0m - \u001b[1mExtracting patches idx...\u001b[0m\n",
73
+ "100%|█████████████████████████████████████████████| 67848/67848 [00:03<00:00, 19329.36it/s]\n",
74
+ "\u001b[32m2025-03-20 12:39:56.011\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mlib.utils\u001b[0m:\u001b[36msliding_window_avg_pooling\u001b[0m:\u001b[36m308\u001b[0m - \u001b[1mExtracting patches idx...\u001b[0m\n",
75
+ "100%|██████████████████████████████████████████████| 64758/64758 [00:20<00:00, 3203.45it/s]\n"
76
+ ]
77
+ }
78
+ ],
79
+ "source": [
80
+ "raster_path = \"data/italy_2022-06-13_cropped.tif\"\n",
81
+ "patch_size = 512\n",
82
+ "stride = 256\n",
83
+ "scaling_factor = 1.0\n",
84
+ "dilate_rows = False\n",
85
+ "window_size = 360\n",
86
+ "granularity = int(window_size/8)\n",
87
+ "\n",
88
+ "# raster_path = \"data/spain_2022-07-29_cropped.tif\"\n",
89
+ "# patch_size = 512\n",
90
+ "# stride = 256\n",
91
+ "# scaling_factor = 1.0\n",
92
+ "# dilate_rows = False\n",
93
+ "# window_size = 400\n",
94
+ "# granularity = int(window_size/8)\n",
95
+ "\n",
96
+ "# raster_path = \"data/portugal_2023-08-01.tif\"\n",
97
+ "# patch_size = 512\n",
98
+ "# stride = 256\n",
99
+ "# scaling_factor = 1.25\n",
100
+ "# dilate_rows = False\n",
101
+ "# window_size = 80\n",
102
+ "# granularity = int(window_size/8)\n",
103
+ "\n",
104
+ "raster = rxr.open_rasterio(raster_path)\n",
105
+ "\n",
106
+ "# Compute mask\n",
107
+ "mask_path = raster_path.replace(\".tif\", \"_mask.tif\")\n",
108
+ "if not os.path.exists(mask_path):\n",
109
+ " mask = compute_mask(\n",
110
+ " raster.to_numpy(),\n",
111
+ " model,\n",
112
+ " patch_size=patch_size,\n",
113
+ " stride=stride,\n",
114
+ " scaling_factor=scaling_factor,\n",
115
+ " rotate=False,\n",
116
+ " batch_size=16,\n",
117
+ " ) # mask is a HxW uint8 array in with 0=background, 255=vine, 1=nodata\n",
118
+ "\n",
119
+ " # Convert mask from grayscale to RGBA, with red pixels for vine\n",
120
+ " alpha = ((mask != 1)*255).astype(np.uint8)\n",
121
+ " mask_colored = np.stack([mask, np.zeros_like(mask), np.zeros_like(mask), alpha], axis=0) # now, mask is a 4xHxW uint8 array in with 0=background, 255=vine\n",
122
+ "\n",
123
+ " # Georef mask like raster\n",
124
+ " mask_raster = xr.DataArray(\n",
125
+ " mask_colored,\n",
126
+ " dims=('band', 'y', 'x'),\n",
127
+ " coords={'x': raster.x, 'y': raster.y, 'band': raster.band}\n",
128
+ " )\n",
129
+ " mask_raster.rio.write_crs(raster.rio.crs, inplace=True) # Copy CRS\n",
130
+ " mask_raster.rio.write_transform(raster.rio.transform(), inplace=True) # Copy affine transform\n",
131
+ "\n",
132
+ " # Save mask\n",
133
+ " mask_raster.rio.to_raster(raster_path.replace(\".tif\", \"_mask.tif\"), compress='lzw')\n",
134
+ "else:\n",
135
+ " mask = rxr.open_rasterio(mask_path).sel(band=1).squeeze().to_numpy()\n",
136
+ "\n",
137
+ "# Compute vNDVI\n",
138
+ "vndvi_rows_path = raster_path.replace(\".tif\", \"_vndvi_rows.tif\")\n",
139
+ "vndvi_interrows_path = raster_path.replace(\".tif\", \"_vndvi_interrows.tif\")\n",
140
+ "if not os.path.exists(vndvi_rows_path) or not os.path.exists(vndvi_interrows_path):\n",
141
+ " vndvi_rows, vndvi_interrows = compute_vndvi(\n",
142
+ " raster.to_numpy(),\n",
143
+ " mask,\n",
144
+ " dilate_rows=dilate_rows,\n",
145
+ " window_size=window_size,\n",
146
+ " granularity=granularity,\n",
147
+ " ) # vNDVI is RGBA\n",
148
+ "\n",
149
+ " # Georef vNDVI like raster\n",
150
+ " vndvi_rows_raster = xr.DataArray(\n",
151
+ " vndvi_rows.transpose(2, 0, 1),\n",
152
+ " dims=('band', 'y', 'x'),\n",
153
+ " coords={'x': raster.x, 'y': raster.y, 'band': raster.band}\n",
154
+ " )\n",
155
+ " vndvi_rows_raster.rio.write_crs(raster.rio.crs, inplace=True)\n",
156
+ " vndvi_rows_raster.rio.write_transform(raster.rio.transform(), inplace=True)\n",
157
+ "\n",
158
+ " vndvi_interrows_raster = xr.DataArray(\n",
159
+ " vndvi_interrows.transpose(2, 0, 1),\n",
160
+ " dims=('band', 'y', 'x'),\n",
161
+ " coords={'x': raster.x, 'y': raster.y, 'band': raster.band}\n",
162
+ " )\n",
163
+ " vndvi_interrows_raster.rio.write_crs(raster.rio.crs, inplace=True)\n",
164
+ " vndvi_interrows_raster.rio.write_transform(raster.rio.transform(), inplace=True)\n",
165
+ "\n",
166
+ " # Save vNDVI\n",
167
+ " vndvi_rows_raster.rio.to_raster(raster_path.replace(\".tif\", \"_vndvi_rows.tif\"), compress='lzw')\n",
168
+ " vndvi_interrows_raster.rio.to_raster(raster_path.replace(\".tif\", \"_vndvi_interrows.tif\"), compress='lzw')\n",
169
+ "\n",
170
+ "# Compute VDI\n",
171
+ "vdi_path = raster_path.replace(\".tif\", \"_vdi.tif\")\n",
172
+ "if not os.path.exists(vdi_path):\n",
173
+ " vdi = compute_vdi(\n",
174
+ " raster.to_numpy(),\n",
175
+ " mask,\n",
176
+ " window_size=window_size,\n",
177
+ " granularity=granularity,\n",
178
+ " ) # VDI is RGBA\n",
179
+ "\n",
180
+ " # Georef VDI like raster\n",
181
+ " vdi_raster = xr.DataArray(\n",
182
+ " vdi.transpose(2, 0, 1),\n",
183
+ " dims=('band', 'y', 'x'),\n",
184
+ " coords={'x': raster.x, 'y': raster.y, 'band': raster.band}\n",
185
+ " )\n",
186
+ " vdi_raster.rio.write_crs(raster.rio.crs, inplace=True)\n",
187
+ " vdi_raster.rio.write_transform(raster.rio.transform(), inplace=True)\n",
188
+ "\n",
189
+ " # Save results\n",
190
+ " vdi_raster.rio.to_raster(raster_path.replace(\".tif\", \"_vdi.tif\"), compress='lzw')\n"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 11,
196
+ "metadata": {},
197
+ "outputs": [
198
+ {
199
+ "name": "stderr",
200
+ "output_type": "stream",
201
+ "text": [
202
+ "\u001b[32m2025-03-20 12:40:30.816\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m76\u001b[0m - \u001b[1mReprojecting rasters to EPSG:4326 with NODATA value 0...\u001b[0m\n",
203
+ "\u001b[32m2025-03-20 12:40:52.371\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[1mCreating RGB raster overlay...\u001b[0m\n",
204
+ "\u001b[32m2025-03-20 12:40:52.373\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mcreate_image_overlay\u001b[0m:\u001b[36m46\u001b[0m - \u001b[1mCreating overlay: 'Orthoimage'...\u001b[0m\n",
205
+ "\u001b[32m2025-03-20 12:40:58.801\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m86\u001b[0m - \u001b[1mCreating mask overlay...\u001b[0m\n",
206
+ "\u001b[32m2025-03-20 12:40:58.806\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mcreate_image_overlay\u001b[0m:\u001b[36m46\u001b[0m - \u001b[1mCreating overlay: 'Mask'...\u001b[0m\n",
207
+ "\u001b[32m2025-03-20 12:41:05.006\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m88\u001b[0m - \u001b[1mCreating vNDVI rows overlay...\u001b[0m\n",
208
+ "\u001b[32m2025-03-20 12:41:05.008\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mcreate_image_overlay\u001b[0m:\u001b[36m46\u001b[0m - \u001b[1mCreating overlay: 'vNDVI Rows'...\u001b[0m\n",
209
+ "\u001b[32m2025-03-20 12:41:10.988\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mCreating vNDVI interrows overlay...\u001b[0m\n",
210
+ "\u001b[32m2025-03-20 12:41:10.990\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mcreate_image_overlay\u001b[0m:\u001b[36m46\u001b[0m - \u001b[1mCreating overlay: 'vNDVI Interrows'...\u001b[0m\n",
211
+ "\u001b[32m2025-03-20 12:41:16.558\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m92\u001b[0m - \u001b[1mCreating VDI overlay...\u001b[0m\n",
212
+ "\u001b[32m2025-03-20 12:41:16.560\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mcreate_image_overlay\u001b[0m:\u001b[36m46\u001b[0m - \u001b[1mCreating overlay: 'VDI'...\u001b[0m\n"
213
+ ]
214
+ }
215
+ ],
216
+ "source": [
217
+ "import folium\n",
218
+ "from loguru import logger\n",
219
+ "\n",
220
+ "def create_map(location=[41.9099533, 12.3711879], zoom_start=5, crs=3857, max_zoom=23):\n",
221
+ " \"\"\"Create a folium map with OpenStreetMap tiles and optional Esri.WorldImagery basemap.\"\"\"\n",
222
+ " if isinstance(crs, int):\n",
223
+ " crs = f\"EPSG{crs}\"\n",
224
+ " assert crs in [\"EPSG3857\"], f\"Only EPSG:3857 supported for now. Got {crs}.\"\n",
225
+ " \n",
226
+ " m = folium.Map(\n",
227
+ " location=location,\n",
228
+ " zoom_start=zoom_start,\n",
229
+ " crs=crs,\n",
230
+ " max_zoom=max_zoom,\n",
231
+ " tiles=\"OpenStreetMap\", # Esri.WorldImagery\n",
232
+ " attributionControl=False,\n",
233
+ " prefer_canvas=True,\n",
234
+ " )\n",
235
+ "\n",
236
+ " # Add Esri.WorldImagery as optional basemap (radio button)\n",
237
+ " folium.TileLayer(\n",
238
+ " tiles=\"Esri.WorldImagery\",\n",
239
+ " show=False,\n",
240
+ " overlay=False,\n",
241
+ " control=True,\n",
242
+ " ).add_to(m)\n",
243
+ "\n",
244
+ " return m\n",
245
+ "\n",
246
+ "def create_image_overlay(raster_path_or_array, name=\"Raster\", opacity=1.0, to_crs=4326, show=True):\n",
247
+ " \"\"\" Create a folium image overlay from a raster filepath or xarray.DataArray. \"\"\"\n",
248
+ " if isinstance(raster_path_or_array, str):\n",
249
+ " # Open the raster and its metadata\n",
250
+ " logger.info(f\"Opening raster: {raster_path_or_array!r}...\")\n",
251
+ " r = rxr.open_rasterio(raster_path_or_array)\n",
252
+ " else:\n",
253
+ " r = raster_path_or_array\n",
254
+ " nodata = r.rio.nodata or 0\n",
255
+ " if r.rio.crs.to_epsg() != to_crs:\n",
256
+ " logger.info(f\"Reprojecting raster to EPSG:{to_crs} with NODATA value {nodata}...\")\n",
257
+ " r = r.rio.reproject(to_crs, nodata=nodata) # nodata default: 255\n",
258
+ " r = r.transpose(\"y\", \"x\", \"band\")\n",
259
+ " bounds = r.rio.bounds() # (left, bottom, right, top)\n",
260
+ "\n",
261
+ " # Create a folium image overlay\n",
262
+ " logger.info(f\"Creating overlay: {name!r}...\")\n",
263
+ " overlay = folium.raster_layers.ImageOverlay(\n",
264
+ " image=r.to_numpy(),\n",
265
+ " name=name,\n",
266
+ " bounds=[[bounds[1], bounds[0]], [bounds[3], bounds[2]]], # format for folium: ((bottom,left),(top,right))\n",
267
+ " opacity=opacity,\n",
268
+ " interactive=True,\n",
269
+ " cross_origin=False,\n",
270
+ " zindex=1,\n",
271
+ " show=show,\n",
272
+ " )\n",
273
+ "\n",
274
+ " return overlay\n",
275
+ "\n",
276
+ "# Define paths\n",
277
+ "raster_path = \"data/portugal_2023-08-01.tif\"\n",
278
+ "mask_path = raster_path.replace('.tif', '_mask.tif')\n",
279
+ "vndvi_rows_path = raster_path.replace('.tif', '_vndvi_rows.tif')\n",
280
+ "vndvi_interrows_path = raster_path.replace('.tif', '_vndvi_interrows.tif')\n",
281
+ "vdi_path = raster_path.replace('.tif', '_vdi.tif')\n",
282
+ "\n",
283
+ "# Load rasters\n",
284
+ "raster = rxr.open_rasterio(raster_path)\n",
285
+ "mask_raster = rxr.open_rasterio(mask_path)\n",
286
+ "vndvi_rows_raster = rxr.open_rasterio(vndvi_rows_path)\n",
287
+ "vndvi_interrows_raster = rxr.open_rasterio(vndvi_interrows_path)\n",
288
+ "vdi_raster = rxr.open_rasterio(vdi_path)\n",
289
+ "\n",
290
+ "# Reproject all rasters to EPSG:4326\n",
291
+ "if raster.rio.crs.to_epsg() != 4326:\n",
292
+ " logger.info(f\"Reprojecting rasters to EPSG:4326 with NODATA value 0...\")\n",
293
+ " raster = raster.rio.reproject(\"EPSG:4326\", nodata=0) # nodata default: 255\n",
294
+ " mask_raster = mask_raster.rio.reproject(\"EPSG:4326\", nodata=0)\n",
295
+ " vndvi_rows_raster = vndvi_rows_raster.rio.reproject(\"EPSG:4326\", nodata=0)\n",
296
+ " vndvi_interrows_raster = vndvi_interrows_raster.rio.reproject(\"EPSG:4326\", nodata=0)\n",
297
+ " vdi_raster = vdi_raster.rio.reproject(\"EPSG:4326\", nodata=0)\n",
298
+ "\n",
299
+ "# Create overlays\n",
300
+ "logger.info(f'Creating RGB raster overlay...')\n",
301
+ "raster_overlay = create_image_overlay(raster, name=\"Orthoimage\", opacity=1.0, show=True)\n",
302
+ "logger.info(f'Creating mask overlay...')\n",
303
+ "mask_overlay = create_image_overlay(mask_raster, name=\"Mask\", opacity=1.0, show=False)\n",
304
+ "logger.info(f'Creating vNDVI rows overlay...')\n",
305
+ "vndvi_rows_overlay = create_image_overlay(vndvi_rows_raster, name=\"vNDVI Rows\", opacity=1.0, show=False)\n",
306
+ "logger.info(f'Creating vNDVI interrows overlay...')\n",
307
+ "vndvi_interrows_overlay = create_image_overlay(vndvi_interrows_raster, name=\"vNDVI Interrows\", opacity=1.0, show=False)\n",
308
+ "logger.info(f'Creating VDI overlay...')\n",
309
+ "vdi_overlay = create_image_overlay(vdi_raster, name=\"VDI\", opacity=1.0, show=False)"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 12,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "m = create_map()\n",
319
+ "raster_overlay.add_to(m)\n",
320
+ "mask_overlay.add_to(m)\n",
321
+ "vndvi_rows_overlay.add_to(m)\n",
322
+ "vndvi_interrows_overlay.add_to(m)\n",
323
+ "vdi_overlay.add_to(m)\n",
324
+ "\n",
325
+ "# Add layer control\n",
326
+ "folium.LayerControl().add_to(m)\n",
327
+ "\n",
328
+ "# Fit map to bounds\n",
329
+ "m.fit_bounds(raster_overlay.get_bounds())\n",
330
+ "\n",
331
+ "# Save map\n",
332
+ "map_path = raster_path.replace('.tif', '.html')\n",
333
+ "m.save(map_path)"
334
+ ]
335
+ }
336
+ ],
337
+ "metadata": {
338
+ "kernelspec": {
339
+ "display_name": "Python 3 (ipykernel)",
340
+ "language": "python",
341
+ "name": "python3"
342
+ },
343
+ "language_info": {
344
+ "codemirror_mode": {
345
+ "name": "ipython",
346
+ "version": 3
347
+ },
348
+ "file_extension": ".py",
349
+ "mimetype": "text/x-python",
350
+ "name": "python",
351
+ "nbconvert_exporter": "python",
352
+ "pygments_lexer": "ipython3",
353
+ "version": "3.10.12"
354
+ }
355
+ },
356
+ "nbformat": 4,
357
+ "nbformat_minor": 2
358
+ }
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ scipy
3
+ rasterio
4
+ torch
5
+ transformers
6
+ tqdm
7
+ loguru
8
+ opencv-python-headless
9
+ pillow
10
+ matplotlib
11
+ cmapy
12
+ python-dotenv
13
+ rioxarray
14
+ geopandas