File size: 3,752 Bytes
9080570
 
 
 
 
 
865d8a3
9080570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33e2863
 
 
 
 
 
9080570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8463d2
dae1a1c
 
9080570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865d8a3
 
 
 
 
 
9080570
 
 
da42bf0
 
038540f
 
 
9080570
 
 
 
36815eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import yaml
import torch
import argparse
import numpy as np
import gradio as gr
import requests

from PIL import Image
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel

from huggingface_hub import hf_hub_download
from gradio_imageslider import ImageSlider

## local code
from models import seemore


def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

def load_img (filename, norm=True,):
    img = np.array(Image.open(filename).convert("RGB"))
    h, w = img.shape[:2]
    
    if w > 1920 or h > 1080:
        new_h, new_w = h // 4, w // 4
        img = np.array(Image.fromarray(img).resize((new_w, new_h), Image.BICUBIC))
 
    if norm:
        img = img / 255.
        img = img.astype(np.float32)
    return img

def process_img (image):
    img = np.array(image)
    img = img / 255.
    img = img.astype(np.float32)
    y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
    
    with torch.no_grad():
        x_hat = model(y)

    restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
    restored_img = np.clip(restored_img, 0. , 1.)

    restored_img = (restored_img * 255.0).round().astype(np.uint8)  # float32 to uint8
    #return Image.fromarray(restored_img) #
    return (image, Image.fromarray(restored_img))

def load_network(net, load_path, strict=True, param_key='params'):
    if isinstance(net, (DataParallel, DistributedDataParallel)):
        net = net.module
    load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
    if param_key is not None:
        if param_key not in load_net and 'params' in load_net:
            param_key = 'params'
        load_net = load_net[param_key]
    # remove unnecessary 'module.'
    for k, v in deepcopy(load_net).items():
        if k.startswith('module.'):
            load_net[k[7:]] = v
            load_net.pop(k)
    net.load_state_dict(load_net, strict=strict)

CONFIG = "configs/eval_seemore_t_x4.yml"
hf_hub_download(repo_id="eduardzamfir/SeemoRe-T", filename="SeemoRe_T_X4.pth", local_dir="./")
MODEL_NAME = "SeemoRe_T_X4.pth"

# parse config file
with open(os.path.join(CONFIG), "r") as f:
    config = yaml.safe_load(f)

cfg = dict2namespace(config)

device = torch.device("cpu")
model = seemore.SeemoRe(scale=cfg.model.scale, in_chans=cfg.model.in_chans,
                        num_experts=cfg.model.num_experts, num_layers=cfg.model.num_layers, embedding_dim=cfg.model.embedding_dim, 
                        img_range=cfg.model.img_range, use_shuffle=cfg.model.use_shuffle, global_kernel_size=cfg.model.global_kernel_size, 
                        recursive=cfg.model.recursive, lr_space=cfg.model.lr_space, topk=cfg.model.topk)

model = model.to(device)
print ("IMAGE MODEL CKPT:", MODEL_NAME)
load_network(model, MODEL_NAME, strict=True, param_key='params')





# Ссылка на файл CSS
css_url = "https://aihubyufi-aihub.static.hf.space/style.css"

# Получение CSS по ссылке
response = requests.get(css_url)
css = response.text

demo = gr.Interface(
    fn=process_img,
    inputs=[gr.Image(type="pil", label="Изображение"),],
    outputs=ImageSlider(label="Улучшеное изображение", 
                        type="pil",
                        show_download_button=True,
                        ), #[gr.Image(type="pil", label="Ouput", min_width=500)],
    css=css,
)

if __name__ == "__main__":
    demo.queue(max_size=2).launch(show_api=False, share=False)