blumenstiel
commited on
Commit
·
7bfe9ea
1
Parent(s):
8ab0a02
Updated inference code
Browse files- .gitattributes +1 -0
- Prithvi.py +0 -319
- README.md +2 -2
- Prithvi_run_inference.py → inference.py +65 -89
- prithvi_mae.py +736 -0
.gitattributes
CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
|
37 |
Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
|
38 |
GFM.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
36 |
Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
|
37 |
Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
|
38 |
GFM.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.tif filter=lfs diff=lfs merge=lfs -text
|
Prithvi.py
DELETED
@@ -1,319 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
# --------------------------------------------------------
|
7 |
-
# References:
|
8 |
-
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
9 |
-
# DeiT: https://github.com/facebookresearch/deit
|
10 |
-
# --------------------------------------------------------
|
11 |
-
|
12 |
-
from functools import partial
|
13 |
-
|
14 |
-
import torch
|
15 |
-
import torch.nn as nn
|
16 |
-
|
17 |
-
from timm.models.vision_transformer import Block
|
18 |
-
from timm.models.layers import to_2tuple
|
19 |
-
|
20 |
-
import numpy as np
|
21 |
-
|
22 |
-
from einops import rearrange
|
23 |
-
|
24 |
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
25 |
-
"""
|
26 |
-
embed_dim: output dimension for each position
|
27 |
-
pos: a list of positions to be encoded: size (M,)
|
28 |
-
out: (M, D)
|
29 |
-
"""
|
30 |
-
assert embed_dim % 2 == 0
|
31 |
-
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
32 |
-
omega /= embed_dim / 2.
|
33 |
-
omega = 1. / 10000**omega # (D/2,)
|
34 |
-
|
35 |
-
pos = pos.reshape(-1) # (M,)
|
36 |
-
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
37 |
-
|
38 |
-
emb_sin = np.sin(out) # (M, D/2)
|
39 |
-
emb_cos = np.cos(out) # (M, D/2)
|
40 |
-
|
41 |
-
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
42 |
-
return emb
|
43 |
-
|
44 |
-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
45 |
-
assert embed_dim % 2 == 0
|
46 |
-
|
47 |
-
# use half of dimensions to encode grid_h
|
48 |
-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
49 |
-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
50 |
-
|
51 |
-
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
52 |
-
return emb
|
53 |
-
|
54 |
-
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
55 |
-
"""
|
56 |
-
grid_size: 3d tuple of grid size: t, h, w
|
57 |
-
return:
|
58 |
-
pos_embed: L, D
|
59 |
-
"""
|
60 |
-
|
61 |
-
assert embed_dim % 16 == 0
|
62 |
-
|
63 |
-
t_size, h_size, w_size = grid_size
|
64 |
-
|
65 |
-
w_embed_dim = embed_dim // 16 * 6
|
66 |
-
h_embed_dim = embed_dim // 16 * 6
|
67 |
-
t_embed_dim = embed_dim // 16 * 4
|
68 |
-
|
69 |
-
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
70 |
-
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
71 |
-
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
72 |
-
|
73 |
-
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
74 |
-
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
75 |
-
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
76 |
-
|
77 |
-
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
78 |
-
|
79 |
-
if cls_token:
|
80 |
-
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
81 |
-
return pos_embed
|
82 |
-
|
83 |
-
|
84 |
-
class PatchEmbed(nn.Module):
|
85 |
-
""" Frames of 2D Images to Patch Embedding
|
86 |
-
The 3D version of timm.models.vision_transformer.PatchEmbed
|
87 |
-
"""
|
88 |
-
def __init__(
|
89 |
-
self,
|
90 |
-
img_size=224,
|
91 |
-
patch_size=16,
|
92 |
-
num_frames=3,
|
93 |
-
tubelet_size=1,
|
94 |
-
in_chans=3,
|
95 |
-
embed_dim=768,
|
96 |
-
norm_layer=None,
|
97 |
-
flatten=True,
|
98 |
-
bias=True,
|
99 |
-
):
|
100 |
-
super().__init__()
|
101 |
-
img_size = to_2tuple(img_size)
|
102 |
-
patch_size = to_2tuple(patch_size)
|
103 |
-
self.img_size = img_size
|
104 |
-
self.patch_size = patch_size
|
105 |
-
self.num_frames = num_frames
|
106 |
-
self.tubelet_size = tubelet_size
|
107 |
-
self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
108 |
-
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
109 |
-
self.flatten = flatten
|
110 |
-
|
111 |
-
self.proj = nn.Conv3d(in_chans, embed_dim,
|
112 |
-
kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
|
113 |
-
stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias)
|
114 |
-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
115 |
-
|
116 |
-
def forward(self, x):
|
117 |
-
B, C, T, H, W = x.shape
|
118 |
-
x = self.proj(x)
|
119 |
-
if self.flatten:
|
120 |
-
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
121 |
-
x = self.norm(x)
|
122 |
-
return x
|
123 |
-
|
124 |
-
|
125 |
-
class MaskedAutoencoderViT(nn.Module):
|
126 |
-
""" Masked Autoencoder with VisionTransformer backbone
|
127 |
-
"""
|
128 |
-
def __init__(self, img_size=224, patch_size=16,
|
129 |
-
num_frames=3, tubelet_size=1,
|
130 |
-
in_chans=3, embed_dim=1024, depth=24, num_heads=16,
|
131 |
-
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
132 |
-
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
|
133 |
-
super().__init__()
|
134 |
-
|
135 |
-
# --------------------------------------------------------------------------
|
136 |
-
# MAE encoder specifics
|
137 |
-
self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim)
|
138 |
-
num_patches = self.patch_embed.num_patches
|
139 |
-
|
140 |
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
141 |
-
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
142 |
-
|
143 |
-
self.blocks = nn.ModuleList([
|
144 |
-
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
145 |
-
for i in range(depth)])
|
146 |
-
self.norm = norm_layer(embed_dim)
|
147 |
-
# --------------------------------------------------------------------------
|
148 |
-
|
149 |
-
# --------------------------------------------------------------------------
|
150 |
-
# MAE decoder specifics
|
151 |
-
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
152 |
-
|
153 |
-
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
154 |
-
|
155 |
-
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
156 |
-
|
157 |
-
self.decoder_blocks = nn.ModuleList([
|
158 |
-
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
|
159 |
-
for i in range(decoder_depth)])
|
160 |
-
|
161 |
-
self.decoder_norm = norm_layer(decoder_embed_dim)
|
162 |
-
self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch
|
163 |
-
# --------------------------------------------------------------------------
|
164 |
-
|
165 |
-
self.norm_pix_loss = norm_pix_loss
|
166 |
-
|
167 |
-
self.initialize_weights()
|
168 |
-
|
169 |
-
def initialize_weights(self):
|
170 |
-
# initialization
|
171 |
-
# initialize (and freeze) pos_embed by sin-cos embedding
|
172 |
-
pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
|
173 |
-
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
174 |
-
|
175 |
-
decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
|
176 |
-
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
177 |
-
|
178 |
-
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
179 |
-
w = self.patch_embed.proj.weight.data
|
180 |
-
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
181 |
-
|
182 |
-
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
183 |
-
torch.nn.init.normal_(self.cls_token, std=.02)
|
184 |
-
torch.nn.init.normal_(self.mask_token, std=.02)
|
185 |
-
|
186 |
-
# initialize nn.Linear and nn.LayerNorm
|
187 |
-
self.apply(self._init_weights)
|
188 |
-
|
189 |
-
def _init_weights(self, m):
|
190 |
-
if isinstance(m, nn.Linear):
|
191 |
-
# we use xavier_uniform following official JAX ViT:
|
192 |
-
torch.nn.init.xavier_uniform_(m.weight)
|
193 |
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
194 |
-
nn.init.constant_(m.bias, 0)
|
195 |
-
elif isinstance(m, nn.LayerNorm):
|
196 |
-
nn.init.constant_(m.bias, 0)
|
197 |
-
nn.init.constant_(m.weight, 1.0)
|
198 |
-
|
199 |
-
def patchify(self, imgs):
|
200 |
-
"""
|
201 |
-
imgs: B, C, T, H, W
|
202 |
-
x: B, L, D
|
203 |
-
"""
|
204 |
-
p = self.patch_embed.patch_size[0]
|
205 |
-
tub = self.patch_embed.tubelet_size
|
206 |
-
x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p)
|
207 |
-
|
208 |
-
return x
|
209 |
-
|
210 |
-
def unpatchify(self, x):
|
211 |
-
"""
|
212 |
-
x: B, L, D
|
213 |
-
imgs: B, C, T, H, W
|
214 |
-
"""
|
215 |
-
p = self.patch_embed.patch_size[0]
|
216 |
-
num_p = self.patch_embed.img_size[0] // p
|
217 |
-
tub = self.patch_embed.tubelet_size
|
218 |
-
imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p)
|
219 |
-
return imgs
|
220 |
-
|
221 |
-
def random_masking(self, x, mask_ratio):
|
222 |
-
"""
|
223 |
-
Perform per-sample random masking by per-sample shuffling.
|
224 |
-
Per-sample shuffling is done by argsort random noise.
|
225 |
-
x: [N, L, D], sequence
|
226 |
-
"""
|
227 |
-
N, L, D = x.shape # batch, length, dim
|
228 |
-
len_keep = int(L * (1 - mask_ratio))
|
229 |
-
|
230 |
-
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
231 |
-
|
232 |
-
# sort noise for each sample
|
233 |
-
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
234 |
-
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
235 |
-
|
236 |
-
# keep the first subset
|
237 |
-
ids_keep = ids_shuffle[:, :len_keep]
|
238 |
-
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
239 |
-
|
240 |
-
# generate the binary mask: 0 is keep, 1 is remove
|
241 |
-
mask = torch.ones([N, L], device=x.device)
|
242 |
-
mask[:, :len_keep] = 0
|
243 |
-
# unshuffle to get the binary mask
|
244 |
-
mask = torch.gather(mask, dim=1, index=ids_restore)
|
245 |
-
|
246 |
-
return x_masked, mask, ids_restore
|
247 |
-
|
248 |
-
def forward_encoder(self, x, mask_ratio):
|
249 |
-
# embed patches
|
250 |
-
x = self.patch_embed(x)
|
251 |
-
|
252 |
-
# add pos embed w/o cls token
|
253 |
-
x = x + self.pos_embed[:, 1:, :]
|
254 |
-
|
255 |
-
# masking: length -> length * mask_ratio
|
256 |
-
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
257 |
-
|
258 |
-
# append cls token
|
259 |
-
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
260 |
-
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
261 |
-
x = torch.cat((cls_tokens, x), dim=1)
|
262 |
-
|
263 |
-
# apply Transformer blocks
|
264 |
-
for blk in self.blocks:
|
265 |
-
x = blk(x)
|
266 |
-
x = self.norm(x)
|
267 |
-
|
268 |
-
return x, mask, ids_restore
|
269 |
-
|
270 |
-
def forward_decoder(self, x, ids_restore):
|
271 |
-
# embed tokens
|
272 |
-
x = self.decoder_embed(x)
|
273 |
-
|
274 |
-
# append mask tokens to sequence
|
275 |
-
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
276 |
-
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
277 |
-
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
278 |
-
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
279 |
-
|
280 |
-
# add pos embed
|
281 |
-
x = x + self.decoder_pos_embed
|
282 |
-
|
283 |
-
# apply Transformer blocks
|
284 |
-
for blk in self.decoder_blocks:
|
285 |
-
x = blk(x)
|
286 |
-
x = self.decoder_norm(x)
|
287 |
-
|
288 |
-
# predictor projection
|
289 |
-
x = self.decoder_pred(x)
|
290 |
-
|
291 |
-
# remove cls token
|
292 |
-
x = x[:, 1:, :]
|
293 |
-
|
294 |
-
return x
|
295 |
-
|
296 |
-
def forward_loss(self, imgs, pred, mask):
|
297 |
-
"""
|
298 |
-
imgs: B, C, T, H, W
|
299 |
-
target: B, L, D
|
300 |
-
pred: B, L, D
|
301 |
-
mask: B, L. 0 is keep, 1 is remove,
|
302 |
-
"""
|
303 |
-
target = self.patchify(imgs)
|
304 |
-
if self.norm_pix_loss:
|
305 |
-
mean = target.mean(dim=-1, keepdim=True)
|
306 |
-
var = target.var(dim=-1, keepdim=True)
|
307 |
-
target = (target - mean) / (var + 1.e-6)**.5
|
308 |
-
|
309 |
-
loss = (pred - target) ** 2
|
310 |
-
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
311 |
-
|
312 |
-
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
313 |
-
return loss
|
314 |
-
|
315 |
-
def forward(self, imgs, mask_ratio=0.75):
|
316 |
-
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
|
317 |
-
pred = self.forward_decoder(latent, ids_restore)
|
318 |
-
loss = self.forward_loss(imgs, pred, mask)
|
319 |
-
return loss, pred, mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -33,10 +33,10 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
|
|
33 |
4. adding infrared bands besides RGB
|
34 |
|
35 |
### Inference and demo
|
36 |
-
There is an inference script (`
|
37 |
|
38 |
```
|
39 |
-
python
|
40 |
```
|
41 |
|
42 |
This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
|
|
|
33 |
4. adding infrared bands besides RGB
|
34 |
|
35 |
### Inference and demo
|
36 |
+
There is an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
|
37 |
|
38 |
```
|
39 |
+
python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
|
40 |
```
|
41 |
|
42 |
This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
|
Prithvi_run_inference.py → inference.py
RENAMED
@@ -2,21 +2,25 @@ import argparse
|
|
2 |
import functools
|
3 |
import os
|
4 |
from typing import List, Union
|
5 |
-
|
|
|
6 |
import numpy as np
|
|
|
7 |
import rasterio
|
8 |
import torch
|
9 |
import yaml
|
10 |
from einops import rearrange
|
11 |
|
12 |
-
from
|
|
|
13 |
|
14 |
NO_DATA = -9999
|
15 |
NO_DATA_FLOAT = 0.0001
|
16 |
-
|
|
|
17 |
|
18 |
|
19 |
-
def process_channel_group(orig_img, new_img, channels,
|
20 |
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
21 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
22 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
@@ -25,43 +29,36 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
|
|
25 |
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
26 |
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
27 |
channels: list of indices representing RGB channels.
|
28 |
-
|
29 |
-
|
30 |
|
31 |
Returns:
|
32 |
torch.Tensor with shape (num_channels, height, width) for original image
|
33 |
torch.Tensor with shape (num_channels, height, width) for the other image
|
34 |
"""
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
valid_mask[orig_ch == NO_DATA_FLOAT] = False
|
42 |
-
|
43 |
-
# Back to original data range
|
44 |
-
orig_ch = (orig_ch * data_std[c]) + data_mean[c]
|
45 |
-
new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
|
46 |
-
|
47 |
-
# Rescale (enhancing contrast)
|
48 |
-
min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
|
49 |
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
|
64 |
-
return
|
65 |
|
66 |
|
67 |
def read_geotiff(file_path: str):
|
@@ -78,8 +75,13 @@ def read_geotiff(file_path: str):
|
|
78 |
with rasterio.open(file_path) as src:
|
79 |
img = src.read()
|
80 |
meta = src.meta
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
return img, meta
|
83 |
|
84 |
|
85 |
def save_geotiff(image, output_path: str, meta: dict):
|
@@ -127,7 +129,7 @@ def load_example(
|
|
127 |
metas = []
|
128 |
|
129 |
for file in file_paths:
|
130 |
-
img, meta = read_geotiff(file)
|
131 |
|
132 |
# Rescaling (don't normalize on nodata)
|
133 |
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
@@ -140,7 +142,7 @@ def load_example(
|
|
140 |
|
141 |
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
142 |
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
143 |
-
imgs = np.expand_dims(imgs, axis=0) # add batch
|
144 |
|
145 |
return imgs, metas
|
146 |
|
@@ -166,7 +168,7 @@ def run_model(
|
|
166 |
with torch.no_grad():
|
167 |
x = input_data.to(device)
|
168 |
|
169 |
-
_, pred, mask = model(x, mask_ratio)
|
170 |
|
171 |
# Create mask and prediction images (un-patchify)
|
172 |
mask_img = (
|
@@ -207,8 +209,8 @@ def save_rgb_imgs(
|
|
207 |
orig_img=input_img[:, t, :, :],
|
208 |
new_img=rec_img[:, t, :, :],
|
209 |
channels=channels,
|
210 |
-
|
211 |
-
|
212 |
)
|
213 |
|
214 |
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
@@ -272,11 +274,10 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
|
272 |
|
273 |
def main(
|
274 |
data_files: List[str],
|
275 |
-
|
276 |
checkpoint: str,
|
277 |
output_dir: str,
|
278 |
rgb_outputs: bool,
|
279 |
-
img_size: int,
|
280 |
mask_ratio: float = None,
|
281 |
input_indices: list[int] = None,
|
282 |
):
|
@@ -284,31 +285,17 @@ def main(
|
|
284 |
|
285 |
# Get parameters --------
|
286 |
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
# data related
|
291 |
-
train_params = params["train_params"]
|
292 |
-
num_frames = len(data_files)
|
293 |
-
bands = train_params["bands"]
|
294 |
-
mean = train_params["data_mean"]
|
295 |
-
std = train_params["data_std"]
|
296 |
-
|
297 |
-
# model related
|
298 |
-
model_params = params["model_args"]
|
299 |
-
img_size = model_params["img_size"] if img_size is None else img_size
|
300 |
-
depth = model_params["depth"]
|
301 |
-
patch_size = model_params["patch_size"]
|
302 |
-
embed_dim = model_params["embed_dim"]
|
303 |
-
num_heads = model_params["num_heads"]
|
304 |
-
tubelet_size = model_params["tubelet_size"]
|
305 |
-
decoder_embed_dim = model_params["decoder_embed_dim"]
|
306 |
-
decoder_num_heads = model_params["decoder_num_heads"]
|
307 |
-
decoder_depth = model_params["decoder_depth"]
|
308 |
|
309 |
batch_size = 1
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
312 |
|
313 |
print(
|
314 |
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
@@ -333,23 +320,13 @@ def main(
|
|
333 |
|
334 |
# Create model and load checkpoint -------------------------------------------------------------
|
335 |
|
336 |
-
|
337 |
-
img_size=img_size,
|
338 |
-
patch_size=patch_size,
|
339 |
num_frames=num_frames,
|
340 |
-
tubelet_size=tubelet_size,
|
341 |
in_chans=len(bands),
|
342 |
-
embed_dim=embed_dim,
|
343 |
-
depth=depth,
|
344 |
-
num_heads=num_heads,
|
345 |
-
decoder_embed_dim=decoder_embed_dim,
|
346 |
-
decoder_depth=decoder_depth,
|
347 |
-
decoder_num_heads=decoder_num_heads,
|
348 |
-
mlp_ratio=4.0,
|
349 |
-
norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
|
350 |
-
norm_pix_loss=False,
|
351 |
)
|
352 |
|
|
|
|
|
353 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
354 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
355 |
|
@@ -357,8 +334,9 @@ def main(
|
|
357 |
|
358 |
state_dict = torch.load(checkpoint, map_location=device)
|
359 |
# discard fixed pos_embedding weight
|
360 |
-
|
361 |
-
|
|
|
362 |
model.load_state_dict(state_dict, strict=False)
|
363 |
print(f"Loaded checkpoint from {checkpoint}")
|
364 |
|
@@ -463,42 +441,40 @@ if __name__ == "__main__":
|
|
463 |
|
464 |
parser.add_argument(
|
465 |
"--data_files",
|
466 |
-
required=True,
|
467 |
type=str,
|
468 |
nargs="+",
|
|
|
|
|
|
|
|
|
469 |
help="Path to the data files. Assumes multi-band files.",
|
470 |
)
|
471 |
parser.add_argument(
|
472 |
-
"--
|
|
|
473 |
type=str,
|
474 |
-
|
475 |
-
help="Path to
|
476 |
)
|
477 |
parser.add_argument(
|
478 |
"--checkpoint",
|
479 |
-
required=True,
|
480 |
type=str,
|
|
|
481 |
help="Path to a checkpoint file to load from.",
|
482 |
)
|
483 |
parser.add_argument(
|
484 |
"--output_dir",
|
485 |
-
required=True,
|
486 |
type=str,
|
|
|
487 |
help="Path to the directory where to save outputs.",
|
488 |
)
|
489 |
parser.add_argument(
|
490 |
"--mask_ratio",
|
491 |
-
default=
|
492 |
type=float,
|
493 |
help="Masking ratio (percentage of removed patches). "
|
494 |
"If None (default) use same value used for pretraining.",
|
495 |
)
|
496 |
-
parser.add_argument(
|
497 |
-
"--img_size",
|
498 |
-
default=224,
|
499 |
-
type=int,
|
500 |
-
help="Image size to be used with model. Defaults to 224",
|
501 |
-
)
|
502 |
parser.add_argument(
|
503 |
"--input_indices",
|
504 |
default=None,
|
|
|
2 |
import functools
|
3 |
import os
|
4 |
from typing import List, Union
|
5 |
+
import re
|
6 |
+
import datetime
|
7 |
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
import rasterio
|
10 |
import torch
|
11 |
import yaml
|
12 |
from einops import rearrange
|
13 |
|
14 |
+
from functools import partial
|
15 |
+
from prithvi_mae import PrithviMAE
|
16 |
|
17 |
NO_DATA = -9999
|
18 |
NO_DATA_FLOAT = 0.0001
|
19 |
+
OFFSET = 0
|
20 |
+
PERCENTILE = 99.9
|
21 |
|
22 |
|
23 |
+
def process_channel_group(orig_img, new_img, channels, mean, std):
|
24 |
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
25 |
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
26 |
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
|
|
29 |
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
30 |
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
31 |
channels: list of indices representing RGB channels.
|
32 |
+
mean: list of mean values for each band.
|
33 |
+
std: list of std values for each band.
|
34 |
|
35 |
Returns:
|
36 |
torch.Tensor with shape (num_channels, height, width) for original image
|
37 |
torch.Tensor with shape (num_channels, height, width) for the other image
|
38 |
"""
|
39 |
|
40 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
41 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
42 |
+
orig_img = orig_img[channels, ...]
|
43 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
44 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# Back to original data range
|
47 |
+
orig_img = (orig_img * std[channels]) + mean[channels]
|
48 |
+
new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
|
49 |
|
50 |
+
# Rescale (enhancing contrast)
|
51 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
52 |
+
min_value = OFFSET
|
53 |
|
54 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
55 |
+
new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
|
56 |
|
57 |
+
# No data as zeros
|
58 |
+
orig_img[~valid_mask] = 0
|
59 |
+
new_img[~valid_mask] = 0
|
60 |
|
61 |
+
return orig_img, new_img
|
62 |
|
63 |
|
64 |
def read_geotiff(file_path: str):
|
|
|
75 |
with rasterio.open(file_path) as src:
|
76 |
img = src.read()
|
77 |
meta = src.meta
|
78 |
+
try:
|
79 |
+
coords = src.lnglat()
|
80 |
+
except:
|
81 |
+
# Cannot read coords
|
82 |
+
coords = None
|
83 |
|
84 |
+
return img, meta, coords
|
85 |
|
86 |
|
87 |
def save_geotiff(image, output_path: str, meta: dict):
|
|
|
129 |
metas = []
|
130 |
|
131 |
for file in file_paths:
|
132 |
+
img, meta, _ = read_geotiff(file)
|
133 |
|
134 |
# Rescaling (don't normalize on nodata)
|
135 |
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
|
|
142 |
|
143 |
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
144 |
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
145 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
146 |
|
147 |
return imgs, metas
|
148 |
|
|
|
168 |
with torch.no_grad():
|
169 |
x = input_data.to(device)
|
170 |
|
171 |
+
_, pred, mask = model(x, mask_ratio=mask_ratio)
|
172 |
|
173 |
# Create mask and prediction images (un-patchify)
|
174 |
mask_img = (
|
|
|
209 |
orig_img=input_img[:, t, :, :],
|
210 |
new_img=rec_img[:, t, :, :],
|
211 |
channels=channels,
|
212 |
+
mean=mean,
|
213 |
+
std=std,
|
214 |
)
|
215 |
|
216 |
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
|
|
274 |
|
275 |
def main(
|
276 |
data_files: List[str],
|
277 |
+
config_path: str,
|
278 |
checkpoint: str,
|
279 |
output_dir: str,
|
280 |
rgb_outputs: bool,
|
|
|
281 |
mask_ratio: float = None,
|
282 |
input_indices: list[int] = None,
|
283 |
):
|
|
|
285 |
|
286 |
# Get parameters --------
|
287 |
|
288 |
+
import json
|
289 |
+
with open(config_path, "r") as f:
|
290 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
batch_size = 1
|
293 |
+
bands = config['bands']
|
294 |
+
num_frames = len(data_files)
|
295 |
+
mean = config['mean']
|
296 |
+
std = config['std']
|
297 |
+
img_size = config['img_size']
|
298 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
299 |
|
300 |
print(
|
301 |
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
|
|
320 |
|
321 |
# Create model and load checkpoint -------------------------------------------------------------
|
322 |
|
323 |
+
config.update(
|
|
|
|
|
324 |
num_frames=num_frames,
|
|
|
325 |
in_chans=len(bands),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
)
|
327 |
|
328 |
+
model = PrithviMAE(**config)
|
329 |
+
|
330 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
331 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
332 |
|
|
|
334 |
|
335 |
state_dict = torch.load(checkpoint, map_location=device)
|
336 |
# discard fixed pos_embedding weight
|
337 |
+
for k in list(state_dict.keys()):
|
338 |
+
if 'pos_embed' in k:
|
339 |
+
del state_dict[k]
|
340 |
model.load_state_dict(state_dict, strict=False)
|
341 |
print(f"Loaded checkpoint from {checkpoint}")
|
342 |
|
|
|
441 |
|
442 |
parser.add_argument(
|
443 |
"--data_files",
|
|
|
444 |
type=str,
|
445 |
nargs="+",
|
446 |
+
default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
|
447 |
+
"examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
|
448 |
+
"examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
|
449 |
+
],
|
450 |
help="Path to the data files. Assumes multi-band files.",
|
451 |
)
|
452 |
parser.add_argument(
|
453 |
+
"--config_path",
|
454 |
+
"-c",
|
455 |
type=str,
|
456 |
+
default="config.json",
|
457 |
+
help="Path to json file containing model training parameters.",
|
458 |
)
|
459 |
parser.add_argument(
|
460 |
"--checkpoint",
|
|
|
461 |
type=str,
|
462 |
+
default="Prithvi_EO_V1_100M.pt",
|
463 |
help="Path to a checkpoint file to load from.",
|
464 |
)
|
465 |
parser.add_argument(
|
466 |
"--output_dir",
|
|
|
467 |
type=str,
|
468 |
+
default="output",
|
469 |
help="Path to the directory where to save outputs.",
|
470 |
)
|
471 |
parser.add_argument(
|
472 |
"--mask_ratio",
|
473 |
+
default=0.75,
|
474 |
type=float,
|
475 |
help="Masking ratio (percentage of removed patches). "
|
476 |
"If None (default) use same value used for pretraining.",
|
477 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
parser.add_argument(
|
479 |
"--input_indices",
|
480 |
default=None,
|
prithvi_mae.py
ADDED
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------
|
15 |
+
# References:
|
16 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
17 |
+
# transformers: https://github.com/huggingface/transformers
|
18 |
+
# --------------------------------------------------------
|
19 |
+
|
20 |
+
from functools import partial
|
21 |
+
from typing import List, Tuple
|
22 |
+
|
23 |
+
import logging
|
24 |
+
import numpy as np
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
from einops import rearrange
|
28 |
+
from timm.layers import to_2tuple
|
29 |
+
from timm.models.vision_transformer import Block
|
30 |
+
|
31 |
+
|
32 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
33 |
+
"""
|
34 |
+
Create 3D sin/cos positional embeddings.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
embed_dim (int):
|
38 |
+
Embedding dimension.
|
39 |
+
grid_size (tuple[int, int, int] | list[int]):
|
40 |
+
The grid depth, height and width.
|
41 |
+
add_cls_token (bool, *optional*, defaults to False):
|
42 |
+
Whether or not to add a classification (CLS) token.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
46 |
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
47 |
+
"""
|
48 |
+
|
49 |
+
assert embed_dim % 16 == 0
|
50 |
+
|
51 |
+
t_size, h_size, w_size = grid_size
|
52 |
+
|
53 |
+
w_embed_dim = embed_dim // 16 * 6
|
54 |
+
h_embed_dim = embed_dim // 16 * 6
|
55 |
+
t_embed_dim = embed_dim // 16 * 4
|
56 |
+
|
57 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
58 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
59 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
60 |
+
|
61 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
62 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
63 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
64 |
+
|
65 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
66 |
+
|
67 |
+
if add_cls_token:
|
68 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
69 |
+
return pos_embed
|
70 |
+
|
71 |
+
|
72 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
73 |
+
"""
|
74 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
75 |
+
"""
|
76 |
+
if embed_dim % 2 != 0:
|
77 |
+
raise ValueError("embed_dim must be even")
|
78 |
+
|
79 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
80 |
+
omega /= embed_dim / 2.0
|
81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
82 |
+
|
83 |
+
pos = pos.reshape(-1) # (M,)
|
84 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
85 |
+
|
86 |
+
emb_sin = np.sin(out) # (M, D/2)
|
87 |
+
emb_cos = np.cos(out) # (M, D/2)
|
88 |
+
|
89 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
90 |
+
return emb
|
91 |
+
|
92 |
+
|
93 |
+
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
|
94 |
+
""" This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
|
95 |
+
it was modified to cast omega values to pos.dtype which must be float (and not int as in
|
96 |
+
regular positional embeddings). This was required in order to allow for native FSDP mixed
|
97 |
+
precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
|
98 |
+
instead of manually forcing float32.
|
99 |
+
|
100 |
+
embed_dim: output dimension for each position
|
101 |
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
102 |
+
out: (M, D)
|
103 |
+
"""
|
104 |
+
assert embed_dim % 2 == 0
|
105 |
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
106 |
+
|
107 |
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
108 |
+
omega /= embed_dim / 2.0
|
109 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
110 |
+
|
111 |
+
pos = pos.reshape(-1) # (M,)
|
112 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
113 |
+
|
114 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
115 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
116 |
+
|
117 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
118 |
+
|
119 |
+
return emb
|
120 |
+
|
121 |
+
|
122 |
+
def _init_weights(module):
|
123 |
+
"""Initialize the weights"""
|
124 |
+
if isinstance(module, nn.Linear):
|
125 |
+
nn.init.xavier_uniform_(module.weight)
|
126 |
+
if module.bias is not None:
|
127 |
+
module.bias.data.zero_()
|
128 |
+
elif isinstance(module, nn.LayerNorm):
|
129 |
+
module.bias.data.zero_()
|
130 |
+
module.weight.data.fill_(1.0)
|
131 |
+
|
132 |
+
|
133 |
+
class PatchEmbed(nn.Module):
|
134 |
+
"""3D version of timm.models.vision_transformer.PatchEmbed"""
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
input_size: Tuple[int, int, int] = (1, 224, 224),
|
138 |
+
patch_size: Tuple[int, int, int] = (1, 16, 16),
|
139 |
+
in_chans: int = 3,
|
140 |
+
embed_dim: int = 768,
|
141 |
+
norm_layer: nn.Module | None = None,
|
142 |
+
flatten: bool = True,
|
143 |
+
bias: bool = True,
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
self.input_size = input_size
|
147 |
+
self.patch_size = patch_size
|
148 |
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
149 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
150 |
+
self.flatten = flatten
|
151 |
+
|
152 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
153 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
B, C, T, H, W = x.shape
|
157 |
+
|
158 |
+
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
|
159 |
+
logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
160 |
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
|
161 |
+
|
162 |
+
x = self.proj(x)
|
163 |
+
if self.flatten:
|
164 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
165 |
+
x = self.norm(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
|
169 |
+
class TemporalEncoder(nn.Module):
|
170 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
171 |
+
super().__init__()
|
172 |
+
self.embed_dim = embed_dim
|
173 |
+
self.year_embed_dim = embed_dim // 2
|
174 |
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
175 |
+
|
176 |
+
# If trainable, initialize scale with small number
|
177 |
+
if trainable_scale:
|
178 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
179 |
+
else:
|
180 |
+
self.register_buffer('scale', torch.ones(1))
|
181 |
+
|
182 |
+
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
|
183 |
+
"""
|
184 |
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
185 |
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
186 |
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
187 |
+
"""
|
188 |
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
189 |
+
|
190 |
+
year = _get_1d_sincos_embed_from_grid_torch(
|
191 |
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
|
192 |
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
193 |
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
|
194 |
+
|
195 |
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
196 |
+
|
197 |
+
if tokens_per_frame is not None:
|
198 |
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
199 |
+
|
200 |
+
return embedding # B, T*tokens_per_frame, embed_dim
|
201 |
+
|
202 |
+
|
203 |
+
class LocationEncoder(nn.Module):
|
204 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
205 |
+
super().__init__()
|
206 |
+
self.embed_dim = embed_dim
|
207 |
+
self.lat_embed_dim = embed_dim // 2
|
208 |
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
209 |
+
|
210 |
+
# If trainable, initialize scale with small number
|
211 |
+
if trainable_scale:
|
212 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
213 |
+
else:
|
214 |
+
self.register_buffer('scale', torch.ones(1))
|
215 |
+
|
216 |
+
def forward(self, location_coords: torch.Tensor):
|
217 |
+
"""
|
218 |
+
location_coords: lat and lon info with shape (B, 2).
|
219 |
+
"""
|
220 |
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
221 |
+
|
222 |
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
223 |
+
self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
|
224 |
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
225 |
+
self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
|
226 |
+
|
227 |
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
228 |
+
|
229 |
+
return embedding # B, 1, embed_dim
|
230 |
+
|
231 |
+
|
232 |
+
class PrithviViT(nn.Module):
|
233 |
+
""" Prithvi ViT Encoder"""
|
234 |
+
def __init__(self,
|
235 |
+
img_size: int | Tuple[int, int] = 224,
|
236 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
237 |
+
num_frames: int = 1,
|
238 |
+
in_chans: int = 3,
|
239 |
+
embed_dim: int = 1024,
|
240 |
+
depth: int = 24,
|
241 |
+
num_heads: int = 16,
|
242 |
+
mlp_ratio: float = 4.,
|
243 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
244 |
+
coords_encoding: List[str] | None = None,
|
245 |
+
coords_scale_learn: bool = False,
|
246 |
+
encoder_only: bool = True, # needed for timm
|
247 |
+
** kwargs,
|
248 |
+
):
|
249 |
+
super().__init__()
|
250 |
+
|
251 |
+
self.feature_info = []
|
252 |
+
self.encoder_only = encoder_only
|
253 |
+
self.in_chans = in_chans
|
254 |
+
self.num_frames = num_frames
|
255 |
+
self.embed_dim = embed_dim
|
256 |
+
self.img_size = to_2tuple(img_size)
|
257 |
+
if isinstance(patch_size, int):
|
258 |
+
patch_size = (1, patch_size, patch_size)
|
259 |
+
|
260 |
+
# 3D patch embedding
|
261 |
+
self.patch_embed = PatchEmbed(
|
262 |
+
input_size=(num_frames,) + self.img_size,
|
263 |
+
patch_size=patch_size,
|
264 |
+
in_chans=in_chans,
|
265 |
+
embed_dim=embed_dim,
|
266 |
+
)
|
267 |
+
|
268 |
+
# Optional temporal and location embedding
|
269 |
+
coords_encoding = coords_encoding or []
|
270 |
+
self.temporal_encoding = 'time' in coords_encoding
|
271 |
+
self.location_encoding = 'location' in coords_encoding
|
272 |
+
if self.temporal_encoding:
|
273 |
+
assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
274 |
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
275 |
+
if self.location_encoding:
|
276 |
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
277 |
+
|
278 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
279 |
+
self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
280 |
+
|
281 |
+
# Transformer layers
|
282 |
+
self.blocks = []
|
283 |
+
for i in range(depth):
|
284 |
+
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
|
285 |
+
self.feature_info.append(
|
286 |
+
{"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
|
287 |
+
)
|
288 |
+
self.blocks = nn.ModuleList(self.blocks)
|
289 |
+
|
290 |
+
self.norm = norm_layer(embed_dim)
|
291 |
+
|
292 |
+
self.initialize_weights()
|
293 |
+
|
294 |
+
def initialize_weights(self):
|
295 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
296 |
+
pos_embed = get_3d_sincos_pos_embed(
|
297 |
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
298 |
+
)
|
299 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
300 |
+
|
301 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
302 |
+
w = self.patch_embed.proj.weight.data
|
303 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
304 |
+
|
305 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
306 |
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
307 |
+
self.apply(_init_weights)
|
308 |
+
|
309 |
+
def random_masking(self, sequence, mask_ratio, noise=None):
|
310 |
+
"""
|
311 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
312 |
+
noise.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
316 |
+
mask_ratio (float): mask ratio to use.
|
317 |
+
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
318 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
319 |
+
"""
|
320 |
+
batch_size, seq_length, dim = sequence.shape
|
321 |
+
len_keep = int(seq_length * (1 - mask_ratio))
|
322 |
+
|
323 |
+
if noise is None:
|
324 |
+
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
325 |
+
|
326 |
+
# sort noise for each sample
|
327 |
+
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
|
328 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
329 |
+
|
330 |
+
# keep the first subset
|
331 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
332 |
+
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
|
333 |
+
|
334 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
335 |
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
336 |
+
mask[:, :len_keep] = 0
|
337 |
+
# unshuffle to get the binary mask
|
338 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
339 |
+
|
340 |
+
return sequence_unmasked, mask, ids_restore
|
341 |
+
|
342 |
+
def _get_pos_embed(self, x):
|
343 |
+
t, h, w = x.shape[-3:]
|
344 |
+
|
345 |
+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
|
346 |
+
self.embed_dim,
|
347 |
+
(
|
348 |
+
t // self.patch_embed.patch_size[0],
|
349 |
+
h // self.patch_embed.patch_size[1],
|
350 |
+
w // self.patch_embed.patch_size[2],
|
351 |
+
),
|
352 |
+
add_cls_token=True,
|
353 |
+
)).float().unsqueeze(0).to(x)
|
354 |
+
|
355 |
+
return pos_embed
|
356 |
+
|
357 |
+
|
358 |
+
def forward(
|
359 |
+
self, x: torch.Tensor,
|
360 |
+
temporal_coords: None | torch.Tensor = None,
|
361 |
+
location_coords: None | torch.Tensor = None,
|
362 |
+
mask_ratio=0.75
|
363 |
+
):
|
364 |
+
if x.shape[-3:] != self.patch_embed.input_size:
|
365 |
+
# changed input size
|
366 |
+
pos_embed = self._get_pos_embed(x)
|
367 |
+
else:
|
368 |
+
pos_embed = self.pos_embed
|
369 |
+
|
370 |
+
# embed patches
|
371 |
+
x = self.patch_embed(x)
|
372 |
+
|
373 |
+
# add pos embed w/o cls token
|
374 |
+
x = x + pos_embed[:, 1:, :]
|
375 |
+
|
376 |
+
if self.temporal_encoding:
|
377 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
378 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
379 |
+
x = x + temporal_encoding
|
380 |
+
if self.location_encoding:
|
381 |
+
location_encoding = self.location_embed_enc(location_coords)
|
382 |
+
x = x + location_encoding
|
383 |
+
|
384 |
+
# masking: length -> length * mask_ratio
|
385 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
386 |
+
|
387 |
+
# append cls token
|
388 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
389 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
390 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
391 |
+
|
392 |
+
# apply Transformer blocks
|
393 |
+
for block in self.blocks:
|
394 |
+
x = block(x)
|
395 |
+
x = self.norm(x)
|
396 |
+
|
397 |
+
return x, mask, ids_restore
|
398 |
+
|
399 |
+
def forward_features(
|
400 |
+
self,
|
401 |
+
x: torch.Tensor,
|
402 |
+
temporal_coords: None | torch.Tensor = None,
|
403 |
+
location_coords: None | torch.Tensor = None,
|
404 |
+
) -> list[torch.Tensor]:
|
405 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
406 |
+
# add time dim
|
407 |
+
x = x.unsqueeze(2)
|
408 |
+
|
409 |
+
if x.shape[-3:] != self.patch_embed.input_size:
|
410 |
+
pos_embed = self._get_pos_embed(x)
|
411 |
+
else:
|
412 |
+
pos_embed = self.pos_embed
|
413 |
+
|
414 |
+
# embed patches
|
415 |
+
x = self.patch_embed(x)
|
416 |
+
|
417 |
+
# add pos embed w/o cls token
|
418 |
+
x = x + pos_embed[:, 1:, :]
|
419 |
+
|
420 |
+
if self.temporal_encoding:
|
421 |
+
num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
|
422 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
423 |
+
x = x + temporal_encoding
|
424 |
+
if self.location_encoding:
|
425 |
+
location_encoding = self.location_embed_enc(location_coords)
|
426 |
+
x = x + location_encoding
|
427 |
+
|
428 |
+
# append cls token
|
429 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
430 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
431 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
432 |
+
|
433 |
+
# apply Transformer blocks
|
434 |
+
out = []
|
435 |
+
for block in self.blocks:
|
436 |
+
x = block(x)
|
437 |
+
out.append(x.clone())
|
438 |
+
|
439 |
+
x = self.norm(x)
|
440 |
+
out[-1] = x
|
441 |
+
return out
|
442 |
+
|
443 |
+
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
444 |
+
out = []
|
445 |
+
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
|
446 |
+
for x in features:
|
447 |
+
x_no_token = x[:, 1:, :]
|
448 |
+
number_of_tokens = x_no_token.shape[1]
|
449 |
+
tokens_per_timestep = number_of_tokens // effective_time_dim
|
450 |
+
h = int(np.sqrt(tokens_per_timestep))
|
451 |
+
encoded = rearrange(
|
452 |
+
x_no_token,
|
453 |
+
"batch (t h w) e -> batch (t e) h w",
|
454 |
+
e=self.embed_dim,
|
455 |
+
t=effective_time_dim,
|
456 |
+
h=h,
|
457 |
+
)
|
458 |
+
out.append(encoded)
|
459 |
+
return out
|
460 |
+
|
461 |
+
|
462 |
+
class MAEDecoder(nn.Module):
|
463 |
+
""" Transformer Decoder used in the Prithvi MAE"""
|
464 |
+
def __init__(self,
|
465 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
466 |
+
grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
|
467 |
+
in_chans: int = 3,
|
468 |
+
encoder_embed_dim: int = 1024,
|
469 |
+
decoder_embed_dim: int = 512,
|
470 |
+
depth: int = 8,
|
471 |
+
num_heads: int = 16,
|
472 |
+
mlp_ratio: float = 4.,
|
473 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
474 |
+
coords_encoding: List[str] | None = None,
|
475 |
+
coords_scale_learn: bool = False,
|
476 |
+
):
|
477 |
+
super().__init__()
|
478 |
+
|
479 |
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
480 |
+
self.decoder_embed_dim = decoder_embed_dim
|
481 |
+
self.grid_size = grid_size
|
482 |
+
if isinstance(patch_size, int):
|
483 |
+
patch_size = (1, patch_size, patch_size)
|
484 |
+
self.patch_size = patch_size
|
485 |
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
486 |
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
487 |
+
|
488 |
+
# Optional temporal and location embedding
|
489 |
+
coords_encoding = coords_encoding or []
|
490 |
+
self.temporal_encoding = 'time' in coords_encoding
|
491 |
+
self.location_encoding = 'location' in coords_encoding
|
492 |
+
if self.temporal_encoding:
|
493 |
+
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
|
494 |
+
if self.location_encoding:
|
495 |
+
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
|
496 |
+
|
497 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
498 |
+
|
499 |
+
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
|
500 |
+
|
501 |
+
self.decoder_blocks = nn.ModuleList(
|
502 |
+
[Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
|
503 |
+
)
|
504 |
+
|
505 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
506 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim,
|
507 |
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
508 |
+
bias=True)
|
509 |
+
|
510 |
+
self.initialize_weights()
|
511 |
+
|
512 |
+
def initialize_weights(self):
|
513 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
514 |
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
515 |
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
516 |
+
)
|
517 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
518 |
+
|
519 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
520 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
521 |
+
self.apply(_init_weights)
|
522 |
+
|
523 |
+
def forward(
|
524 |
+
self,
|
525 |
+
hidden_states: torch.Tensor,
|
526 |
+
ids_restore: torch.Tensor,
|
527 |
+
temporal_coords: None | torch.Tensor = None,
|
528 |
+
location_coords: None | torch.Tensor = None,
|
529 |
+
input_size: list[int] = None,
|
530 |
+
):
|
531 |
+
# embed tokens
|
532 |
+
x = self.decoder_embed(hidden_states)
|
533 |
+
|
534 |
+
t, h, w = input_size[-3:]
|
535 |
+
decoder_pos_embed = torch.from_numpy(
|
536 |
+
get_3d_sincos_pos_embed(
|
537 |
+
self.decoder_embed_dim,
|
538 |
+
(
|
539 |
+
t // self.patch_size[0],
|
540 |
+
h // self.patch_size[1],
|
541 |
+
w // self.patch_size[2],
|
542 |
+
),
|
543 |
+
add_cls_token=True,
|
544 |
+
)
|
545 |
+
).to(x)
|
546 |
+
|
547 |
+
# append mask tokens to sequence
|
548 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
549 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
550 |
+
# unshuffle
|
551 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
552 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
553 |
+
# add pos embed
|
554 |
+
x = x + decoder_pos_embed
|
555 |
+
|
556 |
+
# remove cls token
|
557 |
+
x_ = x[:, 1:, :]
|
558 |
+
|
559 |
+
if self.temporal_encoding:
|
560 |
+
num_tokens_per_frame = x_.shape[1] // self.num_frames
|
561 |
+
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
|
562 |
+
# Add temporal encoding w/o cls token
|
563 |
+
x_ = x_ + temporal_encoding
|
564 |
+
if self.location_encoding:
|
565 |
+
location_encoding = self.location_embed_dec(location_coords)
|
566 |
+
# Add location encoding w/o cls token
|
567 |
+
x_ = x_ + location_encoding
|
568 |
+
|
569 |
+
# append cls token
|
570 |
+
x = torch.cat([x[:, :1, :], x_], dim=1)
|
571 |
+
|
572 |
+
# apply Transformer layers (blocks)
|
573 |
+
for block in self.decoder_blocks:
|
574 |
+
x = block(x)
|
575 |
+
x = self.decoder_norm(x)
|
576 |
+
|
577 |
+
# predictor projection
|
578 |
+
pred = self.decoder_pred(x)
|
579 |
+
|
580 |
+
# remove cls token
|
581 |
+
pred = pred[:, 1:, :]
|
582 |
+
|
583 |
+
return pred
|
584 |
+
|
585 |
+
|
586 |
+
class PrithviMAE(nn.Module):
|
587 |
+
""" Prithvi Masked Autoencoder"""
|
588 |
+
|
589 |
+
def __init__(self,
|
590 |
+
img_size: int | Tuple[int, int] = 224,
|
591 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
592 |
+
num_frames: int = 3,
|
593 |
+
in_chans: int = 3,
|
594 |
+
embed_dim: int = 1024,
|
595 |
+
depth: int = 24,
|
596 |
+
num_heads: int = 16,
|
597 |
+
decoder_embed_dim: int = 512,
|
598 |
+
decoder_depth: int = 8,
|
599 |
+
decoder_num_heads: int = 16,
|
600 |
+
mlp_ratio: float = 4.,
|
601 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
602 |
+
norm_pix_loss: bool = False,
|
603 |
+
coords_encoding: List[str] | None = None,
|
604 |
+
coords_scale_learn: bool = False,
|
605 |
+
encoder_only: bool = False,
|
606 |
+
**kwargs,
|
607 |
+
):
|
608 |
+
super().__init__()
|
609 |
+
|
610 |
+
self.encoder = PrithviViT(
|
611 |
+
img_size=img_size,
|
612 |
+
num_frames=num_frames,
|
613 |
+
patch_size=patch_size,
|
614 |
+
in_chans=in_chans,
|
615 |
+
embed_dim=embed_dim,
|
616 |
+
depth=depth,
|
617 |
+
num_heads=num_heads,
|
618 |
+
mlp_ratio=mlp_ratio,
|
619 |
+
norm_layer=norm_layer,
|
620 |
+
coords_encoding=coords_encoding,
|
621 |
+
coords_scale_learn=coords_scale_learn,
|
622 |
+
)
|
623 |
+
|
624 |
+
self.encoder_only = encoder_only
|
625 |
+
|
626 |
+
if not encoder_only:
|
627 |
+
self.decoder = MAEDecoder(
|
628 |
+
patch_size=patch_size,
|
629 |
+
grid_size=self.encoder.patch_embed.grid_size,
|
630 |
+
in_chans=in_chans,
|
631 |
+
encoder_embed_dim=embed_dim,
|
632 |
+
decoder_embed_dim=decoder_embed_dim,
|
633 |
+
depth=decoder_depth,
|
634 |
+
num_heads=decoder_num_heads,
|
635 |
+
mlp_ratio=mlp_ratio,
|
636 |
+
norm_layer=norm_layer,
|
637 |
+
coords_encoding=coords_encoding,
|
638 |
+
coords_scale_learn=coords_scale_learn,
|
639 |
+
)
|
640 |
+
else:
|
641 |
+
self.decoder = nn.Identity()
|
642 |
+
|
643 |
+
self.norm_pix_loss = norm_pix_loss
|
644 |
+
|
645 |
+
def patchify(self, pixel_values):
|
646 |
+
"""
|
647 |
+
Args:
|
648 |
+
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
649 |
+
Pixel values.
|
650 |
+
|
651 |
+
Returns:
|
652 |
+
torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
653 |
+
Patchified pixel values.
|
654 |
+
"""
|
655 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
656 |
+
num_channels = self.encoder.in_chans
|
657 |
+
|
658 |
+
# patchify
|
659 |
+
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
|
660 |
+
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
661 |
+
|
662 |
+
|
663 |
+
return patchified_pixel_values
|
664 |
+
|
665 |
+
def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
|
666 |
+
"""
|
667 |
+
Args:
|
668 |
+
patchified_pixel_values (`torch.FloatTensor` of shape
|
669 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
670 |
+
Patchified pixel values.
|
671 |
+
image_size (`Tuple[int, int]`, *optional*):
|
672 |
+
Original image size.
|
673 |
+
|
674 |
+
Returns:
|
675 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
676 |
+
Pixel values.
|
677 |
+
"""
|
678 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
679 |
+
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
680 |
+
original_height, original_width = image_size
|
681 |
+
num_patches_h = original_height // patch_size_h
|
682 |
+
num_patches_w = original_width // patch_size_w
|
683 |
+
num_channels = self.encoder.in_chans
|
684 |
+
|
685 |
+
pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
|
686 |
+
c=num_channels, h=num_patches_h, w=num_patches_w,
|
687 |
+
s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
688 |
+
return pixel_values
|
689 |
+
|
690 |
+
def forward_loss(self, pixel_values, pred, mask):
|
691 |
+
"""
|
692 |
+
Args:
|
693 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
694 |
+
Pixel values.
|
695 |
+
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
696 |
+
Predicted pixel values.
|
697 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
698 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
699 |
+
|
700 |
+
Returns:
|
701 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
702 |
+
"""
|
703 |
+
target = self.patchify(pixel_values)
|
704 |
+
if self.norm_pix_loss:
|
705 |
+
mean = target.mean(dim=-1, keepdim=True)
|
706 |
+
var = target.var(dim=-1, keepdim=True)
|
707 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
708 |
+
|
709 |
+
loss = (pred - target) ** 2
|
710 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
711 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
712 |
+
return loss
|
713 |
+
|
714 |
+
def forward(
|
715 |
+
self,
|
716 |
+
pixel_values: torch.Tensor,
|
717 |
+
temporal_coords: None | torch.Tensor = None,
|
718 |
+
location_coords: None | torch.Tensor = None,
|
719 |
+
mask_ratio: float = 0.75
|
720 |
+
):
|
721 |
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
722 |
+
# add time dim
|
723 |
+
pixel_values = pixel_values.unsqueeze(2)
|
724 |
+
|
725 |
+
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
|
726 |
+
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
|
727 |
+
loss = self.forward_loss(pixel_values, pred, mask)
|
728 |
+
return loss, pred, mask
|
729 |
+
|
730 |
+
def forward_features(
|
731 |
+
self,
|
732 |
+
x: torch.Tensor,
|
733 |
+
temporal_coords: None | torch.Tensor = None,
|
734 |
+
location_coords: None | torch.Tensor = None,
|
735 |
+
) -> List[torch.Tensor]:
|
736 |
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|