File size: 1,674 Bytes
c7e6ba9
aa2d9c4
c7e6ba9
 
f518bf0
aa2d9c4
629756b
c7e6ba9
 
49012c2
f518bf0
 
c7e6ba9
629756b
 
0e065e1
c7e6ba9
048e7e4
c7e6ba9
 
 
 
 
 
 
0e065e1
 
f518bf0
0e065e1
048e7e4
629756b
0e065e1
aa2d9c4
30e3fae
46dda0b
30e3fae
0b2d163
 
30e3fae
 
082bc01
 
30e3fae
46dda0b
30e3fae
 
0b2d163
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import gradio as gr
import numpy as np
import torch.nn.functional as F
from skimage import img_as_ubyte

from Allweather.util import load_img, save_img
from basicsr.models.archs.histoformer_arch import Histoformer

model_restoration = Histoformer.from_pretrained("sunsean/Histoformer-real")

model_restoration.eval()

factor = 8
def predict(input_img):
    img = np.float32(load_img(input_img))/255.
    img = torch.from_numpy(img).permute(2,0,1)
    input_ = img.unsqueeze(0)

    # Padding in case images are not multiples of 8
    h,w = input_.shape[2], input_.shape[3]
    H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor
    padh = H-h if h%factor!=0 else 0
    padw = W-w if w%factor!=0 else 0
    input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
    
    restored = model_restoration(input_)
    output_path = "restored.png"
    restored = restored[:,:,:h,:w]
    restored = torch.clamp(restored,0,1).detach().permute(0, 2, 3, 1).squeeze(0).numpy()

    save_img(output_path, img_as_ubyte(restored))

example_images = [
    "examples/example.jpg",
]
gradio_app = gr.Interface(
    predict,
    inputs=gr.Image(label="Upload images with adverse weather degradations", type="filepath"),
    outputs=[
        gr.Image(type="filepath", label="Restored image", height=768, width=768),
        gr.Textbox(label="Error Message")
    ],
    title="Histoformer: All-in-one Image Restoration under Adverse Weather Conditions",
    description="[Histoformer](https://huggingface.co/sunsean/Histoformer/) is a image restoration model for all-in-one adverse weather.",
    examples=example_images
)

if __name__ == "__main__":
    gradio_app.launch()