csaybar commited on
Commit
6c08128
1 Parent(s): 9c55c41

Upload 2 files

Browse files
Files changed (2) hide show
  1. superimage/run.py +31 -0
  2. superimage/utils.py +50 -0
superimage/run.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import opensr_test
2
+ import matplotlib.pyplot as plt
3
+ from utils import create_superimage_model, run_superimage
4
+
5
+
6
+ # Load the model
7
+ model = create_superimage_model(device="cuda")
8
+
9
+ # Load the dataset
10
+ dataset = opensr_test.load("naip")
11
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
12
+
13
+ # Run the model
14
+ results = run_superimage(
15
+ model=model,
16
+ lr=lr_dataset[7][:,0:64, 0:64],
17
+ hr=hr_dataset[7][:,0:256, 0:256],
18
+ device="cuda"
19
+ )
20
+
21
+ # Display the results
22
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
23
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/3000)
24
+ ax[0].set_title("LR")
25
+ ax[0].axis("off")
26
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/3000)
27
+ ax[1].set_title("SR")
28
+ ax[1].axis("off")
29
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
30
+ ax[2].set_title("HR")
31
+ plt.show()
superimage/utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from super_image import HanModel
4
+ from typing import Union
5
+
6
+ def create_superimage_model(
7
+ device: Union[str, torch.device] = "cuda"
8
+ ) -> HanModel:
9
+ """ Create the super image model
10
+
11
+ Returns:
12
+ HanModel: The super image model
13
+ """
14
+ return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device)
15
+
16
+
17
+ def run_superimage(
18
+ model: HanModel,
19
+ lr: np.ndarray,
20
+ hr: np.ndarray,
21
+ device: Union[str, torch.device] = "cuda"
22
+ ):
23
+ """ Run the super image model
24
+
25
+ Args:
26
+ model (HanModel): The super image model
27
+ lr (np.ndarray): The low resolution image
28
+ hr (np.ndarray): The high resolution image
29
+ device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda".
30
+
31
+ Returns:
32
+ dict: The results
33
+ """
34
+ # Convert the images to tensors
35
+ lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float()
36
+
37
+ # Run the model
38
+ with torch.no_grad():
39
+ sr_tensor = model(lr_tensor[None])
40
+
41
+ # Convert the tensors to numpy arrays
42
+ lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16)
43
+ sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16)
44
+
45
+ # Return the results
46
+ return {
47
+ "lr": lr.squeeze(),
48
+ "hr": hr[0:3].squeeze(),
49
+ "sr": sr.squeeze()
50
+ }