csaybar commited on
Commit
99a3901
1 Parent(s): 476803e

Upload 5 files

Browse files
satlas/__pycache__/model.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
satlas/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
satlas/run.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import opensr_test
3
+ import matplotlib.pyplot as plt
4
+ from utils import load_satlas_sr, run_satlas
5
+
6
+ # Load the model
7
+ model = load_satlas_sr(device="cuda")
8
+
9
+ # Load the dataset
10
+ dataset = opensr_test.load("naip")
11
+ lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"]
12
+
13
+ # Predict a image
14
+ index = 20
15
+ lr = torch.from_numpy(lr_dataset[index][[3, 2, 1]]/3558).float().to("cuda").clamp(0, 1)
16
+ sr = run_satlas(model=model, lr=lr, cropsize=32, overlap=0)
17
+
18
+ # Run the model
19
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
20
+ ax[0].imshow(lr.cpu().numpy().transpose(1, 2, 0))
21
+ ax[1].imshow(sr.cpu().numpy().transpose(1, 2, 0))
22
+ plt.show()
satlas/utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from basicsr.archs.rrdbnet_arch import RRDBNet
3
+ from typing import Union
4
+ import itertools
5
+ import numpy as np
6
+
7
+
8
+ def load_satlas_sr(device: Union[str, torch.device] = "cuda") -> RRDBNet:
9
+ # Load the weights
10
+ weights_file = "weights/esrgan_1S2.pth"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Create the model
14
+ model = RRDBNet(
15
+ num_in_ch=3,
16
+ num_out_ch=3,
17
+ num_feat=64,
18
+ num_block=23,
19
+ num_grow_ch=32,
20
+ scale=4
21
+ ).to(device)
22
+
23
+ # Setup the weights
24
+ state_dict = torch.load(weights_file)
25
+ model.load_state_dict(state_dict['params_ema'])
26
+ model.eval()
27
+
28
+ # no gradients
29
+ for param in model.parameters():
30
+ param.requires_grad = False
31
+
32
+ return model
33
+
34
+
35
+ def run_satlas(model, lr, cropsize: int = 32, overlap: int = 0):
36
+ # Select the raster with the lowest resolution
37
+ tshp = lr.shape
38
+
39
+ # if the image is too small, return (0, 0)
40
+ if (tshp[1] < cropsize) and (tshp[2] < cropsize):
41
+ return [(0, 0)]
42
+
43
+ # Define relative coordinates.
44
+ xmn, xmx, ymn, ymx = (0, tshp[1], 0, tshp[2])
45
+
46
+ if overlap > cropsize:
47
+ raise ValueError("The overlap must be smaller than the cropsize")
48
+
49
+ xrange = np.arange(xmn, xmx, (cropsize - overlap))
50
+ yrange = np.arange(ymn, ymx, (cropsize - overlap))
51
+
52
+ # If there is negative values in the range, change them by zero.
53
+ xrange[xrange < 0] = 0
54
+ yrange[yrange < 0] = 0
55
+
56
+ # Remove the last element if it is outside the tensor
57
+ xrange = xrange[xrange - (tshp[1] - cropsize) <= 0]
58
+ yrange = yrange[yrange - (tshp[2] - cropsize) <= 0]
59
+
60
+ # If the last element is not (tshp[1] - cropsize) add it!
61
+ if xrange[-1] != (tshp[1] - cropsize):
62
+ xrange = np.append(xrange, tshp[1] - cropsize)
63
+ if yrange[-1] != (tshp[2] - cropsize):
64
+ yrange = np.append(yrange, tshp[2] - cropsize)
65
+
66
+ # Create all the relative coordinates
67
+ mrs = list(itertools.product(xrange, yrange))
68
+
69
+ # Predict the image
70
+ sr = torch.zeros(3, tshp[1]*4, tshp[2]*4)
71
+ for x, y in mrs:
72
+ crop = lr[:, x:x+cropsize, y:y+cropsize]
73
+ sr_crop = model(crop[None])[0]
74
+ sr[:, x*4:(x+cropsize)*4, y*4:(y+cropsize)*4] = sr_crop
75
+
76
+ return sr
satlas/weights/esrgan_1S2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4478f38ccd2271467e77eb5a311aec99ff6796bf900ccfa88c85eea992537f2
3
+ size 134059342