|
from diffusers import LDMSuperResolutionPipeline |
|
import numpy as np |
|
import opensr_test |
|
import torch |
|
import pickle |
|
from typing import Union |
|
|
|
|
|
def create_stable_diffusion_model( |
|
device: Union[str, torch.device] = "cuda" |
|
) -> LDMSuperResolutionPipeline: |
|
""" Create the stable diffusion model |
|
|
|
Returns: |
|
LDMSuperResolutionPipeline: The model to use for |
|
super resolution. |
|
""" |
|
model_id = "CompVis/ldm-super-resolution-4x-openimages" |
|
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id) |
|
pipeline = pipeline.to(device) |
|
return pipeline |
|
|
|
def run_diffuser( |
|
model: LDMSuperResolutionPipeline, |
|
lr: torch.Tensor, |
|
hr: torch.Tensor, |
|
device: Union[str, torch.device] = "cuda" |
|
) -> dict: |
|
""" Run the model on the low resolution image |
|
|
|
Args: |
|
model (LDMSuperResolutionPipeline): The model to use |
|
lr (torch.Tensor): The low resolution image |
|
hr (torch.Tensor): The high resolution image |
|
device (Union[str, torch.device], optional): The device |
|
to use. Defaults to "cuda". |
|
|
|
Returns: |
|
dict: The results of the model |
|
""" |
|
|
|
|
|
lr = (torch.from_numpy(lr[[3, 2, 1]]) / 2000).to(device).clamp(0, 1) |
|
|
|
if lr.shape[1] == 121: |
|
|
|
lr = torch.nn.functional.pad( |
|
lr[None], |
|
pad=(3, 4, 3, 4), |
|
mode='reflect' |
|
).squeeze() |
|
|
|
|
|
with torch.no_grad(): |
|
sr = model(lr[None], num_inference_steps=100, eta=1) |
|
sr = torch.from_numpy( |
|
np.array(sr.images[0])/255 |
|
).permute(2,0,1).float() |
|
|
|
|
|
sr = sr[:, 3*4:-4*4, 3*4:-4*4] |
|
lr = lr[:, 3:-4, 3:-4] |
|
else: |
|
|
|
with torch.no_grad(): |
|
sr = model(lr[None], num_inference_steps=100, eta=1) |
|
sr = torch.from_numpy( |
|
np.array(sr.images[0])/255 |
|
).permute(2,0,1).float() |
|
|
|
lr = (lr.cpu().numpy() * 2000).astype(np.uint16) |
|
hr = ((hr[0:3] / 2000).clip(0, 1) * 2000).astype(np.uint16) |
|
sr = (sr.cpu().numpy() * 2000).astype(np.uint16) |
|
|
|
results = { |
|
"lr": lr, |
|
"hr": hr, |
|
"sr": sr |
|
} |
|
|
|
return results |