File size: 1,716 Bytes
7a13af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import pickle
import numpy as np
import opensr_test
import onnxruntime as ort
from typing import List, Union
def load_evoland() -> np.ndarray:
# ONNX inference session options
so = ort.SessionOptions()
so.intra_op_num_threads = 10
so.inter_op_num_threads = 10
so.use_deterministic_compute = True
# Execute on cpu only
ep_list = ["CPUExecutionProvider"]
ep_list.insert(0, "CUDAExecutionProvider")
ort_session = ort.InferenceSession(
"weights/carn_3x3x64g4sw_bootstrap.onnx",
sess_options=so,
providers=ep_list
)
ort_session.set_providers(["CPUExecutionProvider"])
ro = ort.RunOptions()
return [ort_session, ro]
def run_evoland(
model: List,
lr: np.ndarray,
hr: np.ndarray
) -> dict:
ort_session, ro = model
# Bands to use
bands = [1, 2, 3, 7, 4, 5, 6, 8, 10, 11]
lr = lr[bands]
if lr.shape[1] == 121:
# add padding
lr = torch.nn.functional.pad(
torch.from_numpy(lr[None]).float(),
pad=(3, 4, 3, 4),
mode='reflect'
).squeeze().cpu().numpy()
# run the model
sr = ort_session.run(
None,
{"input": lr[None]},
run_options=ro
)[0].squeeze()
# remove padding
sr = sr[:, 3*4:-4*4, 3*4:-4*4].astype(np.uint16)
lr = lr[:, 3:-4, 3:-4].astype(np.uint16)
else:
# run the model
sr = ort_session.run(
None,
{"input": lr[None]},
run_options=ro
)[0].squeeze()
# Run the model
return {
"lr": lr[[2, 1, 0]],
"sr": sr[[2, 1, 0]],
"hr": hr[0:3]
} |