blumenstiel commited on
Commit
7bfe9ea
·
1 Parent(s): 8ab0a02

Updated inference code

Browse files
.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
37
  Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
38
  GFM.png filter=lfs diff=lfs merge=lfs -text
 
 
36
  Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
37
  Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
38
  GFM.png filter=lfs diff=lfs merge=lfs -text
39
+ *.tif filter=lfs diff=lfs merge=lfs -text
Prithvi.py DELETED
@@ -1,319 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
- # DeiT: https://github.com/facebookresearch/deit
10
- # --------------------------------------------------------
11
-
12
- from functools import partial
13
-
14
- import torch
15
- import torch.nn as nn
16
-
17
- from timm.models.vision_transformer import Block
18
- from timm.models.layers import to_2tuple
19
-
20
- import numpy as np
21
-
22
- from einops import rearrange
23
-
24
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
25
- """
26
- embed_dim: output dimension for each position
27
- pos: a list of positions to be encoded: size (M,)
28
- out: (M, D)
29
- """
30
- assert embed_dim % 2 == 0
31
- omega = np.arange(embed_dim // 2, dtype=np.float32)
32
- omega /= embed_dim / 2.
33
- omega = 1. / 10000**omega # (D/2,)
34
-
35
- pos = pos.reshape(-1) # (M,)
36
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
37
-
38
- emb_sin = np.sin(out) # (M, D/2)
39
- emb_cos = np.cos(out) # (M, D/2)
40
-
41
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
42
- return emb
43
-
44
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
- assert embed_dim % 2 == 0
46
-
47
- # use half of dimensions to encode grid_h
48
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50
-
51
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52
- return emb
53
-
54
- def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
55
- """
56
- grid_size: 3d tuple of grid size: t, h, w
57
- return:
58
- pos_embed: L, D
59
- """
60
-
61
- assert embed_dim % 16 == 0
62
-
63
- t_size, h_size, w_size = grid_size
64
-
65
- w_embed_dim = embed_dim // 16 * 6
66
- h_embed_dim = embed_dim // 16 * 6
67
- t_embed_dim = embed_dim // 16 * 4
68
-
69
- w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
70
- h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
71
- t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
72
-
73
- w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
74
- h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
75
- t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
76
-
77
- pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
78
-
79
- if cls_token:
80
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
81
- return pos_embed
82
-
83
-
84
- class PatchEmbed(nn.Module):
85
- """ Frames of 2D Images to Patch Embedding
86
- The 3D version of timm.models.vision_transformer.PatchEmbed
87
- """
88
- def __init__(
89
- self,
90
- img_size=224,
91
- patch_size=16,
92
- num_frames=3,
93
- tubelet_size=1,
94
- in_chans=3,
95
- embed_dim=768,
96
- norm_layer=None,
97
- flatten=True,
98
- bias=True,
99
- ):
100
- super().__init__()
101
- img_size = to_2tuple(img_size)
102
- patch_size = to_2tuple(patch_size)
103
- self.img_size = img_size
104
- self.patch_size = patch_size
105
- self.num_frames = num_frames
106
- self.tubelet_size = tubelet_size
107
- self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
108
- self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
109
- self.flatten = flatten
110
-
111
- self.proj = nn.Conv3d(in_chans, embed_dim,
112
- kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
113
- stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias)
114
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
115
-
116
- def forward(self, x):
117
- B, C, T, H, W = x.shape
118
- x = self.proj(x)
119
- if self.flatten:
120
- x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
121
- x = self.norm(x)
122
- return x
123
-
124
-
125
- class MaskedAutoencoderViT(nn.Module):
126
- """ Masked Autoencoder with VisionTransformer backbone
127
- """
128
- def __init__(self, img_size=224, patch_size=16,
129
- num_frames=3, tubelet_size=1,
130
- in_chans=3, embed_dim=1024, depth=24, num_heads=16,
131
- decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
132
- mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
133
- super().__init__()
134
-
135
- # --------------------------------------------------------------------------
136
- # MAE encoder specifics
137
- self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim)
138
- num_patches = self.patch_embed.num_patches
139
-
140
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
141
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
142
-
143
- self.blocks = nn.ModuleList([
144
- Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
145
- for i in range(depth)])
146
- self.norm = norm_layer(embed_dim)
147
- # --------------------------------------------------------------------------
148
-
149
- # --------------------------------------------------------------------------
150
- # MAE decoder specifics
151
- self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
152
-
153
- self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
154
-
155
- self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
156
-
157
- self.decoder_blocks = nn.ModuleList([
158
- Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
159
- for i in range(decoder_depth)])
160
-
161
- self.decoder_norm = norm_layer(decoder_embed_dim)
162
- self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch
163
- # --------------------------------------------------------------------------
164
-
165
- self.norm_pix_loss = norm_pix_loss
166
-
167
- self.initialize_weights()
168
-
169
- def initialize_weights(self):
170
- # initialization
171
- # initialize (and freeze) pos_embed by sin-cos embedding
172
- pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
173
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
174
-
175
- decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
176
- self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
177
-
178
- # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
179
- w = self.patch_embed.proj.weight.data
180
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
181
-
182
- # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
183
- torch.nn.init.normal_(self.cls_token, std=.02)
184
- torch.nn.init.normal_(self.mask_token, std=.02)
185
-
186
- # initialize nn.Linear and nn.LayerNorm
187
- self.apply(self._init_weights)
188
-
189
- def _init_weights(self, m):
190
- if isinstance(m, nn.Linear):
191
- # we use xavier_uniform following official JAX ViT:
192
- torch.nn.init.xavier_uniform_(m.weight)
193
- if isinstance(m, nn.Linear) and m.bias is not None:
194
- nn.init.constant_(m.bias, 0)
195
- elif isinstance(m, nn.LayerNorm):
196
- nn.init.constant_(m.bias, 0)
197
- nn.init.constant_(m.weight, 1.0)
198
-
199
- def patchify(self, imgs):
200
- """
201
- imgs: B, C, T, H, W
202
- x: B, L, D
203
- """
204
- p = self.patch_embed.patch_size[0]
205
- tub = self.patch_embed.tubelet_size
206
- x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p)
207
-
208
- return x
209
-
210
- def unpatchify(self, x):
211
- """
212
- x: B, L, D
213
- imgs: B, C, T, H, W
214
- """
215
- p = self.patch_embed.patch_size[0]
216
- num_p = self.patch_embed.img_size[0] // p
217
- tub = self.patch_embed.tubelet_size
218
- imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p)
219
- return imgs
220
-
221
- def random_masking(self, x, mask_ratio):
222
- """
223
- Perform per-sample random masking by per-sample shuffling.
224
- Per-sample shuffling is done by argsort random noise.
225
- x: [N, L, D], sequence
226
- """
227
- N, L, D = x.shape # batch, length, dim
228
- len_keep = int(L * (1 - mask_ratio))
229
-
230
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
231
-
232
- # sort noise for each sample
233
- ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
234
- ids_restore = torch.argsort(ids_shuffle, dim=1)
235
-
236
- # keep the first subset
237
- ids_keep = ids_shuffle[:, :len_keep]
238
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
239
-
240
- # generate the binary mask: 0 is keep, 1 is remove
241
- mask = torch.ones([N, L], device=x.device)
242
- mask[:, :len_keep] = 0
243
- # unshuffle to get the binary mask
244
- mask = torch.gather(mask, dim=1, index=ids_restore)
245
-
246
- return x_masked, mask, ids_restore
247
-
248
- def forward_encoder(self, x, mask_ratio):
249
- # embed patches
250
- x = self.patch_embed(x)
251
-
252
- # add pos embed w/o cls token
253
- x = x + self.pos_embed[:, 1:, :]
254
-
255
- # masking: length -> length * mask_ratio
256
- x, mask, ids_restore = self.random_masking(x, mask_ratio)
257
-
258
- # append cls token
259
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
260
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
261
- x = torch.cat((cls_tokens, x), dim=1)
262
-
263
- # apply Transformer blocks
264
- for blk in self.blocks:
265
- x = blk(x)
266
- x = self.norm(x)
267
-
268
- return x, mask, ids_restore
269
-
270
- def forward_decoder(self, x, ids_restore):
271
- # embed tokens
272
- x = self.decoder_embed(x)
273
-
274
- # append mask tokens to sequence
275
- mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
276
- x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
277
- x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
278
- x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
279
-
280
- # add pos embed
281
- x = x + self.decoder_pos_embed
282
-
283
- # apply Transformer blocks
284
- for blk in self.decoder_blocks:
285
- x = blk(x)
286
- x = self.decoder_norm(x)
287
-
288
- # predictor projection
289
- x = self.decoder_pred(x)
290
-
291
- # remove cls token
292
- x = x[:, 1:, :]
293
-
294
- return x
295
-
296
- def forward_loss(self, imgs, pred, mask):
297
- """
298
- imgs: B, C, T, H, W
299
- target: B, L, D
300
- pred: B, L, D
301
- mask: B, L. 0 is keep, 1 is remove,
302
- """
303
- target = self.patchify(imgs)
304
- if self.norm_pix_loss:
305
- mean = target.mean(dim=-1, keepdim=True)
306
- var = target.var(dim=-1, keepdim=True)
307
- target = (target - mean) / (var + 1.e-6)**.5
308
-
309
- loss = (pred - target) ** 2
310
- loss = loss.mean(dim=-1) # [N, L], mean loss per patch
311
-
312
- loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
313
- return loss
314
-
315
- def forward(self, imgs, mask_ratio=0.75):
316
- latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
317
- pred = self.forward_decoder(latent, ids_restore)
318
- loss = self.forward_loss(imgs, pred, mask)
319
- return loss, pred, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -33,10 +33,10 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
33
  4. adding infrared bands besides RGB
34
 
35
  ### Inference and demo
36
- There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
- python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/config.yaml --checkpoint /path/to/checkpoint/Prithvi_EO_V1_100M.pth --output_dir /path/to/out/dir/ --input_indices <space separated 0-based indices of channels to select from input> --mask_ratio 0.5 --img_size <length of one side of square input shape>
40
  ```
41
 
42
  This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
 
33
  4. adding infrared bands besides RGB
34
 
35
  ### Inference and demo
36
+ There is an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
+ python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
40
  ```
41
 
42
  This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
Prithvi_run_inference.py → inference.py RENAMED
@@ -2,21 +2,25 @@ import argparse
2
  import functools
3
  import os
4
  from typing import List, Union
5
-
 
6
  import numpy as np
 
7
  import rasterio
8
  import torch
9
  import yaml
10
  from einops import rearrange
11
 
12
- from Prithvi import MaskedAutoencoderViT
 
13
 
14
  NO_DATA = -9999
15
  NO_DATA_FLOAT = 0.0001
16
- PERCENTILES = (0.1, 99.9)
 
17
 
18
 
19
- def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
20
  """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
21
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
22
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
@@ -25,43 +29,36 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
25
  orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
26
  new_img: torch.Tensor representing image with shape = (bands, H, W).
27
  channels: list of indices representing RGB channels.
28
- data_mean: list of mean values for each band.
29
- data_std: list of std values for each band.
30
 
31
  Returns:
32
  torch.Tensor with shape (num_channels, height, width) for original image
33
  torch.Tensor with shape (num_channels, height, width) for the other image
34
  """
35
 
36
- stack_c = [], []
37
-
38
- for c in channels:
39
- orig_ch = orig_img[c, ...]
40
- valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
41
- valid_mask[orig_ch == NO_DATA_FLOAT] = False
42
-
43
- # Back to original data range
44
- orig_ch = (orig_ch * data_std[c]) + data_mean[c]
45
- new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
46
-
47
- # Rescale (enhancing contrast)
48
- min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
49
 
50
- orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
51
- new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
 
52
 
53
- # No data as zeros
54
- orig_ch[~valid_mask] = 0
55
- new_ch[~valid_mask] = 0
56
 
57
- stack_c[0].append(orig_ch)
58
- stack_c[1].append(new_ch)
59
 
60
- # Channels first
61
- stack_orig = torch.stack(stack_c[0], dim=0)
62
- stack_rec = torch.stack(stack_c[1], dim=0)
63
 
64
- return stack_orig, stack_rec
65
 
66
 
67
  def read_geotiff(file_path: str):
@@ -78,8 +75,13 @@ def read_geotiff(file_path: str):
78
  with rasterio.open(file_path) as src:
79
  img = src.read()
80
  meta = src.meta
 
 
 
 
 
81
 
82
- return img, meta
83
 
84
 
85
  def save_geotiff(image, output_path: str, meta: dict):
@@ -127,7 +129,7 @@ def load_example(
127
  metas = []
128
 
129
  for file in file_paths:
130
- img, meta = read_geotiff(file)
131
 
132
  # Rescaling (don't normalize on nodata)
133
  img = np.moveaxis(img, 0, -1) # channels last for rescaling
@@ -140,7 +142,7 @@ def load_example(
140
 
141
  imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
142
  imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
143
- imgs = np.expand_dims(imgs, axis=0) # add batch dim
144
 
145
  return imgs, metas
146
 
@@ -166,7 +168,7 @@ def run_model(
166
  with torch.no_grad():
167
  x = input_data.to(device)
168
 
169
- _, pred, mask = model(x, mask_ratio)
170
 
171
  # Create mask and prediction images (un-patchify)
172
  mask_img = (
@@ -207,8 +209,8 @@ def save_rgb_imgs(
207
  orig_img=input_img[:, t, :, :],
208
  new_img=rec_img[:, t, :, :],
209
  channels=channels,
210
- data_mean=mean,
211
- data_std=std,
212
  )
213
 
214
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
@@ -272,11 +274,10 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
272
 
273
  def main(
274
  data_files: List[str],
275
- yaml_file_path: str,
276
  checkpoint: str,
277
  output_dir: str,
278
  rgb_outputs: bool,
279
- img_size: int,
280
  mask_ratio: float = None,
281
  input_indices: list[int] = None,
282
  ):
@@ -284,31 +285,17 @@ def main(
284
 
285
  # Get parameters --------
286
 
287
- with open(yaml_file_path, "r") as f:
288
- params = yaml.safe_load(f)
289
-
290
- # data related
291
- train_params = params["train_params"]
292
- num_frames = len(data_files)
293
- bands = train_params["bands"]
294
- mean = train_params["data_mean"]
295
- std = train_params["data_std"]
296
-
297
- # model related
298
- model_params = params["model_args"]
299
- img_size = model_params["img_size"] if img_size is None else img_size
300
- depth = model_params["depth"]
301
- patch_size = model_params["patch_size"]
302
- embed_dim = model_params["embed_dim"]
303
- num_heads = model_params["num_heads"]
304
- tubelet_size = model_params["tubelet_size"]
305
- decoder_embed_dim = model_params["decoder_embed_dim"]
306
- decoder_num_heads = model_params["decoder_num_heads"]
307
- decoder_depth = model_params["decoder_depth"]
308
 
309
  batch_size = 1
310
-
311
- mask_ratio = train_params["mask_ratio"] if mask_ratio is None else mask_ratio
 
 
 
 
312
 
313
  print(
314
  f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
@@ -333,23 +320,13 @@ def main(
333
 
334
  # Create model and load checkpoint -------------------------------------------------------------
335
 
336
- model = MaskedAutoencoderViT(
337
- img_size=img_size,
338
- patch_size=patch_size,
339
  num_frames=num_frames,
340
- tubelet_size=tubelet_size,
341
  in_chans=len(bands),
342
- embed_dim=embed_dim,
343
- depth=depth,
344
- num_heads=num_heads,
345
- decoder_embed_dim=decoder_embed_dim,
346
- decoder_depth=decoder_depth,
347
- decoder_num_heads=decoder_num_heads,
348
- mlp_ratio=4.0,
349
- norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
350
- norm_pix_loss=False,
351
  )
352
 
 
 
353
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
354
  print(f"\n--> Model has {total_params:,} parameters.\n")
355
 
@@ -357,8 +334,9 @@ def main(
357
 
358
  state_dict = torch.load(checkpoint, map_location=device)
359
  # discard fixed pos_embedding weight
360
- del state_dict["pos_embed"]
361
- del state_dict["decoder_pos_embed"]
 
362
  model.load_state_dict(state_dict, strict=False)
363
  print(f"Loaded checkpoint from {checkpoint}")
364
 
@@ -463,42 +441,40 @@ if __name__ == "__main__":
463
 
464
  parser.add_argument(
465
  "--data_files",
466
- required=True,
467
  type=str,
468
  nargs="+",
 
 
 
 
469
  help="Path to the data files. Assumes multi-band files.",
470
  )
471
  parser.add_argument(
472
- "--yaml_file_path",
 
473
  type=str,
474
- required=True,
475
- help="Path to yaml file containing model training parameters.",
476
  )
477
  parser.add_argument(
478
  "--checkpoint",
479
- required=True,
480
  type=str,
 
481
  help="Path to a checkpoint file to load from.",
482
  )
483
  parser.add_argument(
484
  "--output_dir",
485
- required=True,
486
  type=str,
 
487
  help="Path to the directory where to save outputs.",
488
  )
489
  parser.add_argument(
490
  "--mask_ratio",
491
- default=None,
492
  type=float,
493
  help="Masking ratio (percentage of removed patches). "
494
  "If None (default) use same value used for pretraining.",
495
  )
496
- parser.add_argument(
497
- "--img_size",
498
- default=224,
499
- type=int,
500
- help="Image size to be used with model. Defaults to 224",
501
- )
502
  parser.add_argument(
503
  "--input_indices",
504
  default=None,
 
2
  import functools
3
  import os
4
  from typing import List, Union
5
+ import re
6
+ import datetime
7
  import numpy as np
8
+ import pandas as pd
9
  import rasterio
10
  import torch
11
  import yaml
12
  from einops import rearrange
13
 
14
+ from functools import partial
15
+ from prithvi_mae import PrithviMAE
16
 
17
  NO_DATA = -9999
18
  NO_DATA_FLOAT = 0.0001
19
+ OFFSET = 0
20
+ PERCENTILE = 99.9
21
 
22
 
23
+ def process_channel_group(orig_img, new_img, channels, mean, std):
24
  """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
25
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
26
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
 
29
  orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
30
  new_img: torch.Tensor representing image with shape = (bands, H, W).
31
  channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
 
35
  Returns:
36
  torch.Tensor with shape (num_channels, height, width) for original image
37
  torch.Tensor with shape (num_channels, height, width) for the other image
38
  """
39
 
40
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
41
+ std = torch.tensor(np.asarray(std)[:, None, None])
42
+ orig_img = orig_img[channels, ...]
43
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
44
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
 
 
 
 
 
 
 
 
45
 
46
+ # Back to original data range
47
+ orig_img = (orig_img * std[channels]) + mean[channels]
48
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
49
 
50
+ # Rescale (enhancing contrast)
51
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
52
+ min_value = OFFSET
53
 
54
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
55
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
56
 
57
+ # No data as zeros
58
+ orig_img[~valid_mask] = 0
59
+ new_img[~valid_mask] = 0
60
 
61
+ return orig_img, new_img
62
 
63
 
64
  def read_geotiff(file_path: str):
 
75
  with rasterio.open(file_path) as src:
76
  img = src.read()
77
  meta = src.meta
78
+ try:
79
+ coords = src.lnglat()
80
+ except:
81
+ # Cannot read coords
82
+ coords = None
83
 
84
+ return img, meta, coords
85
 
86
 
87
  def save_geotiff(image, output_path: str, meta: dict):
 
129
  metas = []
130
 
131
  for file in file_paths:
132
+ img, meta, _ = read_geotiff(file)
133
 
134
  # Rescaling (don't normalize on nodata)
135
  img = np.moveaxis(img, 0, -1) # channels last for rescaling
 
142
 
143
  imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
144
  imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
145
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
146
 
147
  return imgs, metas
148
 
 
168
  with torch.no_grad():
169
  x = input_data.to(device)
170
 
171
+ _, pred, mask = model(x, mask_ratio=mask_ratio)
172
 
173
  # Create mask and prediction images (un-patchify)
174
  mask_img = (
 
209
  orig_img=input_img[:, t, :, :],
210
  new_img=rec_img[:, t, :, :],
211
  channels=channels,
212
+ mean=mean,
213
+ std=std,
214
  )
215
 
216
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
 
274
 
275
  def main(
276
  data_files: List[str],
277
+ config_path: str,
278
  checkpoint: str,
279
  output_dir: str,
280
  rgb_outputs: bool,
 
281
  mask_ratio: float = None,
282
  input_indices: list[int] = None,
283
  ):
 
285
 
286
  # Get parameters --------
287
 
288
+ import json
289
+ with open(config_path, "r") as f:
290
+ config = yaml.safe_load(f)['pretrained_cfg']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  batch_size = 1
293
+ bands = config['bands']
294
+ num_frames = len(data_files)
295
+ mean = config['mean']
296
+ std = config['std']
297
+ img_size = config['img_size']
298
+ mask_ratio = mask_ratio or config['mask_ratio']
299
 
300
  print(
301
  f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
 
320
 
321
  # Create model and load checkpoint -------------------------------------------------------------
322
 
323
+ config.update(
 
 
324
  num_frames=num_frames,
 
325
  in_chans=len(bands),
 
 
 
 
 
 
 
 
 
326
  )
327
 
328
+ model = PrithviMAE(**config)
329
+
330
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
331
  print(f"\n--> Model has {total_params:,} parameters.\n")
332
 
 
334
 
335
  state_dict = torch.load(checkpoint, map_location=device)
336
  # discard fixed pos_embedding weight
337
+ for k in list(state_dict.keys()):
338
+ if 'pos_embed' in k:
339
+ del state_dict[k]
340
  model.load_state_dict(state_dict, strict=False)
341
  print(f"Loaded checkpoint from {checkpoint}")
342
 
 
441
 
442
  parser.add_argument(
443
  "--data_files",
 
444
  type=str,
445
  nargs="+",
446
+ default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
447
+ "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
448
+ "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
449
+ ],
450
  help="Path to the data files. Assumes multi-band files.",
451
  )
452
  parser.add_argument(
453
+ "--config_path",
454
+ "-c",
455
  type=str,
456
+ default="config.json",
457
+ help="Path to json file containing model training parameters.",
458
  )
459
  parser.add_argument(
460
  "--checkpoint",
 
461
  type=str,
462
+ default="Prithvi_EO_V1_100M.pt",
463
  help="Path to a checkpoint file to load from.",
464
  )
465
  parser.add_argument(
466
  "--output_dir",
 
467
  type=str,
468
+ default="output",
469
  help="Path to the directory where to save outputs.",
470
  )
471
  parser.add_argument(
472
  "--mask_ratio",
473
+ default=0.75,
474
  type=float,
475
  help="Masking ratio (percentage of removed patches). "
476
  "If None (default) use same value used for pretraining.",
477
  )
 
 
 
 
 
 
478
  parser.add_argument(
479
  "--input_indices",
480
  default=None,
prithvi_mae.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ from functools import partial
21
+ from typing import List, Tuple
22
+
23
+ import logging
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ from einops import rearrange
28
+ from timm.layers import to_2tuple
29
+ from timm.models.vision_transformer import Block
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
95
+ it was modified to cast omega values to pos.dtype which must be float (and not int as in
96
+ regular positional embeddings). This was required in order to allow for native FSDP mixed
97
+ precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
98
+ instead of manually forcing float32.
99
+
100
+ embed_dim: output dimension for each position
101
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
102
+ out: (M, D)
103
+ """
104
+ assert embed_dim % 2 == 0
105
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
106
+
107
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = torch.sin(out) # (M, D/2)
115
+ emb_cos = torch.cos(out) # (M, D/2)
116
+
117
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
118
+
119
+ return emb
120
+
121
+
122
+ def _init_weights(module):
123
+ """Initialize the weights"""
124
+ if isinstance(module, nn.Linear):
125
+ nn.init.xavier_uniform_(module.weight)
126
+ if module.bias is not None:
127
+ module.bias.data.zero_()
128
+ elif isinstance(module, nn.LayerNorm):
129
+ module.bias.data.zero_()
130
+ module.weight.data.fill_(1.0)
131
+
132
+
133
+ class PatchEmbed(nn.Module):
134
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
135
+ def __init__(
136
+ self,
137
+ input_size: Tuple[int, int, int] = (1, 224, 224),
138
+ patch_size: Tuple[int, int, int] = (1, 16, 16),
139
+ in_chans: int = 3,
140
+ embed_dim: int = 768,
141
+ norm_layer: nn.Module | None = None,
142
+ flatten: bool = True,
143
+ bias: bool = True,
144
+ ):
145
+ super().__init__()
146
+ self.input_size = input_size
147
+ self.patch_size = patch_size
148
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
149
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
150
+ self.flatten = flatten
151
+
152
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
153
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
154
+
155
+ def forward(self, x):
156
+ B, C, T, H, W = x.shape
157
+
158
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
159
+ logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
160
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
161
+
162
+ x = self.proj(x)
163
+ if self.flatten:
164
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
165
+ x = self.norm(x)
166
+ return x
167
+
168
+
169
+ class TemporalEncoder(nn.Module):
170
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
171
+ super().__init__()
172
+ self.embed_dim = embed_dim
173
+ self.year_embed_dim = embed_dim // 2
174
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
175
+
176
+ # If trainable, initialize scale with small number
177
+ if trainable_scale:
178
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
179
+ else:
180
+ self.register_buffer('scale', torch.ones(1))
181
+
182
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
183
+ """
184
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
185
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
186
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
187
+ """
188
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
189
+
190
+ year = _get_1d_sincos_embed_from_grid_torch(
191
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
192
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
193
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
194
+
195
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
196
+
197
+ if tokens_per_frame is not None:
198
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
199
+
200
+ return embedding # B, T*tokens_per_frame, embed_dim
201
+
202
+
203
+ class LocationEncoder(nn.Module):
204
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
205
+ super().__init__()
206
+ self.embed_dim = embed_dim
207
+ self.lat_embed_dim = embed_dim // 2
208
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
209
+
210
+ # If trainable, initialize scale with small number
211
+ if trainable_scale:
212
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
213
+ else:
214
+ self.register_buffer('scale', torch.ones(1))
215
+
216
+ def forward(self, location_coords: torch.Tensor):
217
+ """
218
+ location_coords: lat and lon info with shape (B, 2).
219
+ """
220
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
221
+
222
+ lat = _get_1d_sincos_embed_from_grid_torch(
223
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
224
+ lon = _get_1d_sincos_embed_from_grid_torch(
225
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
226
+
227
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
228
+
229
+ return embedding # B, 1, embed_dim
230
+
231
+
232
+ class PrithviViT(nn.Module):
233
+ """ Prithvi ViT Encoder"""
234
+ def __init__(self,
235
+ img_size: int | Tuple[int, int] = 224,
236
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
237
+ num_frames: int = 1,
238
+ in_chans: int = 3,
239
+ embed_dim: int = 1024,
240
+ depth: int = 24,
241
+ num_heads: int = 16,
242
+ mlp_ratio: float = 4.,
243
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
244
+ coords_encoding: List[str] | None = None,
245
+ coords_scale_learn: bool = False,
246
+ encoder_only: bool = True, # needed for timm
247
+ ** kwargs,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.feature_info = []
252
+ self.encoder_only = encoder_only
253
+ self.in_chans = in_chans
254
+ self.num_frames = num_frames
255
+ self.embed_dim = embed_dim
256
+ self.img_size = to_2tuple(img_size)
257
+ if isinstance(patch_size, int):
258
+ patch_size = (1, patch_size, patch_size)
259
+
260
+ # 3D patch embedding
261
+ self.patch_embed = PatchEmbed(
262
+ input_size=(num_frames,) + self.img_size,
263
+ patch_size=patch_size,
264
+ in_chans=in_chans,
265
+ embed_dim=embed_dim,
266
+ )
267
+
268
+ # Optional temporal and location embedding
269
+ coords_encoding = coords_encoding or []
270
+ self.temporal_encoding = 'time' in coords_encoding
271
+ self.location_encoding = 'location' in coords_encoding
272
+ if self.temporal_encoding:
273
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
274
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
275
+ if self.location_encoding:
276
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
277
+
278
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
279
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
280
+
281
+ # Transformer layers
282
+ self.blocks = []
283
+ for i in range(depth):
284
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
285
+ self.feature_info.append(
286
+ {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
287
+ )
288
+ self.blocks = nn.ModuleList(self.blocks)
289
+
290
+ self.norm = norm_layer(embed_dim)
291
+
292
+ self.initialize_weights()
293
+
294
+ def initialize_weights(self):
295
+ # initialize (and freeze) position embeddings by sin-cos embedding
296
+ pos_embed = get_3d_sincos_pos_embed(
297
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
298
+ )
299
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
300
+
301
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
302
+ w = self.patch_embed.proj.weight.data
303
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
304
+
305
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
306
+ torch.nn.init.normal_(self.cls_token, std=0.02)
307
+ self.apply(_init_weights)
308
+
309
+ def random_masking(self, sequence, mask_ratio, noise=None):
310
+ """
311
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
312
+ noise.
313
+
314
+ Args:
315
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
316
+ mask_ratio (float): mask ratio to use.
317
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
318
+ mainly used for testing purposes to control randomness and maintain the reproducibility
319
+ """
320
+ batch_size, seq_length, dim = sequence.shape
321
+ len_keep = int(seq_length * (1 - mask_ratio))
322
+
323
+ if noise is None:
324
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
325
+
326
+ # sort noise for each sample
327
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
328
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
329
+
330
+ # keep the first subset
331
+ ids_keep = ids_shuffle[:, :len_keep]
332
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
333
+
334
+ # generate the binary mask: 0 is keep, 1 is remove
335
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
336
+ mask[:, :len_keep] = 0
337
+ # unshuffle to get the binary mask
338
+ mask = torch.gather(mask, dim=1, index=ids_restore)
339
+
340
+ return sequence_unmasked, mask, ids_restore
341
+
342
+ def _get_pos_embed(self, x):
343
+ t, h, w = x.shape[-3:]
344
+
345
+ pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
346
+ self.embed_dim,
347
+ (
348
+ t // self.patch_embed.patch_size[0],
349
+ h // self.patch_embed.patch_size[1],
350
+ w // self.patch_embed.patch_size[2],
351
+ ),
352
+ add_cls_token=True,
353
+ )).float().unsqueeze(0).to(x)
354
+
355
+ return pos_embed
356
+
357
+
358
+ def forward(
359
+ self, x: torch.Tensor,
360
+ temporal_coords: None | torch.Tensor = None,
361
+ location_coords: None | torch.Tensor = None,
362
+ mask_ratio=0.75
363
+ ):
364
+ if x.shape[-3:] != self.patch_embed.input_size:
365
+ # changed input size
366
+ pos_embed = self._get_pos_embed(x)
367
+ else:
368
+ pos_embed = self.pos_embed
369
+
370
+ # embed patches
371
+ x = self.patch_embed(x)
372
+
373
+ # add pos embed w/o cls token
374
+ x = x + pos_embed[:, 1:, :]
375
+
376
+ if self.temporal_encoding:
377
+ num_tokens_per_frame = x.shape[1] // self.num_frames
378
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
379
+ x = x + temporal_encoding
380
+ if self.location_encoding:
381
+ location_encoding = self.location_embed_enc(location_coords)
382
+ x = x + location_encoding
383
+
384
+ # masking: length -> length * mask_ratio
385
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
386
+
387
+ # append cls token
388
+ cls_token = self.cls_token + pos_embed[:, :1, :]
389
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
390
+ x = torch.cat((cls_tokens, x), dim=1)
391
+
392
+ # apply Transformer blocks
393
+ for block in self.blocks:
394
+ x = block(x)
395
+ x = self.norm(x)
396
+
397
+ return x, mask, ids_restore
398
+
399
+ def forward_features(
400
+ self,
401
+ x: torch.Tensor,
402
+ temporal_coords: None | torch.Tensor = None,
403
+ location_coords: None | torch.Tensor = None,
404
+ ) -> list[torch.Tensor]:
405
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
406
+ # add time dim
407
+ x = x.unsqueeze(2)
408
+
409
+ if x.shape[-3:] != self.patch_embed.input_size:
410
+ pos_embed = self._get_pos_embed(x)
411
+ else:
412
+ pos_embed = self.pos_embed
413
+
414
+ # embed patches
415
+ x = self.patch_embed(x)
416
+
417
+ # add pos embed w/o cls token
418
+ x = x + pos_embed[:, 1:, :]
419
+
420
+ if self.temporal_encoding:
421
+ num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
422
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
423
+ x = x + temporal_encoding
424
+ if self.location_encoding:
425
+ location_encoding = self.location_embed_enc(location_coords)
426
+ x = x + location_encoding
427
+
428
+ # append cls token
429
+ cls_token = self.cls_token + pos_embed[:, :1, :]
430
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
431
+ x = torch.cat((cls_tokens, x), dim=1)
432
+
433
+ # apply Transformer blocks
434
+ out = []
435
+ for block in self.blocks:
436
+ x = block(x)
437
+ out.append(x.clone())
438
+
439
+ x = self.norm(x)
440
+ out[-1] = x
441
+ return out
442
+
443
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
444
+ out = []
445
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
446
+ for x in features:
447
+ x_no_token = x[:, 1:, :]
448
+ number_of_tokens = x_no_token.shape[1]
449
+ tokens_per_timestep = number_of_tokens // effective_time_dim
450
+ h = int(np.sqrt(tokens_per_timestep))
451
+ encoded = rearrange(
452
+ x_no_token,
453
+ "batch (t h w) e -> batch (t e) h w",
454
+ e=self.embed_dim,
455
+ t=effective_time_dim,
456
+ h=h,
457
+ )
458
+ out.append(encoded)
459
+ return out
460
+
461
+
462
+ class MAEDecoder(nn.Module):
463
+ """ Transformer Decoder used in the Prithvi MAE"""
464
+ def __init__(self,
465
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
466
+ grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
467
+ in_chans: int = 3,
468
+ encoder_embed_dim: int = 1024,
469
+ decoder_embed_dim: int = 512,
470
+ depth: int = 8,
471
+ num_heads: int = 16,
472
+ mlp_ratio: float = 4.,
473
+ norm_layer: nn.Module = nn.LayerNorm,
474
+ coords_encoding: List[str] | None = None,
475
+ coords_scale_learn: bool = False,
476
+ ):
477
+ super().__init__()
478
+
479
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
480
+ self.decoder_embed_dim = decoder_embed_dim
481
+ self.grid_size = grid_size
482
+ if isinstance(patch_size, int):
483
+ patch_size = (1, patch_size, patch_size)
484
+ self.patch_size = patch_size
485
+ self.num_frames = self.grid_size[0] * patch_size[0]
486
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
487
+
488
+ # Optional temporal and location embedding
489
+ coords_encoding = coords_encoding or []
490
+ self.temporal_encoding = 'time' in coords_encoding
491
+ self.location_encoding = 'location' in coords_encoding
492
+ if self.temporal_encoding:
493
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
494
+ if self.location_encoding:
495
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
496
+
497
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
498
+
499
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
500
+
501
+ self.decoder_blocks = nn.ModuleList(
502
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
503
+ )
504
+
505
+ self.decoder_norm = norm_layer(decoder_embed_dim)
506
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
507
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
508
+ bias=True)
509
+
510
+ self.initialize_weights()
511
+
512
+ def initialize_weights(self):
513
+ # initialize (and freeze) position embeddings by sin-cos embedding
514
+ decoder_pos_embed = get_3d_sincos_pos_embed(
515
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
516
+ )
517
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
518
+
519
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
520
+ torch.nn.init.normal_(self.mask_token, std=0.02)
521
+ self.apply(_init_weights)
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.Tensor,
526
+ ids_restore: torch.Tensor,
527
+ temporal_coords: None | torch.Tensor = None,
528
+ location_coords: None | torch.Tensor = None,
529
+ input_size: list[int] = None,
530
+ ):
531
+ # embed tokens
532
+ x = self.decoder_embed(hidden_states)
533
+
534
+ t, h, w = input_size[-3:]
535
+ decoder_pos_embed = torch.from_numpy(
536
+ get_3d_sincos_pos_embed(
537
+ self.decoder_embed_dim,
538
+ (
539
+ t // self.patch_size[0],
540
+ h // self.patch_size[1],
541
+ w // self.patch_size[2],
542
+ ),
543
+ add_cls_token=True,
544
+ )
545
+ ).to(x)
546
+
547
+ # append mask tokens to sequence
548
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
549
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
550
+ # unshuffle
551
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
552
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553
+ # add pos embed
554
+ x = x + decoder_pos_embed
555
+
556
+ # remove cls token
557
+ x_ = x[:, 1:, :]
558
+
559
+ if self.temporal_encoding:
560
+ num_tokens_per_frame = x_.shape[1] // self.num_frames
561
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
562
+ # Add temporal encoding w/o cls token
563
+ x_ = x_ + temporal_encoding
564
+ if self.location_encoding:
565
+ location_encoding = self.location_embed_dec(location_coords)
566
+ # Add location encoding w/o cls token
567
+ x_ = x_ + location_encoding
568
+
569
+ # append cls token
570
+ x = torch.cat([x[:, :1, :], x_], dim=1)
571
+
572
+ # apply Transformer layers (blocks)
573
+ for block in self.decoder_blocks:
574
+ x = block(x)
575
+ x = self.decoder_norm(x)
576
+
577
+ # predictor projection
578
+ pred = self.decoder_pred(x)
579
+
580
+ # remove cls token
581
+ pred = pred[:, 1:, :]
582
+
583
+ return pred
584
+
585
+
586
+ class PrithviMAE(nn.Module):
587
+ """ Prithvi Masked Autoencoder"""
588
+
589
+ def __init__(self,
590
+ img_size: int | Tuple[int, int] = 224,
591
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
592
+ num_frames: int = 3,
593
+ in_chans: int = 3,
594
+ embed_dim: int = 1024,
595
+ depth: int = 24,
596
+ num_heads: int = 16,
597
+ decoder_embed_dim: int = 512,
598
+ decoder_depth: int = 8,
599
+ decoder_num_heads: int = 16,
600
+ mlp_ratio: float = 4.,
601
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
602
+ norm_pix_loss: bool = False,
603
+ coords_encoding: List[str] | None = None,
604
+ coords_scale_learn: bool = False,
605
+ encoder_only: bool = False,
606
+ **kwargs,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.encoder = PrithviViT(
611
+ img_size=img_size,
612
+ num_frames=num_frames,
613
+ patch_size=patch_size,
614
+ in_chans=in_chans,
615
+ embed_dim=embed_dim,
616
+ depth=depth,
617
+ num_heads=num_heads,
618
+ mlp_ratio=mlp_ratio,
619
+ norm_layer=norm_layer,
620
+ coords_encoding=coords_encoding,
621
+ coords_scale_learn=coords_scale_learn,
622
+ )
623
+
624
+ self.encoder_only = encoder_only
625
+
626
+ if not encoder_only:
627
+ self.decoder = MAEDecoder(
628
+ patch_size=patch_size,
629
+ grid_size=self.encoder.patch_embed.grid_size,
630
+ in_chans=in_chans,
631
+ encoder_embed_dim=embed_dim,
632
+ decoder_embed_dim=decoder_embed_dim,
633
+ depth=decoder_depth,
634
+ num_heads=decoder_num_heads,
635
+ mlp_ratio=mlp_ratio,
636
+ norm_layer=norm_layer,
637
+ coords_encoding=coords_encoding,
638
+ coords_scale_learn=coords_scale_learn,
639
+ )
640
+ else:
641
+ self.decoder = nn.Identity()
642
+
643
+ self.norm_pix_loss = norm_pix_loss
644
+
645
+ def patchify(self, pixel_values):
646
+ """
647
+ Args:
648
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
649
+ Pixel values.
650
+
651
+ Returns:
652
+ torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
653
+ Patchified pixel values.
654
+ """
655
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
656
+ num_channels = self.encoder.in_chans
657
+
658
+ # patchify
659
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
660
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
661
+
662
+
663
+ return patchified_pixel_values
664
+
665
+ def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
666
+ """
667
+ Args:
668
+ patchified_pixel_values (`torch.FloatTensor` of shape
669
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
670
+ Patchified pixel values.
671
+ image_size (`Tuple[int, int]`, *optional*):
672
+ Original image size.
673
+
674
+ Returns:
675
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
676
+ Pixel values.
677
+ """
678
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
679
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
680
+ original_height, original_width = image_size
681
+ num_patches_h = original_height // patch_size_h
682
+ num_patches_w = original_width // patch_size_w
683
+ num_channels = self.encoder.in_chans
684
+
685
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
686
+ c=num_channels, h=num_patches_h, w=num_patches_w,
687
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
688
+ return pixel_values
689
+
690
+ def forward_loss(self, pixel_values, pred, mask):
691
+ """
692
+ Args:
693
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
694
+ Pixel values.
695
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
696
+ Predicted pixel values.
697
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
698
+ Tensor indicating which patches are masked (1) and which are not (0).
699
+
700
+ Returns:
701
+ `torch.FloatTensor`: Pixel reconstruction loss.
702
+ """
703
+ target = self.patchify(pixel_values)
704
+ if self.norm_pix_loss:
705
+ mean = target.mean(dim=-1, keepdim=True)
706
+ var = target.var(dim=-1, keepdim=True)
707
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
708
+
709
+ loss = (pred - target) ** 2
710
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
711
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
712
+ return loss
713
+
714
+ def forward(
715
+ self,
716
+ pixel_values: torch.Tensor,
717
+ temporal_coords: None | torch.Tensor = None,
718
+ location_coords: None | torch.Tensor = None,
719
+ mask_ratio: float = 0.75
720
+ ):
721
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
722
+ # add time dim
723
+ pixel_values = pixel_values.unsqueeze(2)
724
+
725
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
726
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
727
+ loss = self.forward_loss(pixel_values, pred, mask)
728
+ return loss, pred, mask
729
+
730
+ def forward_features(
731
+ self,
732
+ x: torch.Tensor,
733
+ temporal_coords: None | torch.Tensor = None,
734
+ location_coords: None | torch.Tensor = None,
735
+ ) -> List[torch.Tensor]:
736
+ return self.encoder.forward_features(x, temporal_coords, location_coords)