Guillermo Uribe Vicencio commited on
Commit
b057d7c
·
1 Parent(s): d73326a

Referencia a config local

Browse files
app.py CHANGED
@@ -1,12 +1,15 @@
1
  ######### pull files
2
  import os
3
  from huggingface_hub import hf_hub_download
4
- config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
5
- filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
- token=os.environ.get("token"))
 
7
  ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
8
  filename='multi_temporal_crop_classification_Prithvi_100M.pth',
9
  token=os.environ.get("token"))
 
 
10
  ##########
11
  import argparse
12
  from mmcv import Config
@@ -205,6 +208,8 @@ def inference_on_file(target_image, model, custom_test_pipeline):
205
 
206
  output=result[0][0] + 1
207
  output = np.vstack([output[None], output[None], output[None]]).astype(np.uint8)
 
 
208
  output=apply_color_map(output).transpose((1,2,0))
209
 
210
  return rgb1,rgb2,rgb3,output
 
1
  ######### pull files
2
  import os
3
  from huggingface_hub import hf_hub_download
4
+ #config_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
5
+ # filename="multi_temporal_crop_classification_Prithvi_100M.py",
6
+ # token=os.environ.get("token"))
7
+
8
  ckpt=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification",
9
  filename='multi_temporal_crop_classification_Prithvi_100M.pth',
10
  token=os.environ.get("token"))
11
+
12
+ config_path="multi_temporal_crop_classification_Prithvi_100M.py"
13
  ##########
14
  import argparse
15
  from mmcv import Config
 
208
 
209
  output=result[0][0] + 1
210
  output = np.vstack([output[None], output[None], output[None]]).astype(np.uint8)
211
+
212
+
213
  output=apply_color_map(output).transpose((1,2,0))
214
 
215
  return rgb1,rgb2,rgb3,output
multi_temporal_crop_classification_Prithvi_100M.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ dist_params = dict(backend='nccl')
4
+ log_level = 'INFO'
5
+ load_from = None
6
+ resume_from = None
7
+ cudnn_benchmark = True
8
+ custom_imports = dict(imports=['geospatial_fm'])
9
+ num_frames = 3
10
+ img_size = 224
11
+ num_workers = 2
12
+
13
+ # model
14
+ # TO BE DEFINED BY USER: model path
15
+ pretrained_weights_path = '<path to pretrained weights>'
16
+ num_layers = 6
17
+ patch_size = 16
18
+ embed_dim = 768
19
+ num_heads = 8
20
+ tubelet_size = 1
21
+ max_epochs = 80
22
+ eval_epoch_interval = 5
23
+
24
+ loss_weights_multi = [
25
+ 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
26
+ 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
27
+ ]
28
+ loss_func = dict(
29
+ type='CrossEntropyLoss',
30
+ use_sigmoid=False,
31
+ class_weight=loss_weights_multi,
32
+ avg_non_ignore=True)
33
+ output_embed_dim = embed_dim*num_frames
34
+
35
+
36
+ # TO BE DEFINED BY USER: Save directory
37
+ experiment = '<experiment name>'
38
+ project_dir = '<project directory name>'
39
+ work_dir = os.path.join(project_dir, experiment)
40
+ save_path = work_dir
41
+
42
+
43
+ gpu_ids = range(0, 1)
44
+ dataset_type = 'GeospatialDataset'
45
+
46
+ # TO BE DEFINED BY USER: data directory
47
+ data_root = '<path to data root>'
48
+
49
+ splits = dict(
50
+ train='<path to train split>',
51
+ val= '<path to val split>',
52
+ test= '<path to test split>'
53
+ )
54
+
55
+
56
+ img_norm_cfg = dict(
57
+ means=[
58
+ 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
59
+ 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
60
+ 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
61
+ 2968.881459, 2634.621962, 1739.579917
62
+ ],
63
+ stds=[
64
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
65
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
66
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
67
+ ])
68
+
69
+ bands = [0, 1, 2, 3, 4, 5]
70
+
71
+ tile_size = 224
72
+ orig_nsize = 512
73
+ crop_size = (tile_size, tile_size)
74
+ train_pipeline = [
75
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
76
+ dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
77
+ dict(type='RandomFlip', prob=0.5),
78
+ dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
79
+ # to channels first
80
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
81
+ dict(type='TorchNormalize', **img_norm_cfg),
82
+ dict(type='TorchRandomCrop', crop_size=crop_size),
83
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
84
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
85
+ dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
86
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
87
+ ]
88
+
89
+ test_pipeline = [
90
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
91
+ dict(type='ToTensor', keys=['img']),
92
+ # to channels first
93
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
94
+ dict(type='TorchNormalize', **img_norm_cfg),
95
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
96
+ dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
97
+ dict(type='CollectTestList', keys=['img'],
98
+ meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
99
+ 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
100
+ ]
101
+
102
+ CLASSES = ('Natural Vegetation',
103
+ 'Forest',
104
+ 'Corn',
105
+ 'Soybeans',
106
+ 'Wetlands',
107
+ 'Developed/Barren',
108
+ 'Open Water',
109
+ 'Winter Wheat',
110
+ 'Alfalfa',
111
+ 'Fallow/Idle Cropland',
112
+ 'Cotton',
113
+ 'Sorghum',
114
+ 'Other')
115
+
116
+ dataset = 'GeospatialDataset'
117
+
118
+ data = dict(
119
+ samples_per_gpu=8,
120
+ workers_per_gpu=4,
121
+ train=dict(
122
+ type=dataset,
123
+ CLASSES=CLASSES,
124
+ reduce_zero_label=True,
125
+ data_root=data_root,
126
+ img_dir='training_chips',
127
+ ann_dir='training_chips',
128
+ pipeline=train_pipeline,
129
+ img_suffix='_merged.tif',
130
+ seg_map_suffix='.mask.tif',
131
+ split=splits['train']),
132
+ val=dict(
133
+ type=dataset,
134
+ CLASSES=CLASSES,
135
+ reduce_zero_label=True,
136
+ data_root=data_root,
137
+ img_dir='validation_chips',
138
+ ann_dir='validation_chips',
139
+ pipeline=test_pipeline,
140
+ img_suffix='_merged.tif',
141
+ seg_map_suffix='.mask.tif',
142
+ split=splits['val']
143
+ ),
144
+ test=dict(
145
+ type=dataset,
146
+ CLASSES=CLASSES,
147
+ reduce_zero_label=True,
148
+ data_root=data_root,
149
+ img_dir='validation_chips',
150
+ ann_dir='validation_chips',
151
+ pipeline=test_pipeline,
152
+ img_suffix='_merged.tif',
153
+ seg_map_suffix='.mask.tif',
154
+ split=splits['val']
155
+ ))
156
+
157
+ optimizer = dict(
158
+ type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
159
+ optimizer_config = dict(grad_clip=None)
160
+ lr_config = dict(
161
+ policy='poly',
162
+ warmup='linear',
163
+ warmup_iters=1500,
164
+ warmup_ratio=1e-06,
165
+ power=1.0,
166
+ min_lr=0.0,
167
+ by_epoch=False)
168
+ log_config = dict(
169
+ interval=10,
170
+ hooks=[dict(type='TextLoggerHook'),
171
+ dict(type='TensorboardLoggerHook')])
172
+
173
+ checkpoint_config = dict(
174
+ by_epoch=True,
175
+ interval=100,
176
+ out_dir=save_path)
177
+
178
+ evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
179
+ reduce_train_set = dict(reduce_train_set=False)
180
+ reduce_factor = dict(reduce_factor=1)
181
+ runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
182
+ workflow = [('train', 1)]
183
+ norm_cfg = dict(type='BN', requires_grad=True)
184
+
185
+ model = dict(
186
+ type='TemporalEncoderDecoder',
187
+ frozen_backbone=False,
188
+ backbone=dict(
189
+ type='TemporalViTEncoder',
190
+ pretrained=pretrained_weights_path,
191
+ img_size=img_size,
192
+ patch_size=patch_size,
193
+ num_frames=num_frames,
194
+ tubelet_size=1,
195
+ in_chans=len(bands),
196
+ embed_dim=embed_dim,
197
+ depth=6,
198
+ num_heads=num_heads,
199
+ mlp_ratio=4.0,
200
+ norm_pix_loss=False),
201
+ neck=dict(
202
+ type='ConvTransformerTokensToEmbeddingNeck',
203
+ embed_dim=embed_dim*num_frames,
204
+ output_embed_dim=output_embed_dim,
205
+ drop_cls_token=True,
206
+ Hp=14,
207
+ Wp=14),
208
+ decode_head=dict(
209
+ num_classes=len(CLASSES),
210
+ in_channels=output_embed_dim,
211
+ type='FCNHead',
212
+ in_index=-1,
213
+ channels=256,
214
+ num_convs=1,
215
+ concat_input=False,
216
+ dropout_ratio=0.1,
217
+ norm_cfg=dict(type='BN', requires_grad=True),
218
+ align_corners=False,
219
+ loss_decode=loss_func),
220
+ auxiliary_head=dict(
221
+ num_classes=len(CLASSES),
222
+ in_channels=output_embed_dim,
223
+ type='FCNHead',
224
+ in_index=-1,
225
+ channels=256,
226
+ num_convs=2,
227
+ concat_input=False,
228
+ dropout_ratio=0.1,
229
+ norm_cfg=dict(type='BN', requires_grad=True),
230
+ align_corners=False,
231
+ loss_decode=loss_func),
232
+ train_cfg=dict(),
233
+ test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
234
+ auto_resume = False