import torch import numpy as np import opensr_model from typing import Union def create_opensr_model( device: Union[str, torch.device] = "cpu" ) -> opensr_model: """ Create the super image model Returns: HanModel: The super image model """ model = opensr_model.SRLatentDiffusion(device=device) model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt") model.eval() return model def run_opensr_model( model: opensr_model, lr: np.ndarray, hr: np.ndarray, device: Union[str, torch.device] = "cpu" ) -> dict: # Convert the input to torch tensors lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float() hr_img = hr[0:3] if lr_img.shape[1] == 121: # add padding lr_img = torch.nn.functional.pad( lr_img[None], pad=(3, 4, 3, 4), mode='reflect' ).squeeze() # Run the model with torch.no_grad(): sr_img = model(lr_img[None]).squeeze() # take out padding lr_img = lr_img[:, 3:-4, 3:-4] sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4] else: # Run the model with torch.no_grad(): sr_img = model(lr_img[None]).squeeze() # Convert the output to numpy lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16) hr_img = hr_img # Return the results return { "lr": lr_img, "sr": sr_img, "hr": hr_img }