import rasterio as rio import pathlib import opensr_test import matplotlib.pyplot as plt from typing import Callable, Union def create_geotiff( model: Callable, fn: Callable, datasets: Union[str, list], output_path: str, force: bool = False, **kwargs ) -> None: """Create all the GeoTIFFs for a specific dataset snippet Args: model (Callable): The model to use to run the fn function. fn (Callable): A function that return a dictionary with the following keys: - "lr": Low resolution image - "sr": Super resolution image - "hr": High resolution image datasets (list): A list of dataset snippets to use to run the fn function. output_path (str): The output path to save the GeoTIFFs. force (bool, optional): If True, the dataset is redownloaded. Defaults to False. """ if datasets == "all": datasets = opensr_test.datasets for snippet in datasets: create_geotiff_batch( model=model, fn=fn, snippet=snippet, output_path=output_path, force=force, **kwargs ) return None def create_geotiff_batch( model: Callable, fn: Callable, snippet: str, output_path: str, force: bool = False, **kwargs ) -> pathlib.Path: """Create all the GeoTIFFs for a specific dataset snippet Args: model (Callable): The model to use to run the fn function. fn (Callable): A function that return a dictionary with the following keys: - "lr": Low resolution image - "sr": Super resolution image - "hr": High resolution image snippet (str): The dataset snippet to use to run the fn function. output_path (str): The output path to save the GeoTIFFs. force (bool, optional): If True, the dataset is redownloaded. Defaults to False. Returns: pathlib.Path: The output path where the GeoTIFFs are saved. """ # Create folders to save results output_path = pathlib.Path(output_path) / "results" / "SR" output_path.mkdir(parents=True, exist_ok=True) output_path_dataset_geotiff = output_path / snippet / "geotiff" output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True) output_path_dataset_png = output_path / snippet / "png" output_path_dataset_png.mkdir(parents=True, exist_ok=True) # Load the dataset dataset = opensr_test.load(snippet, force=force) lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"] for index in range(len(lr_dataset)): print(f"Processing {index}/{len(lr_dataset)}") # Run the model results = fn( model=model, lr=lr_dataset[index], hr=hr_dataset[index], **kwargs ) # Get the image name image_name = metadata.iloc[index]["hr_file"] # Get the CRS and transform crs = metadata.iloc[index]["crs"] transform_str = metadata.iloc[index]["affine"] transform_list = [float(x) for x in transform_str.split(",")] transform_rio = rio.transform.from_origin( transform_list[2], transform_list[5], transform_list[0], transform_list[4] * -1 ) # Create rio dict meta_img = { "driver": "GTiff", "count": 3, "dtype": "uint16", "height": results["hr"].shape[1], "width": results["hr"].shape[2], "crs": crs, "transform": transform_rio, "compress": "deflate", "predictor": 2, "tiled": True } # Save the GeoTIFF with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst: dst.write(results["sr"]) # Save the PNG fig, ax = plt.subplots(1, 3, figsize=(15, 5)) ax[0].imshow(results["lr"].transpose(1, 2, 0) / 3000) ax[0].set_title("LR") ax[0].axis("off") ax[1].imshow(results["sr"].transpose(1, 2, 0) / 3000) ax[1].set_title("SR") ax[1].axis("off") ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000) ax[2].set_title("HR") # remove whitespace around the image plt.subplots_adjust(left=0, right=1, top=1, bottom=0) plt.axis("off") plt.savefig(output_path_dataset_png / (image_name + ".png")) plt.close() plt.clf() return output_path_dataset_geotiff def run( model_path: str ) -> pathlib.Path: """Run the all metrics for a specific model. Args: model_path (str): The path to the model folder. Returns: pathlib.Path: The output path where the metrics are saved as a pickle file. """ pass def plot( model_path: str ) -> pathlib.Path: """Generate the plots and tables for a specific model. Args: model_path (str): The path to the model folder. Returns: pathlib.Path: The output path where the plots and tables are saved. """ pass