ajstewart commited on
Commit
4c557dd
1 Parent(s): 0c252eb

Upload extract.py

Browse files
Files changed (1) hide show
  1. extract.py +2 -3
extract.py CHANGED
@@ -7,11 +7,10 @@ import torch
7
  from torchgeo.models import dofa_base_patch16_224
8
 
9
  # Load the checkpoint
10
- in_filename = "ofa_base_checkpoint_e99.pth"
11
- checkpoint = torch.load(in_filename, map_location=torch.device("cpu"))
12
 
13
  # Remove extra keys
14
- weights = checkpoint["model"]
15
  del weights["mask_token"]
16
  del weights["norm.weight"], weights["norm.bias"]
17
  del weights["projector.weight"], weights["projector.bias"]
 
7
  from torchgeo.models import dofa_base_patch16_224
8
 
9
  # Load the checkpoint
10
+ in_filename = "DOFA_ViT_base_e100.pth"
11
+ weights = torch.load(in_filename, map_location=torch.device("cpu"))
12
 
13
  # Remove extra keys
 
14
  del weights["mask_token"]
15
  del weights["norm.weight"], weights["norm.bias"]
16
  del weights["projector.weight"], weights["projector.bias"]