csaybar commited on
Commit
7a13af2
1 Parent(s): 8cb5e5a

Upload 3 files

Browse files
evoland/run.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import opensr_test
3
+ import matplotlib.pyplot as plt
4
+ from utils import load_evoland, run_evoland
5
+
6
+ # Load the model
7
+ model = load_evoland()
8
+
9
+ # Load the dataset
10
+ dataset = opensr_test.load("naip")
11
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
12
+
13
+ # Predict a image
14
+ results = run_evoland(
15
+ model=model,
16
+ lr=lr_dataset[4],
17
+ hr=hr_dataset[4]
18
+ )
19
+
20
+
21
+ # Display the results
22
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
23
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/3000)
24
+ ax[0].set_title("LR")
25
+ ax[0].axis("off")
26
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/3000)
27
+ ax[1].set_title("SR")
28
+ ax[1].axis("off")
29
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
30
+ ax[2].set_title("HR")
31
+ plt.show()
evoland/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import numpy as np
4
+ import opensr_test
5
+ import onnxruntime as ort
6
+ from typing import List, Union
7
+
8
+ def load_evoland() -> np.ndarray:
9
+ # ONNX inference session options
10
+ so = ort.SessionOptions()
11
+ so.intra_op_num_threads = 10
12
+ so.inter_op_num_threads = 10
13
+ so.use_deterministic_compute = True
14
+
15
+ # Execute on cpu only
16
+ ep_list = ["CPUExecutionProvider"]
17
+ ep_list.insert(0, "CUDAExecutionProvider")
18
+
19
+ ort_session = ort.InferenceSession(
20
+ "weights/carn_3x3x64g4sw_bootstrap.onnx",
21
+ sess_options=so,
22
+ providers=ep_list
23
+ )
24
+ ort_session.set_providers(["CPUExecutionProvider"])
25
+ ro = ort.RunOptions()
26
+
27
+ return [ort_session, ro]
28
+
29
+
30
+ def run_evoland(
31
+ model: List,
32
+ lr: np.ndarray,
33
+ hr: np.ndarray
34
+ ) -> dict:
35
+
36
+ ort_session, ro = model
37
+
38
+ # Bands to use
39
+ bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11]
40
+ lr = lr[bands]
41
+
42
+ if lr.shape[1] == 121:
43
+ # add padding
44
+ lr = torch.nn.functional.pad(
45
+ torch.from_numpy(lr[None]).float(),
46
+ pad=(3, 4, 3, 4),
47
+ mode='reflect'
48
+ ).squeeze().cpu().numpy()
49
+
50
+ # run the model
51
+ sr = ort_session.run(
52
+ None,
53
+ {"input": lr[None]},
54
+ run_options=ro
55
+ )[0].squeeze()
56
+
57
+ # remove padding
58
+ sr = sr[:, 3*4:-4*4, 3*4:-4*4].astype(np.uint16)
59
+ lr = lr[:, 3:-4, 3:-4].astype(np.uint16)
60
+ else:
61
+ # run the model
62
+ sr = ort_session.run(
63
+ None,
64
+ {"input": lr[None]},
65
+ run_options=ro
66
+ )[0].squeeze()
67
+
68
+ # Run the model
69
+ return {
70
+ "lr": lr[[2, 1, 0]],
71
+ "sr": sr[[2, 1, 0]],
72
+ "hr": hr[0:3]
73
+ }
evoland/weights/carn_3x3x64g4sw_bootstrap.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:361233ee1cb1977a6e2c41e3bb40eb55cea8bdfe001e945d3b74b6eecaff6516
3
+ size 10103338