Paolo-Fraccaro commited on
Commit
c599a73
1 Parent(s): c1a9a1c

Upload 2 files

Browse files
burn_scars_Prithvi_100M.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a6f72909bbf28917e3afa2ba00f948c470a52f825d7353ece1a785f7ca805a8
3
+ size 1198548899
burn_scars_Prithvi_100M.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dist_params = dict(backend='nccl')
2
+ log_level = 'INFO'
3
+ load_from = None
4
+ resume_from = None
5
+ cudnn_benchmark = True
6
+ custom_imports = dict(imports=['geospatial_fm'])
7
+ dataset_type = 'GeospatialDataset'
8
+ data_root = '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended'
9
+ num_frames = 1
10
+ img_size = 224
11
+ num_workers = 4
12
+ samples_per_gpu = 4
13
+ img_norm_cfg = dict(
14
+ means=[
15
+ 0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
16
+ 0.2323245113436119, 0.1972854853760658, 0.11944914225186566
17
+ ],
18
+ stds=[
19
+ 0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
20
+ 0.07791732423672691, 0.08708738838140137, 0.07241979477437814
21
+ ])
22
+ bands = [0, 1, 2, 3, 4, 5]
23
+ tile_size = 224
24
+ orig_nsize = 512
25
+ crop_size = (224, 224)
26
+ img_suffix = '_merged.tif'
27
+ seg_map_suffix = '.mask.tif'
28
+ ignore_index = -1
29
+ image_nodata = -9999
30
+ image_nodata_replace = 0
31
+ image_to_float32 = True
32
+ # pretrained_weights_path = '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt'
33
+ pretrained_weights_path = None
34
+ num_layers = 12
35
+ patch_size = 16
36
+ embed_dim = 768
37
+ num_heads = 12
38
+ tubelet_size = 1
39
+ epochs = 50
40
+ eval_epoch_interval = 5
41
+ experiment = 'test2'
42
+ project_dir = '/dccstor/geofm-finetuning/fire-scars/os'
43
+ work_dir = '/dccstor/geofm-finetuning/fire-scars/os/test2'
44
+ save_path = '/dccstor/geofm-finetuning/fire-scars/os/test2'
45
+ train_pipeline = [
46
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
47
+ dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
48
+ dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
49
+ dict(type='RandomFlip', prob=0.5),
50
+ dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
51
+ dict(
52
+ type='TorchNormalize',
53
+ means=[
54
+ 0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
55
+ 0.2323245113436119, 0.1972854853760658, 0.11944914225186566
56
+ ],
57
+ stds=[
58
+ 0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
59
+ 0.07791732423672691, 0.08708738838140137, 0.07241979477437814
60
+ ]),
61
+ dict(type='TorchRandomCrop', crop_size=(224, 224)),
62
+ dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
63
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
64
+ dict(
65
+ type='CastTensor',
66
+ keys=['gt_semantic_seg'],
67
+ new_type='torch.LongTensor'),
68
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
69
+ ]
70
+ test_pipeline = [
71
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
72
+ dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
73
+ dict(type='ToTensor', keys=['img']),
74
+ dict(
75
+ type='TorchNormalize',
76
+ means=[
77
+ 0.033349706741586264, 0.05701185520536176, 0.05889748132001316,
78
+ 0.2323245113436119, 0.1972854853760658, 0.11944914225186566
79
+ ],
80
+ stds=[
81
+ 0.02269135568823774, 0.026807560223070237, 0.04004109844362779,
82
+ 0.07791732423672691, 0.08708738838140137, 0.07241979477437814
83
+ ]),
84
+ dict(
85
+ type='Reshape',
86
+ keys=['img'],
87
+ new_shape=(6, 1, -1, -1),
88
+ look_up=dict({
89
+ '2': 1,
90
+ '3': 2
91
+ })),
92
+ dict(type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
93
+ dict(
94
+ type='CollectTestList',
95
+ keys=['img'],
96
+ meta_keys=[
97
+ 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename',
98
+ 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape',
99
+ 'scale_factor', 'img_norm_cfg'
100
+ ])
101
+ ]
102
+ data = dict(
103
+ samples_per_gpu=4,
104
+ workers_per_gpu=4,
105
+ train=dict(
106
+ type='FireScars',
107
+ data_root=
108
+ '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
109
+ img_dir='training',
110
+ ann_dir='training',
111
+ img_suffix='_merged.tif',
112
+ seg_map_suffix='.mask.tif',
113
+ pipeline=[
114
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
115
+ dict(type='LoadGeospatialAnnotations', reduce_zero_label=False),
116
+ dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
117
+ dict(type='RandomFlip', prob=0.5),
118
+ dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
119
+ dict(
120
+ type='TorchNormalize',
121
+ means=[
122
+ 0.033349706741586264, 0.05701185520536176,
123
+ 0.05889748132001316, 0.2323245113436119,
124
+ 0.1972854853760658, 0.11944914225186566
125
+ ],
126
+ stds=[
127
+ 0.02269135568823774, 0.026807560223070237,
128
+ 0.04004109844362779, 0.07791732423672691,
129
+ 0.08708738838140137, 0.07241979477437814
130
+ ]),
131
+ dict(type='TorchRandomCrop', crop_size=(224, 224)),
132
+ dict(type='Reshape', keys=['img'], new_shape=(6, 1, 224, 224)),
133
+ dict(
134
+ type='Reshape',
135
+ keys=['gt_semantic_seg'],
136
+ new_shape=(1, 224, 224)),
137
+ dict(
138
+ type='CastTensor',
139
+ keys=['gt_semantic_seg'],
140
+ new_type='torch.LongTensor'),
141
+ dict(type='Collect', keys=['img', 'gt_semantic_seg'])
142
+ ],
143
+ ignore_index=-1),
144
+ val=dict(
145
+ type='FireScars',
146
+ data_root=
147
+ '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
148
+ img_dir='validation',
149
+ ann_dir='validation',
150
+ img_suffix='_merged.tif',
151
+ seg_map_suffix='.mask.tif',
152
+ pipeline=[
153
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
154
+ dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
155
+ dict(type='ToTensor', keys=['img']),
156
+ dict(
157
+ type='TorchNormalize',
158
+ means=[
159
+ 0.033349706741586264, 0.05701185520536176,
160
+ 0.05889748132001316, 0.2323245113436119,
161
+ 0.1972854853760658, 0.11944914225186566
162
+ ],
163
+ stds=[
164
+ 0.02269135568823774, 0.026807560223070237,
165
+ 0.04004109844362779, 0.07791732423672691,
166
+ 0.08708738838140137, 0.07241979477437814
167
+ ]),
168
+ dict(
169
+ type='Reshape',
170
+ keys=['img'],
171
+ new_shape=(6, 1, -1, -1),
172
+ look_up=dict({
173
+ '2': 1,
174
+ '3': 2
175
+ })),
176
+ dict(
177
+ type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
178
+ dict(
179
+ type='CollectTestList',
180
+ keys=['img'],
181
+ meta_keys=[
182
+ 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
183
+ 'filename', 'ori_filename', 'img', 'img_shape',
184
+ 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
185
+ ])
186
+ ],
187
+ ignore_index=-1),
188
+ test=dict(
189
+ type='FireScars',
190
+ data_root=
191
+ '/dccstor/geofm-finetuning/fire-scars/finetune-data/6_bands_no_replant_extended',
192
+ img_dir='validation',
193
+ ann_dir='validation',
194
+ img_suffix='_merged.tif',
195
+ seg_map_suffix='.mask.tif',
196
+ pipeline=[
197
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
198
+ dict(type='BandsExtract', bands=[0, 1, 2, 3, 4, 5]),
199
+ dict(type='ToTensor', keys=['img']),
200
+ dict(
201
+ type='TorchNormalize',
202
+ means=[
203
+ 0.033349706741586264, 0.05701185520536176,
204
+ 0.05889748132001316, 0.2323245113436119,
205
+ 0.1972854853760658, 0.11944914225186566
206
+ ],
207
+ stds=[
208
+ 0.02269135568823774, 0.026807560223070237,
209
+ 0.04004109844362779, 0.07791732423672691,
210
+ 0.08708738838140137, 0.07241979477437814
211
+ ]),
212
+ dict(
213
+ type='Reshape',
214
+ keys=['img'],
215
+ new_shape=(6, 1, -1, -1),
216
+ look_up=dict({
217
+ '2': 1,
218
+ '3': 2
219
+ })),
220
+ dict(
221
+ type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
222
+ dict(
223
+ type='CollectTestList',
224
+ keys=['img'],
225
+ meta_keys=[
226
+ 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
227
+ 'filename', 'ori_filename', 'img', 'img_shape',
228
+ 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
229
+ ])
230
+ ],
231
+ ignore_index=-1))
232
+ optimizer = dict(type='Adam', lr=1.3e-05, betas=(0.9, 0.999))
233
+ optimizer_config = dict(grad_clip=None)
234
+ lr_config = dict(
235
+ policy='poly',
236
+ warmup='linear',
237
+ warmup_iters=1500,
238
+ warmup_ratio=1e-06,
239
+ power=1.0,
240
+ min_lr=0.0,
241
+ by_epoch=False)
242
+ log_config = dict(
243
+ interval=20,
244
+ hooks=[
245
+ dict(type='TextLoggerHook', by_epoch=False),
246
+ dict(type='TensorboardLoggerHook', by_epoch=False)
247
+ ])
248
+ checkpoint_config = dict(
249
+ by_epoch=True,
250
+ interval=10,
251
+ out_dir=
252
+ '/dccstor/geofm-finetuning/carlosgomes/fire_scars/carlos_replicate_experiment_fixed_lr'
253
+ )
254
+ evaluation = dict(
255
+ interval=1180,
256
+ metric='mIoU',
257
+ pre_eval=True,
258
+ save_best='mIoU',
259
+ by_epoch=False)
260
+ runner = dict(type='IterBasedRunner', max_iters=6300)
261
+ workflow = [('train', 1)]
262
+ norm_cfg = dict(type='BN', requires_grad=True)
263
+ model = dict(
264
+ type='TemporalEncoderDecoder',
265
+ frozen_backbone=False,
266
+ backbone=dict(
267
+ type='TemporalViTEncoder',
268
+ pretrained=None,
269
+ # '/dccstor/geofm-finetuning/pretrain_ckpts/mae_weights/2023-04-29_21-50-47/epoch-725-loss-0.0365.pt',
270
+ img_size=224,
271
+ patch_size=16,
272
+ num_frames=1,
273
+ tubelet_size=1,
274
+ in_chans=6,
275
+ embed_dim=768,
276
+ depth=12,
277
+ num_heads=12,
278
+ mlp_ratio=4.0,
279
+ norm_pix_loss=False),
280
+ neck=dict(
281
+ type='ConvTransformerTokensToEmbeddingNeck',
282
+ embed_dim=768,
283
+ output_embed_dim=768,
284
+ drop_cls_token=True,
285
+ Hp=14,
286
+ Wp=14),
287
+ decode_head=dict(
288
+ num_classes=2,
289
+ in_channels=768,
290
+ type='FCNHead',
291
+ in_index=-1,
292
+ channels=256,
293
+ num_convs=1,
294
+ concat_input=False,
295
+ dropout_ratio=0.1,
296
+ norm_cfg=dict(type='BN', requires_grad=True),
297
+ align_corners=False,
298
+ loss_decode=dict(
299
+ type='DiceLoss', use_sigmoid=False, loss_weight=1,
300
+ ignore_index=-1)),
301
+ auxiliary_head=dict(
302
+ num_classes=2,
303
+ in_channels=768,
304
+ type='FCNHead',
305
+ in_index=-1,
306
+ channels=256,
307
+ num_convs=2,
308
+ concat_input=False,
309
+ dropout_ratio=0.1,
310
+ norm_cfg=dict(type='BN', requires_grad=True),
311
+ align_corners=False,
312
+ loss_decode=dict(
313
+ type='DiceLoss', use_sigmoid=False, loss_weight=1,
314
+ ignore_index=-1)),
315
+ train_cfg=dict(),
316
+ test_cfg=dict(mode='slide', stride=(112, 112), crop_size=(224, 224)))
317
+ gpu_ids = range(0, 1)
318
+ auto_resume = False