|
import torch |
|
import opensr_test |
|
import matplotlib.pyplot as plt |
|
from utils import load_satlas_sr, run_satlas |
|
|
|
|
|
model = load_satlas_sr(device="cuda") |
|
|
|
|
|
dataset = opensr_test.load("naip") |
|
lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"] |
|
|
|
|
|
index = 20 |
|
lr = torch.from_numpy(lr_dataset[index][[3, 2, 1]]/3558).float().to("cuda").clamp(0, 1) |
|
sr = run_satlas(model=model, lr=lr, cropsize=32, overlap=0) |
|
|
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) |
|
ax[0].imshow(lr.cpu().numpy().transpose(1, 2, 0)) |
|
ax[1].imshow(sr.cpu().numpy().transpose(1, 2, 0)) |
|
plt.show() |
|
|