Spaces:
Running
Running
import numpy as np | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from model import SRCNNModel, pred_SRCNN | |
from PIL import Image | |
title = "Super Resolution with CNN" | |
description = """ | |
Your low resolution image will be reconstructed to high resolution with a scale of 2 with a convolutional neural network! | |
CNN output on the left, bicubic interpolation output on the right. | |
""" | |
article = "Check out the origianl [paper](https://arxiv.org/abs/1501.00092) proposed by Dong *et al*." | |
# load model | |
print("Loading SRCNN model...") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = SRCNNModel().to(device) | |
model.load_state_dict(torch.load('SRCNNmodel_trained.pt')) | |
model.eval() | |
print("SRCNN model loaded!") | |
def image_grid(imgs, rows, cols): | |
''' | |
imgs:list of PILImage | |
''' | |
assert len(imgs) == rows*cols | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
grid_w, grid_h = grid.size | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
return grid | |
def sepia(image_path): | |
# gradio open image as np array | |
image = Image.fromarray(image_path,mode='RGB') | |
out_final,image_bicubic,image = pred_SRCNN(model=model,image=image,device=device) | |
grid = image_grid([out_final,image_bicubic],1,2) | |
return grid | |
demo = gr.Interface(fn = sepia, inputs=gr.Image(shape=(200, 200)), outputs="image",title=title,description = description,article = article,examples=['LR_image.png','barbara.png']) | |
demo.launch() |