File size: 3,638 Bytes
d8eb455
 
 
ea1509a
d8eb455
 
 
 
 
ea1509a
 
 
 
 
 
 
 
 
 
 
 
 
 
d8eb455
1fed0d2
ea1509a
d8eb455
 
 
 
 
1fed0d2
d8eb455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from PIL import Image

import _thread
import torch
import numpy as np
from models.network_swinir import SwinIR as net

# model load
def load_model():
    global super_res_model
    global device
    
    param_key_g = 'params_ema'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    super_res_model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
                        img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                        num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                        mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
    super_res_pretrained_model = torch.load("model_zoo/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_PSNR.pth")
    super_res_model.load_state_dict(super_res_pretrained_model[param_key_g] if param_key_g in super_res_pretrained_model.keys() else super_res_pretrained_model, strict=True)
    super_res_model.eval()

super_res_model=None
_thread.start_new_thread(load_model, tuple())

def predict(input_img):
    out = None
    
    # preprocess input
    if(input_img is not None and super_res_model is not None):
        # model predict
        img_lq = input_img.astype(np.float32) / 255
        img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
        img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

        # inference
        window_size = 8
        model = super_res_model.to(device)

        with torch.no_grad():
            # pad input image to be a multiple of window_size
            _, _, h_old, w_old = img_lq.size()
            h_pad = (h_old // window_size + 1) * window_size - h_old
            w_pad = (w_old // window_size + 1) * window_size - w_old
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
            img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
            output = test(model, img_lq)
            output = output[..., :h_old * 4, :w_old * 4]
        
        # process image
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        if output.ndim == 3:
            output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
        output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
        
        # convert to pil image
        out = Image.fromarray(output)
    return out

def test(model, img_lq):
    # test the image tile by tile
    b, c, h, w = img_lq.size()
    tile = min(800, h, w)
    tile_overlap = 32
    sf = 4

    stride = tile - tile_overlap
    h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
    w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
    E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
    W = torch.zeros_like(E)

    for h_idx in h_idx_list:
        for w_idx in w_idx_list:
            in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
            out_patch = model(in_patch)
            out_patch_mask = torch.ones_like(out_patch)

            E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
            W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
    output = E.div_(W)
    return output

gr.Interface(
    fn=predict,
    inputs=[
        gr.inputs.Image()
    ],
    outputs=[
        gr.inputs.Image()
    ],
    title="SwinIR moon super resolution",
    description="Description of the app",
    examples=[
        "render0001.png", "render1546.png", "render1682.png"
    ]
).launch()