Cesar Aybar commited on
Commit
caa7010
1 Parent(s): ce8cb02

benchmark script

Browse files
benchmark.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rasterio
2
+ import pathlib
3
+
4
+ from typing import Callable
5
+ from rasterio.transform import from_origin
6
+
7
+
8
+ def create_geotiff(
9
+ fn: Callable,
10
+ dataset_snippet: str,
11
+ output_path: str
12
+ ) -> pathlib.Path:
13
+ """Create all the GeoTIFFs for a specific dataset snippet
14
+
15
+ Args:
16
+ fn (Callable): A function that return a dictionary with the following keys:
17
+ - "lr": Low resolution image
18
+ - "sr": Super resolution image
19
+ - "hr": High resolution image
20
+ dataset_snippet (str): The dataset snippet to use to run the fn function.
21
+ output_path (str): The output path to save the GeoTIFFs.
22
+
23
+ Returns:
24
+ pathlib.Path: The output path where the GeoTIFFs are saved.
25
+ """
26
+ pass
27
+
28
+
29
+ def run(
30
+ model_path: str
31
+ ) -> pathlib.Path:
32
+ """Run the all metrics for a specific model.
33
+
34
+ Args:
35
+ model_path (str): The path to the model folder.
36
+
37
+ Returns:
38
+ pathlib.Path: The output path where the metrics are
39
+ saved as a pickle file.
40
+ """
41
+ pass
42
+
43
+
44
+ def plot(
45
+ model_path: str
46
+ ) -> pathlib.Path:
47
+ """Generate the plots and tables for a specific model.
48
+
49
+ Args:
50
+ model_path (str): The path to the model folder.
51
+
52
+ Returns:
53
+ pathlib.Path: The output path where the plots and tables are
54
+ saved.
55
+ """
56
+ pass
ldm_baseline/metadata.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "ldm-baseline",
3
+ "authors": ["CompVis team"],
4
+ "affiliations": ["None"],
5
+ "description": "A baseline of LDM models trained on the Open Images dataset.",
6
+ "code": "open-source",
7
+ "scale": "x4",
8
+ "url": "https://huggingface.co/CompVis/ldm-super-resolution-4x-openimages",
9
+ "license": "apache-2.0"
10
+ }
ldm_baseline/run.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import opensr_test
3
+
4
+ from ldm_baseline.utils import create_stable_diffusion_model, run_diffuser
5
+
6
+ # set the device
7
+ device = "cuda:0"
8
+
9
+ # Load the model
10
+ model = create_stable_diffusion_model(device=device)
11
+
12
+ # Load the dataset
13
+ dataset = opensr_test.load("spain_crops")
14
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
15
+
16
+ # Run the model
17
+ results = run_diffuser(model=model, lr=lr_dataset[5], hr=hr_dataset[5], device=device)
18
+
19
+ # Display the results
20
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
21
+ ax[0].imshow(results["lr"].transpose(1, 2, 0) / 3000)
22
+ ax[0].set_title("LR")
23
+ ax[0].axis("off")
24
+ ax[1].imshow(results["sr"].transpose(1, 2, 0) / 3000)
25
+ ax[1].set_title("SR")
26
+ ax[1].axis("off")
27
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
28
+ ax[2].set_title("HR")
29
+ plt.show()
30
+
31
+ # Run the experiment
32
+ #
33
+ # benchmark.create_geotiff(run_diffuser, "all", "ldm_baseline/")
34
+ # benchmark.run("all")
35
+ # benchmark.plot("all")
ldm_baseline/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import opensr_test
6
+ import torch
7
+ from diffusers import LDMSuperResolutionPipeline
8
+
9
+
10
+ def create_stable_diffusion_model(
11
+ device: Union[str, torch.device] = "cuda"
12
+ ) -> LDMSuperResolutionPipeline:
13
+ """Create the stable diffusion model
14
+
15
+ Returns:
16
+ LDMSuperResolutionPipeline: The model to use for
17
+ super resolution.
18
+ """
19
+ model_id = "CompVis/ldm-super-resolution-4x-openimages"
20
+ pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
21
+ pipeline = pipeline.to(device)
22
+ return pipeline
23
+
24
+
25
+ def run_diffuser(
26
+ model: LDMSuperResolutionPipeline,
27
+ lr: torch.Tensor,
28
+ hr: torch.Tensor,
29
+ device: Union[str, torch.device] = "cuda",
30
+ ) -> dict:
31
+ """Run the model on the low resolution image
32
+
33
+ Args:
34
+ model (LDMSuperResolutionPipeline): The model to use
35
+ lr (torch.Tensor): The low resolution image
36
+ hr (torch.Tensor): The high resolution image
37
+ device (Union[str, torch.device], optional): The device
38
+ to use. Defaults to "cuda".
39
+
40
+ Returns:
41
+ dict: The results of the model
42
+ """
43
+
44
+ # move the images to the device
45
+ lr = (torch.from_numpy(lr[[3, 2, 1]]) / 2000).to(device).clamp(0, 1)
46
+
47
+ if lr.shape[1] == 121:
48
+ # add padding
49
+ lr = torch.nn.functional.pad(
50
+ lr[None], pad=(3, 4, 3, 4), mode="reflect"
51
+ ).squeeze()
52
+
53
+ # run the model
54
+ with torch.no_grad():
55
+ sr = model(lr[None], num_inference_steps=100, eta=1)
56
+ sr = torch.from_numpy(np.array(sr.images[0]) / 255).permute(2, 0, 1).float()
57
+
58
+ # remove padding
59
+ sr = sr[:, 3 * 4 : -4 * 4, 3 * 4 : -4 * 4]
60
+ lr = lr[:, 3:-4, 3:-4]
61
+ else:
62
+ # run the model
63
+ with torch.no_grad():
64
+ sr = model(lr[None], num_inference_steps=100, eta=1)
65
+ sr = torch.from_numpy(np.array(sr.images[0]) / 255).permute(2, 0, 1).float()
66
+
67
+ lr = (lr.cpu().numpy() * 2000).astype(np.uint16)
68
+ hr = ((hr[0:3] / 2000).clip(0, 1) * 2000).astype(np.uint16)
69
+ sr = (sr.cpu().numpy() * 2000).astype(np.uint16)
70
+
71
+ results = {"lr": lr, "hr": hr, "sr": sr}
72
+
73
+ return results