import torch import opensr_test import matplotlib.pyplot as plt from utils import load_satlas_sr, run_satlas # Load the model model = load_satlas_sr(device="cuda") # Load the dataset dataset = opensr_test.load("naip") lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"] # Predict a image index = 20 lr = torch.from_numpy(lr_dataset[index][[3, 2, 1]]/3558).float().to("cuda").clamp(0, 1) sr = run_satlas(model=model, lr=lr, cropsize=32, overlap=0) # Run the model fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].imshow(lr.cpu().numpy().transpose(1, 2, 0)) ax[1].imshow(sr.cpu().numpy().transpose(1, 2, 0)) plt.show()