superIX / evoland /utils.py
Cesar Aybar
up
93e9c59
raw
history blame
No virus
2.08 kB
import torch
import pickle
import numpy as np
import opensr_test
import onnxruntime as ort
from typing import List, Union
def load_evoland() -> np.ndarray:
# ONNX inference session options
so = ort.SessionOptions()
so.intra_op_num_threads = 10
so.inter_op_num_threads = 10
so.use_deterministic_compute = True
# Execute on cpu only
ep_list = ["CPUExecutionProvider"]
ep_list.insert(0, "CUDAExecutionProvider")
ort_session = ort.InferenceSession(
"evoland/weights/carn_3x3x64g4sw_bootstrap.onnx",
sess_options=so,
providers=ep_list
)
ort_session.set_providers(["CPUExecutionProvider"])
ro = ort.RunOptions()
return [ort_session, ro]
def run_evoland(
model: List,
lr: np.ndarray,
hr: np.ndarray
) -> dict:
ort_session, ro = model
# Bands to use
bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11]
lr = lr[bands]
if lr.shape[1] == 121:
# add padding
lr = torch.nn.functional.pad(
torch.from_numpy(lr[None]).float(),
pad=(3, 4, 3, 4),
mode='reflect'
).squeeze().cpu().numpy()
# run the model
sr = ort_session.run(
None,
{"input": lr[None]},
run_options=ro
)[0].squeeze()
# remove padding
sr = sr[:, 3*2:-4*2, 3*2:-4*2].astype(np.uint16)
lr = lr[:, 3:-4, 3:-4].astype(np.uint16)
else:
# run the model
sr = ort_session.run(
None,
{"input": lr[None].astype(np.float32)},
run_options=ro
)[0].squeeze()
# Use nn interpolation to go back to x2 without distortion
# during metrics calculation
if sr.shape[1] != hr.shape[1]:
sr = torch.nn.functional.interpolate(
torch.from_numpy(sr)[None].float(),
size=hr.shape[1:],
mode='nearest'
).squeeze().numpy().astype('uint16')
# Run the model
return {
"lr": lr[[2, 1, 0]],
"sr": sr[[2, 1, 0]],
"hr": hr[0:3]
}