|
import torch |
|
import numpy as np |
|
from super_image import HanModel |
|
from typing import Union |
|
|
|
def create_superimage_model( |
|
device: Union[str, torch.device] = "cuda" |
|
) -> HanModel: |
|
""" Create the super image model |
|
|
|
Returns: |
|
HanModel: The super image model |
|
""" |
|
return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device) |
|
|
|
|
|
def run_superimage( |
|
model: HanModel, |
|
lr: np.ndarray, |
|
hr: np.ndarray, |
|
device: Union[str, torch.device] = "cuda" |
|
): |
|
""" Run the super image model |
|
|
|
Args: |
|
model (HanModel): The super image model |
|
lr (np.ndarray): The low resolution image |
|
hr (np.ndarray): The high resolution image |
|
device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda". |
|
|
|
Returns: |
|
dict: The results |
|
""" |
|
|
|
lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float() |
|
|
|
|
|
with torch.no_grad(): |
|
sr_tensor = model(lr_tensor[None]) |
|
|
|
|
|
lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16) |
|
sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16) |
|
|
|
|
|
return { |
|
"lr": lr.squeeze(), |
|
"hr": hr[0:3].squeeze(), |
|
"sr": sr.squeeze() |
|
} |