File size: 3,234 Bytes
6fee505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np

import cv2

from insightface import model_zoo
from dofaker.utils import download_file, get_model_url


class BSRGAN:

    def __init__(self, name='bsrgan', root='weights/models', scale=1) -> None:
        _, model_file = download_file(get_model_url(name),
                                      save_dir=root,
                                      overwrite=False)
        self.scale = scale
        providers = model_zoo.model_zoo.get_default_providers()
        self.session = model_zoo.model_zoo.PickableInferenceSession(
            model_file, providers=providers)

        self.input_mean = 0.0
        self.input_std = 255.0
        inputs = self.session.get_inputs()
        self.input_names = []
        for inp in inputs:
            self.input_names.append(inp.name)
        outputs = self.session.get_outputs()
        output_names = []
        for out in outputs:
            output_names.append(out.name)
        self.output_names = output_names
        assert len(
            self.output_names
        ) == 1, "The output number of BSRGAN model should be 1, but got {}, please check your model.".format(
            len(self.output_names))
        output_shape = outputs[0].shape
        input_cfg = inputs[0]
        input_shape = input_cfg.shape
        self.input_shape = input_shape
        print('image super resolution shape:', self.input_shape)

    def forward(self, image, image_format='bgr'):
        if isinstance(image, str):
            image = cv2.imread(image, 1)
            image_format = 'bgr'
        elif isinstance(image, np.ndarray):
            if image_format == 'bgr':
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            elif image_format == 'rgb':
                pass
            else:
                raise UserWarning(
                    "BSRGAN not support image format {}".format(image_format))
        else:
            raise UserWarning(
                "BSRGAN input must be str or np.ndarray, but got {}.".format(
                    type(image)))
        img = (image - self.input_mean) / self.input_std
        pred = self.session.run(self.output_names,
                                {self.input_names[0]: img})[0]
        return pred

    def get(self, img, image_format='bgr'):
        if image_format.lower() == 'bgr':
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        elif image_format.lower() == 'rgb':
            pass
        else:
            raise UserWarning(
                "gfpgan not support image format {}".format(image_format))
        h, w, c = img.shape
        blob = cv2.dnn.blobFromImage(
            img,
            1.0 / self.input_std, (w, h),
            (self.input_mean, self.input_mean, self.input_mean),
            swapRB=False)
        pred = self.session.run(self.output_names,
                                {self.input_names[0]: blob})[0]
        image_aug = pred.transpose((0, 2, 3, 1))[0]
        rgb_aug = np.clip(self.input_std * image_aug + self.input_mean, 0,
                          255).astype(np.uint8)
        rgb_aug = cv2.resize(rgb_aug,
                             (int(w * self.scale), int(h * self.scale)))
        bgr_aug = rgb_aug[:, :, ::-1]
        return bgr_aug