File size: 3,989 Bytes
e041d7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import cv2
import numpy as np
import torch
from loguru import logger

from iopaint.helper import download_model
from iopaint.plugins.base_plugin import BasePlugin
from iopaint.schema import RunPluginRequest, RealESRGANModel


class RealESRGANUpscaler(BasePlugin):
    name = "RealESRGAN"
    support_gen_image = True

    def __init__(self, name, device, no_half=False):
        super().__init__()
        self.model_name = name
        self.device = device
        self.no_half = no_half
        self._init_model(name)

    def _init_model(self, name):
        from basicsr.archs.rrdbnet_arch import RRDBNet
        from realesrgan import RealESRGANer
        from realesrgan.archs.srvgg_arch import SRVGGNetCompact

        REAL_ESRGAN_MODELS = {
            RealESRGANModel.realesr_general_x4v3: {
                "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
                "scale": 4,
                "model": lambda: SRVGGNetCompact(
                    num_in_ch=3,
                    num_out_ch=3,
                    num_feat=64,
                    num_conv=32,
                    upscale=4,
                    act_type="prelu",
                ),
                "model_md5": "91a7644643c884ee00737db24e478156",
            },
            RealESRGANModel.RealESRGAN_x4plus: {
                "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
                "scale": 4,
                "model": lambda: RRDBNet(
                    num_in_ch=3,
                    num_out_ch=3,
                    num_feat=64,
                    num_block=23,
                    num_grow_ch=32,
                    scale=4,
                ),
                "model_md5": "99ec365d4afad750833258a1a24f44ca",
            },
            RealESRGANModel.RealESRGAN_x4plus_anime_6B: {
                "url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
                "scale": 4,
                "model": lambda: RRDBNet(
                    num_in_ch=3,
                    num_out_ch=3,
                    num_feat=64,
                    num_block=6,
                    num_grow_ch=32,
                    scale=4,
                ),
                "model_md5": "d58ce384064ec1591c2ea7b79dbf47ba",
            },
        }
        if name not in REAL_ESRGAN_MODELS:
            raise ValueError(f"Unknown RealESRGAN model name: {name}")
        model_info = REAL_ESRGAN_MODELS[name]

        model_path = download_model(model_info["url"], model_info["model_md5"])
        logger.info(f"RealESRGAN model path: {model_path}")

        self.model = RealESRGANer(
            scale=model_info["scale"],
            model_path=model_path,
            model=model_info["model"](),
            half=True if "cuda" in str(self.device) and not self.no_half else False,
            tile=512,
            tile_pad=10,
            pre_pad=10,
            device=self.device,
        )

    def switch_model(self, new_model_name: str):
        if self.model_name == new_model_name:
            return
        self._init_model(new_model_name)
        self.model_name = new_model_name

    def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
        bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
        logger.info(f"RealESRGAN input shape: {bgr_np_img.shape}, scale: {req.scale}")
        result = self.forward(bgr_np_img, req.scale)
        logger.info(f"RealESRGAN output shape: {result.shape}")
        return result

    @torch.inference_mode()
    def forward(self, bgr_np_img, scale: float):
        # 输出是 BGR
        upsampled = self.model.enhance(bgr_np_img, outscale=scale)[0]
        return upsampled

    def check_dep(self):
        try:
            import realesrgan
        except ImportError:
            return "RealESRGAN is not installed, please install it first. pip install realesrgan"