File size: 2,674 Bytes
99a3901
 
 
 
 
 
 
 
 
7c8cf1a
99a3901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d106d
 
 
 
 
 
 
 
8cb5e5a
26d106d
 
99a3901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d106d
 
8cb5e5a
 
26d106d
 
8cb5e5a
26d106d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from typing import Union
import itertools
import numpy as np


def load_satlas_sr(device: Union[str, torch.device] = "cuda") -> RRDBNet:
    # Load the weights
    weights_file = "satlas/weights/esrgan_1S2.pth"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Create the model
    model = RRDBNet(
        num_in_ch=3,
        num_out_ch=3,
        num_feat=64,
        num_block=23,
        num_grow_ch=32,
        scale=4
    ).to(device)

    # Setup the weights
    state_dict = torch.load(weights_file)
    model.load_state_dict(state_dict['params_ema'])
    model.eval()

    # no gradients
    for param in model.parameters():
        param.requires_grad = False

    return model


def run_satlas(
    model: RRDBNet,
    lr: torch.Tensor,
    hr: torch.Tensor,
    cropsize: int = 32,
    overlap: int = 0,
    device: Union[str, torch.device] = "cuda"
) -> torch.Tensor:
    # Load the LR image 
    lr = torch.from_numpy(lr[[3, 2, 1]]/3558).float().to(device).clamp(0, 1)

    # Select the raster with the lowest resolution
    tshp = lr.shape
    
    # if the image is too small, return (0, 0)
    if (tshp[1] < cropsize) and (tshp[2] < cropsize):
        return [(0, 0)]

    # Define relative coordinates.
    xmn, xmx, ymn, ymx = (0, tshp[1], 0, tshp[2])

    if overlap > cropsize:
        raise ValueError("The overlap must be smaller than the cropsize")

    xrange = np.arange(xmn, xmx, (cropsize - overlap))
    yrange = np.arange(ymn, ymx, (cropsize - overlap))

    # If there is negative values in the range, change them by zero.
    xrange[xrange < 0] = 0
    yrange[yrange < 0] = 0

    # Remove the last element if it is outside the tensor
    xrange = xrange[xrange - (tshp[1] - cropsize) <= 0]
    yrange = yrange[yrange - (tshp[2] - cropsize) <= 0]

    # If the last element is not (tshp[1] - cropsize) add it!
    if xrange[-1] != (tshp[1] - cropsize):
        xrange = np.append(xrange, tshp[1] - cropsize)
    if yrange[-1] != (tshp[2] - cropsize):
        yrange = np.append(yrange, tshp[2] - cropsize)

    # Create all the relative coordinates
    mrs = list(itertools.product(xrange, yrange))

    # Predict the image
    sr = torch.zeros(3, tshp[1]*4, tshp[2]*4)
    for x, y in mrs:
        crop = lr[:, x:x+cropsize, y:y+cropsize]        
        sr_crop = model(crop[None])[0]
        sr[:, x*4:(x+cropsize)*4, y*4:(y+cropsize)*4] = sr_crop
    
    # Save the result
    results = {
        "lr": (lr.cpu().numpy() * 3558).astype(np.uint16),
        "sr": (sr.cpu().numpy() * 3558).astype(np.uint16),
        "hr": hr[0:3]
    }
    

    return results