File size: 1,338 Bytes
6c08128 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
import numpy as np
from super_image import HanModel
from typing import Union
def create_superimage_model(
device: Union[str, torch.device] = "cuda"
) -> HanModel:
""" Create the super image model
Returns:
HanModel: The super image model
"""
return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device)
def run_superimage(
model: HanModel,
lr: np.ndarray,
hr: np.ndarray,
device: Union[str, torch.device] = "cuda"
):
""" Run the super image model
Args:
model (HanModel): The super image model
lr (np.ndarray): The low resolution image
hr (np.ndarray): The high resolution image
device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda".
Returns:
dict: The results
"""
# Convert the images to tensors
lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float()
# Run the model
with torch.no_grad():
sr_tensor = model(lr_tensor[None])
# Convert the tensors to numpy arrays
lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16)
sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16)
# Return the results
return {
"lr": lr.squeeze(),
"hr": hr[0:3].squeeze(),
"sr": sr.squeeze()
} |