File size: 13,554 Bytes
8e3e789
 
af0160d
 
85993f4
08c1d69
85993f4
08c1d69
 
48c62f7
85993f4
08c1d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7672122
 
99ee6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85993f4
 
 
99ee6d2
b7e5937
08c1d69
b7e5937
 
 
 
 
 
 
 
 
 
 
 
 
a563c94
b7e5937
af0160d
 
3c2c5b7
af0160d
 
 
85993f4
08c1d69
af0e7f9
af0160d
 
3c2c5b7
 
 
 
 
85993f4
3c2c5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85993f4
af0e7f9
85993f4
08c1d69
85993f4
 
 
 
 
 
 
08c1d69
85993f4
 
 
08c1d69
85993f4
 
 
 
 
 
 
 
 
 
 
 
99ee6d2
85993f4
 
 
 
 
 
99ee6d2
85993f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c1d69
85993f4
 
 
 
 
 
 
 
 
08c1d69
3c2c5b7
08c1d69
 
85993f4
7672122
08c1d69
 
 
 
 
 
 
85993f4
08c1d69
99ee6d2
85993f4
 
 
 
 
 
1920ef2
 
6cb40a3
99ee6d2
48c62f7
1920ef2
08c1d69
1920ef2
08c1d69
1920ef2
 
08c1d69
1920ef2
b7e5937
3c2c5b7
08c1d69
 
 
 
 
 
 
 
 
 
 
 
 
8e3e789
 
 
 
 
 
 
 
 
 
 
 
 
5803fc9
b7e5937
 
 
08c1d69
3c2c5b7
8e3e789
 
 
 
 
 
 
 
 
 
3c2c5b7
 
 
 
 
8e3e789
3c2c5b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import base64

import streamlit as st
import zipfile
from utils import *
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import streamlit.components.v1 as components
from matplotlib import colors

st.set_page_config(layout="wide")

def create_animation(images, pred_dates):
    print('Creating composition of images...')
    fps = 2
    fig_an, ax_an = plt.subplots()
    plt.title("")
    a = images[0]
    im = ax_an.imshow(a, interpolation='none', aspect='auto', vmin=0, vmax=1)

    title = ax_an.text(0.5, 0.85, "", bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5},
                    transform=ax_an.transAxes, ha="center")

    def animate_func(idx):
        title.set_text("date: " + pred_dates[idx])
        im.set_array(images[idx])
        return [im]

    anima = animation.FuncAnimation(fig_an, animate_func, frames=len(images), interval=1000 / fps, blit=True,
                                    repeat=False)
    print('Done!')
    return anima


def load_daily_preds_as_animations(pred_full_paths, pred_dates):
    daily_preds = []
    for path in pred_full_paths:
        img, _ = read(path)
        img = np.squeeze(img)
        img = [classes_color_map[p] for p in img]
        daily_preds.append(img)
    anima = create_animation(daily_preds, pred_dates)
    return anima


def load_src_images_as_animations(img_paths, pred_dates):
    imgs = []
    for path in img_paths:
        img, _ = read(path)

        # https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/composites/
        # IREA image:
        # False colors (8,4,3): 2,blue-B3,green-B4,5,6,7,red-B8,11,12
        # Simple RGB (4, 3, 2): blue-B2,green-B3,red-B4,5,6,7,8,11,12
        rgb = img[[2, 1, 0], :, :]
        rgb = np.moveaxis(rgb, 0, -1)
        imgs.append(rgb/np.amax(rgb))
    anima = create_animation(imgs, pred_dates)
    return anima


if not hasattr(st, 'paths'):
    st.paths = None
if not hasattr(st, 'daily_model'):
    best_model_daily_file_name = "best_model_daily.pth"
    best_model_annual_file_name = "best_model_annual.pth"

    first_input_batch = torch.zeros(71, 9, 5, 48, 48)
    # first_input_batch = first_input_batch.view(-1, *first_input_batch.shape[2:])
    st.daily_model = FPN(opt, first_input_batch, opt.win_size)
    st.annual_model = SimpleNN(opt)

    if torch.cuda.is_available():
        st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
        st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
        st.daily_model = torch.nn.DataParallel(st.daily_model).cuda()
        st.annual_model = torch.nn.DataParallel(st.annual_model).cuda()
    else:
        st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
        st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()
        st.daily_model = torch.nn.DataParallel(st.daily_model).cpu()
        st.annual_model = torch.nn.DataParallel(st.annual_model).cpu()

    print('trying to resume previous saved models...')
    state = resume(
        os.path.join(opt.resume_path, best_model_daily_file_name),
        model=st.daily_model, optimizer=None)
    state = resume(
        os.path.join(opt.resume_path, best_model_annual_file_name),
        model=st.annual_model, optimizer=None)
    st.daily_model = st.daily_model.eval()
    st.annual_model = st.annual_model.eval()

# Load Model
# @title Load pretrained weights


st.title('In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series')
st.markdown(""" Demo App for the model presented in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203):
```
@article{gallo2022in_season,
  title = {In-season and dynamic crop mapping using 3D convolution neural networks and sentinel-2 time series},
  journal = {ISPRS Journal of Photogrammetry and Remote Sensing},
  volume = {195},
  pages = {335-352},
  year = {2023},
  issn = {0924-2716},
  doi = {https://doi.org/10.1016/j.isprsjprs.2022.12.005},
  url = {https://www.sciencedirect.com/science/article/pii/S0924271622003203},
  author = {Ignazio Gallo and Luigi Ranghetti and Nicola Landro and Riccardo {La Grassa} and Mirco Boschetti},
}
```
**NOTE: The demo doesn't work properly, we are working to fix the bugs!**
""")

file_uploaded = st.file_uploader(
    "Upload a zip file containing a sample",
    type=["zip"],
    accept_multiple_files=False,
)
sample_path = None
tileids = None
st.paths = None
if file_uploaded is not None:
    with zipfile.ZipFile(file_uploaded, "r") as z:
        z.extractall(os.path.join("uploaded_samples", opt.years[0]))
    tileids = [file_uploaded.name[:-4]]
    # sample_path = os.path.join("uploaded_samples", opt.years[0], tileids[0])
    sample_path = "uploaded_samples"

st.markdown('or use a demo sample')
col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
with col1:
    if st.button('sample 1'):
        sample_path = 'demo_data/lombardia'
        tileids = ['24']
with col2:
    if st.button('sample 2'):
        sample_path = 'demo_data/lombardia'
        tileids = ['712']
with col3:
    if st.button('sample 3'):
        sample_path = 'demo_data/lombardia'
        tileids = ['814']
with col4:
    if st.button('sample 4'):
        sample_path = 'demo_data/lombardia'
        tileids = ['1509']

# paths = None
if sample_path is not None:
    # st.markdown(f'elaborating {sample_path} ...')

    validationdataset = SentinelDailyAnnualDatasetNoLabel(
        sample_path,
        opt.years,
        opt.classes_path,
        opt.sample_duration,
        opt.win_size,
        tileids=tileids)
    validationdataloader = torch.utils.data.DataLoader(
        validationdataset, batch_size=opt.batch_size, shuffle=False, num_workers=opt.workers)

    st.markdown('Model prediction in progress ...')

    out_dir = os.path.join(opt.result_path, "seg_maps")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    for i, (x_dailies, dates, dirs_path) in enumerate(validationdataloader):
        with torch.no_grad():
            # x_dailies, dates, dirs_path = next(iter(validationdataloader))
            # reshape merging the first two dimensions
            x_dailies = x_dailies.view(-1, *x_dailies.shape[2:])
            if torch.cuda.is_available():
                x_dailies = x_dailies.cuda()

            feat_daily, outs_daily = st.daily_model.forward(x_dailies)
            # return to original size of batch and year
            outs_daily = outs_daily.view(
                opt.batch_size, opt.sample_duration, *outs_daily.shape[1:])
            feat_daily = feat_daily.view(
                opt.batch_size, opt.sample_duration, *feat_daily.shape[1:])

            _, out_annual = st.annual_model.forward(feat_daily)
            pred_annual = torch.argmax(out_annual, dim=1).squeeze(1)
            pred_annual = pred_annual.cpu().numpy()
            # Remapping the labels
            pred_annual_nn = ids_to_labels(
                validationdataloader, pred_annual).astype(numpy.uint8)

            for batch in range(feat_daily.shape[0]):
                # _, profile = read(os.path.join(dirs_path[batch], '20191230_MSAVI.tif'))  # todo get the last image
                _, tmp_path = get_patch_id(validationdataset.samples, 0)
                dates = get_all_dates(
                    tmp_path, validationdataset.max_seq_length)
                last_tif_path = os.path.join(tmp_path, dates[-1] + ".tif")
                _, profile = read(last_tif_path)
                profile["name"] = dirs_path[batch]

                pth = dirs_path[batch].split(os.path.sep)[-3:]
                full_pth_patch = os.path.join(
                    out_dir, pth[1] + '-' + pth[0], pth[2])

                if not os.path.exists(full_pth_patch):
                    os.makedirs(full_pth_patch)
                full_pth_pred = os.path.join(
                    full_pth_patch, 'patch-pred-nn.tif')
                profile.update({
                    'nodata': None,
                    'dtype': 'uint8',
                    'count': 1})
                with rasterio.open(full_pth_pred, 'w', **profile) as dst:
                    dst.write_band(1, pred_annual_nn[batch])

                # patch_predictions = None
                for ch in range(len(dates)):
                    soft_seg = outs_daily[batch, ch, :, :, :]
                    # transform probs into a hard segmentation
                    pred_daily = torch.argmax(soft_seg, dim=0)
                    pred_daily = pred_daily.cpu()
                    daily_pred = ids_to_labels(
                        validationdataloader, pred_daily).astype(numpy.uint8)
                    # if patch_predictions is None:
                    #     patch_predictions = numpy.expand_dims(daily_pred, axis=0)
                    # else:
                    #     patch_predictions = numpy.concatenate((patch_predictions, numpy.expand_dims(daily_pred, axis=0)),
                    #                                           axis=0)

                    # save GT image in  opt.root_path
                    full_pth_date = os.path.join(
                        full_pth_patch, dates[ch] + '-daily-pred.tif')
                    profile.update({
                        'nodata': None,
                        'dtype': 'uint8',
                        'count': 1})
                    with rasterio.open(full_pth_date, 'w', **profile) as dst:
                        dst.write_band(1, daily_pred)

    st.markdown('End prediction')

    # folder_out = "demo_data/results/seg_maps/example-lombardia/2"
    folder_out = full_pth_patch  # os.path.join("demo_data/results/seg_maps/"+opt.years[0]+"-lombardia/", tileids[0])
    st.paths = os.listdir(folder_out)
    st.paths.sort()

if st.paths is not None:
    # folder_out = os.path.join("demo_data/results/seg_maps/example-lombardia/", tileids[0])
    folder_src = os.path.join("demo_data/lombardia/", opt.years[0], tileids[0])
    st.markdown(""" 
        ### Predictions
        """)
    # file_picker = st.selectbox("Select day predict (annual is patch-pred-nn.tif)",
    #                            st.paths, index=st.paths.index('patch-pred-nn.tif'))

    file_path = os.path.join(folder_out, 'patch-pred-nn.tif')
    # print(file_path)
    target, profile = read(file_path)
    target = np.squeeze(target)
    target = [classes_color_map[p] for p in target]

    fig, ax = plt.subplots()
    ax.imshow(target)

    markdown_legend = ''
    for c, l in zip(color_labels, labels_map):
        # print(colors.to_hex(c))
        markdown_legend += f'<div style="color:gray;background-color: {colors.to_hex(c)};">{l}</div><br>'

    col1, col2 = st.columns([2,1])
    with col1:
        st.markdown("**Long-term (annual) prediction**")
        st.pyplot(fig)
    with col2:
        st.markdown("**Legend**")
        st.markdown(markdown_legend, unsafe_allow_html=True)

    st.markdown("**Short-term (daily) predictions**")
    img_full_paths = [os.path.join(folder_out, path) for path in st.paths if 'daily-pred' in path]
    pred_dates = [path[:8] for path in st.paths if 'daily-pred' in path]
    anim = load_daily_preds_as_animations(img_full_paths, pred_dates)
    components.html(anim.to_jshtml(), height=600)

    st.markdown("**Input time series**")
    list_dir = os.listdir(folder_src)
    list_dir.sort()
    img_full_paths = [os.path.join(folder_src, f) for f in list_dir if f.endswith(".tif")]
    pred_dates = [f[:8] for f in list_dir if f.endswith(".tif")]
    anim_src = load_src_images_as_animations(img_full_paths, pred_dates)
    components.html(anim_src.to_jshtml(), height=600)

# zip_url = hf_hub_url(repo_id="ARTeLab/DemoCropMapping", filename="demo_data/1509.zip")
# with open("demo_data/1509.zip", "rb") as f:
#     bytes = f.read()
#     b64 = base64.b64encode(bytes).decode()
#     href = f'<a href="data:file/zip;base64,{b64}" download=\'1509.zip\'>\
#         Click to download\
#     </a>'
# st.sidebar.markdown(href, unsafe_allow_html=True)
# download_button_str = download_button(s, filename, f'Click here to download {filename}')
# st.markdown(download_button_str, unsafe_allow_html=True)
# with open('demo_data/1509.zip') as f:
#     st.download_button('Download 1509.zip', f, file_name="demo_data/1509.zip")

st.markdown(f""" 
    ## Lombardia Dataset
    You can download other patches from the original dataset created and published on 
    [Kaggle](https://www.kaggle.com/datasets/ignazio/sentinel2-crop-mapping) and used in our paper.

    ## How to build an input file for the Demo 
    You can download the following zip example to better understand how to create a new sample to feed as input to the model. """)
with open("demo_data/1509.zip", "rb") as fp:
    btn = st.download_button(
        label="Download ZIP example",
        data=fp,
        file_name="1509.zip",
        mime="application/octet-stream"
    )
st.markdown(f""" 
    A sample is a time series of sentinel-2 images, 
    i.e. all images acquired by the satellite during a year.    
    A zip file must contain 
    - a geoTiff image of size _9 x 48 x 48_ for each date of the time series; 
    - the name of each geoTif must show the date like this example "20221225.tif" which represents the date 25 December 2022; 
    - each image must contain all sentinel-2 bands as reported in the [paper](https://www.sciencedirect.com/science/article/pii/S0924271622003203); 
    - all the images inside the zip file must be placed inside a directory (see ZIP example) where the name represents the name of the patch (for example "24"). ) 
""")