DarkIR / app.py
danifei's picture
Update app.py
e84698c verified
raw
history blame
2.53 kB
import gradio as gr
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
from archs import DarkIR
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#define some auxiliary functions
pil_to_tensor = transforms.ToTensor()
tensor_to_pil = transforms.ToPILImage()
network = 'DarkIR'
PATH_MODEL = './darkir_v2real+lsrw.pt'
model = DarkIR(img_channel=3,
width=32,
middle_blk_num_enc=2,
middle_blk_num_dec=2,
enc_blk_nums=[1, 2, 3],
dec_blk_nums=[3, 1, 1],
dilations=[1, 4, 9],
extra_depth_wise=True)
checkpoints = torch.load(PATH_MODEL, map_location=device)
model.load_state_dict(checkpoints['params'])
model = model.to(device)
def pad_tensor(tensor, multiple = 8):
'''pad the tensor to be multiple of some number'''
multiple = multiple
_, _, H, W = tensor.shape
pad_h = (multiple - H % multiple) % multiple
pad_w = (multiple - W % multiple) % multiple
tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value = 0)
return tensor
def process_img(image):
tensor = pil_to_tensor(image).unsqueeze(0).to(device)
_, _, H, W = tensor.shape
tensor = pad_tensor(tensor)
with torch.no_grad():
output = model(tensor, side_loss=False)
output = torch.clamp(output, 0., 1.)
output = output[:,:, :H, :W].squeeze(0)
return tensor_to_pil(output)
title = "DarkIR ✏️🖼️ 🤗"
description = ''' ## [ DarkIR: Robust Low-Light Image Restoration](https://github.com/cidautai/DarkIR)
[Daniel Feijoo](https://github.com/danifei)
Fundación Cidaut
> **Disclaimer:** please remember this is not a product, thus, you will notice some limitations.
**This demo expects an image with some Low-Light degradations.**
<br>
'''
examples = [['examples/0010.png'],
['examples/r13073518t_low.png'],
['examples/low00733_low.png'],
["examples/0087.png"]]
css = """
.image-frame img, .image-container img {
width: auto;
height: auto;
max-width: none;
}
"""
demo = gr.Interface(
fn = process_img,
inputs = [
gr.Image(type = 'pil', label = 'input')
],
outputs = [gr.Image(type='pil', label = 'output')],
title = title,
description = description,
examples = examples,
css = css
)
if __name__ == '__main__':
demo.launch()