|
import torch |
|
import numpy as np |
|
import opensr_model |
|
from typing import Union |
|
|
|
def create_opensr_model( |
|
device: Union[str, torch.device] = "cpu" |
|
) -> opensr_model: |
|
""" Create the super image model |
|
Returns: |
|
HanModel: The super image model |
|
""" |
|
model = opensr_model.SRLatentDiffusion(device=device) |
|
model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt") |
|
model.eval() |
|
return model |
|
|
|
|
|
def run_opensr_model( |
|
model: opensr_model, |
|
lr: np.ndarray, |
|
hr: np.ndarray, |
|
device: Union[str, torch.device] = "cpu" |
|
) -> dict: |
|
|
|
lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float() |
|
hr_img = hr[0:3] |
|
|
|
if lr_img.shape[1] == 121: |
|
|
|
lr_img = torch.nn.functional.pad( |
|
lr_img[None], |
|
pad=(3, 4, 3, 4), |
|
mode='reflect' |
|
).squeeze() |
|
|
|
|
|
with torch.no_grad(): |
|
sr_img = model(lr_img[None]).squeeze() |
|
|
|
|
|
lr_img = lr_img[:, 3:-4, 3:-4] |
|
sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4] |
|
else: |
|
|
|
with torch.no_grad(): |
|
sr_img = model(lr_img[None]).squeeze() |
|
|
|
|
|
lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
|
sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) |
|
hr_img = hr_img |
|
|
|
|
|
return { |
|
"lr": lr_img, |
|
"sr": sr_img, |
|
"hr": hr_img |
|
} |