File size: 5,230 Bytes
61e0235 caa7010 61e0235 caa7010 307a330 caa7010 61e0235 caa7010 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 61e0235 307a330 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 caa7010 61e0235 307a330 61e0235 307a330 61e0235 caa7010 307a330 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import rasterio as rio
import pathlib
import opensr_test
import matplotlib.pyplot as plt
from typing import Callable, Union
def create_geotiff(
model: Callable,
fn: Callable,
datasets: Union[str, list],
output_path: str,
force: bool = False,
**kwargs
) -> None:
"""Create all the GeoTIFFs for a specific dataset snippet
Args:
model (Callable): The model to use to run the fn function.
fn (Callable): A function that return a dictionary with the following keys:
- "lr": Low resolution image
- "sr": Super resolution image
- "hr": High resolution image
datasets (list): A list of dataset snippets to use to run the fn function.
output_path (str): The output path to save the GeoTIFFs.
force (bool, optional): If True, the dataset is redownloaded. Defaults
to False.
"""
if datasets == "all":
datasets = opensr_test.datasets
for snippet in datasets:
create_geotiff_batch(
model=model,
fn=fn,
snippet=snippet,
output_path=output_path,
force=force,
**kwargs
)
return None
def create_geotiff_batch(
model: Callable,
fn: Callable,
snippet: str,
output_path: str,
force: bool = False,
**kwargs
) -> pathlib.Path:
"""Create all the GeoTIFFs for a specific dataset snippet
Args:
model (Callable): The model to use to run the fn function.
fn (Callable): A function that return a dictionary with the following keys:
- "lr": Low resolution image
- "sr": Super resolution image
- "hr": High resolution image
snippet (str): The dataset snippet to use to run the fn function.
output_path (str): The output path to save the GeoTIFFs.
force (bool, optional): If True, the dataset is redownloaded. Defaults
to False.
Returns:
pathlib.Path: The output path where the GeoTIFFs are saved.
"""
# Create folders to save results
output_path = pathlib.Path(output_path) / "results" / "SR"
output_path.mkdir(parents=True, exist_ok=True)
output_path_dataset_geotiff = output_path / snippet / "geotiff"
output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True)
output_path_dataset_png = output_path / snippet / "png"
output_path_dataset_png.mkdir(parents=True, exist_ok=True)
# Load the dataset
dataset = opensr_test.load(snippet, force=force)
lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
for index in range(len(lr_dataset)):
print(f"Processing {index}/{len(lr_dataset)}")
# Run the model
results = fn(
model=model,
lr=lr_dataset[index],
hr=hr_dataset[index],
**kwargs
)
# Get the image name
image_name = metadata.iloc[index]["hr_file"]
# Get the CRS and transform
crs = metadata.iloc[index]["crs"]
transform_str = metadata.iloc[index]["affine"]
transform_list = [float(x) for x in transform_str.split(",")]
transform_rio = rio.transform.from_origin(
transform_list[2],
transform_list[5],
transform_list[0],
transform_list[4] * -1
)
# Create rio dict
meta_img = {
"driver": "GTiff",
"count": 3,
"dtype": "uint16",
"height": results["hr"].shape[1],
"width": results["hr"].shape[2],
"crs": crs,
"transform": transform_rio,
"compress": "deflate",
"predictor": 2,
"tiled": True
}
# Save the GeoTIFF
with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst:
dst.write(results["sr"])
# Save the PNG
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(results["lr"].transpose(1, 2, 0) / 3000)
ax[0].set_title("LR")
ax[0].axis("off")
ax[1].imshow(results["sr"].transpose(1, 2, 0) / 3000)
ax[1].set_title("SR")
ax[1].axis("off")
ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
ax[2].set_title("HR")
# remove whitespace around the image
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.axis("off")
plt.savefig(output_path_dataset_png / (image_name + ".png"))
plt.close()
plt.clf()
return output_path_dataset_geotiff
def run(
model_path: str
) -> pathlib.Path:
"""Run the all metrics for a specific model.
Args:
model_path (str): The path to the model folder.
Returns:
pathlib.Path: The output path where the metrics are
saved as a pickle file.
"""
pass
def plot(
model_path: str
) -> pathlib.Path:
"""Generate the plots and tables for a specific model.
Args:
model_path (str): The path to the model folder.
Returns:
pathlib.Path: The output path where the plots and tables are
saved.
"""
pass
|