superIX / satlas /run.py
csaybar's picture
Upload 5 files
99a3901 verified
raw
history blame
No virus
631 Bytes
import torch
import opensr_test
import matplotlib.pyplot as plt
from utils import load_satlas_sr, run_satlas
# Load the model
model = load_satlas_sr(device="cuda")
# Load the dataset
dataset = opensr_test.load("naip")
lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"]
# Predict a image
index = 20
lr = torch.from_numpy(lr_dataset[index][[3, 2, 1]]/3558).float().to("cuda").clamp(0, 1)
sr = run_satlas(model=model, lr=lr, cropsize=32, overlap=0)
# Run the model
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(lr.cpu().numpy().transpose(1, 2, 0))
ax[1].imshow(sr.cpu().numpy().transpose(1, 2, 0))
plt.show()