File size: 1,568 Bytes
c2f815f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import numpy as np
import opensr_model
from typing import Union
def create_opensr_model(
device: Union[str, torch.device] = "cpu"
) -> opensr_model:
""" Create the super image model
Returns:
HanModel: The super image model
"""
model = opensr_model.SRLatentDiffusion(device=device)
model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt")
model.eval()
return model
def run_opensr_model(
model: opensr_model,
lr: np.ndarray,
hr: np.ndarray,
device: Union[str, torch.device] = "cpu"
) -> dict:
# Convert the input to torch tensors
lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float()
hr_img = hr[0:3]
if lr_img.shape[1] == 121:
# add padding
lr_img = torch.nn.functional.pad(
lr_img[None],
pad=(3, 4, 3, 4),
mode='reflect'
).squeeze()
# Run the model
with torch.no_grad():
sr_img = model(lr_img[None]).squeeze()
# take out padding
lr_img = lr_img[:, 3:-4, 3:-4]
sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4]
else:
# Run the model
with torch.no_grad():
sr_img = model(lr_img[None]).squeeze()
# Convert the output to numpy
lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
hr_img = hr_img
# Return the results
return {
"lr": lr_img,
"sr": sr_img,
"hr": hr_img
} |