File size: 3,940 Bytes
523fb10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9be4e6
523fb10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9be4e6
523fb10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9be4e6
523fb10
 
 
 
 
 
 
 
 
 
 
 
b9be4e6
523fb10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import cv2
import numpy as np
from PIL import Image
import glob

import torch
import tqdm
import shutil
import argparse
from third_party.GPEN.face_enhancement import FaceEnhancement

make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))


class GPENImageInfer(object):
    def __init__(self, device):
        super(GPENImageInfer, self).__init__()

        model = {
            "name": "GPEN-BFR-512",
            "in_size": 512,
            "out_size": 512,
            "channel_multiplier": 2,
            "narrow": 1,
        }
        faceenhancer = FaceEnhancement(
            base_dir=make_abs_path('./'),
            use_sr=True,
            in_size=model["in_size"],
            out_size=model["out_size"],
            model=model["name"],
            channel_multiplier=model["channel_multiplier"],
            narrow=model["narrow"],
            device=device,
        )
        self.faceenhancer = faceenhancer

    def image_infer(self, in_img: np.ndarray):
        """

        :param in_img: np.ndarray, (H,W,BGR), in [0,255]
        :return: out_img: np.ndarray, (H,W,BGR), in [0,255]
        """
        h, w, _ = in_img.shape
        out_img, orig_faces, enhanced_faces = self.faceenhancer.process(in_img)
        out_img = cv2.resize(out_img, (w, h))
        return out_img

    def ndarray_infer(self, in_ndarray: np.ndarray,
                      save_folder: str = 'demo_images/out/',
                      save_name: str = 'reen.png',
                      ):
        """

        :param in_ndarray: np.ndarray, (N,H,W,BGR), in [0,255]
        :param save_folder: not used
        :param save_name: not used
        :return: out_ndarray: np.ndarray, (N,H,W,BGR), in [0,255]
        """
        B, H, W, C = in_ndarray.shape

        out_ndarray = np.zeros_like(in_ndarray, dtype=np.uint8)  # (N,H,W,BGR)
        for b_idx in range(B):
            single_img = in_ndarray[b_idx]
            out_img = self.image_infer(single_img)  # (H,W,BGR), in [0,255]
            out_ndarray[b_idx] = out_img
        return out_ndarray

    def batch_infer(self, in_batch: torch.Tensor,
                          save_folder: str = 'demo_images/out/',
                          save_name: str = 'reen.png',
                          save_batch_idx: int = 0,
                          ):
        """

        :param in_batch: (N,RGB,H,W), in [-1,1]
        :return: out_batch: (N,RGB,H,W), in [-1,1]
        """
        B, C, H, W = in_batch.shape
        device = in_batch.device

        in_batch = ((in_batch + 1.) * 127.5).permute(0, 2, 3, 1)
        in_batch = in_batch.cpu().numpy().astype(np.uint8)  # (N,H,W,RGB), in [0,255]
        in_batch = in_batch[:, :, :, ::-1]  # (N,H,W,BGR)

        out_batch = np.zeros_like(in_batch, dtype=np.uint8)  # (N,H,W,BGR)
        for b_idx in range(B):
            single_img = in_batch[b_idx]
            out_img = self.image_infer(single_img)  # (H,W,BGR), in [0,255]
            out_batch[b_idx] = out_img[:, :, ::-1]
            if save_batch_idx is not None and b_idx == save_batch_idx:
                cv2.imwrite(os.path.join(save_folder, save_name), out_img)
        out_batch = torch.FloatTensor(out_batch).to(device)
        out_batch = out_batch / 127.5 - 1.  # (N,H,W,RGB)
        out_batch = out_batch.permute(0, 3, 1, 2)  # (N,RGB,H,W)
        out_batch = out_batch.clamp(-1, 1)

        return out_batch


if __name__ == '__main__':
    gpen = GPENImageInfer()

    in_folder = 'examples/imgs/'
    img_list = os.listdir(in_folder)

    for img_name in img_list:
        if 'gpen' in img_name:
            continue

        in_path = os.path.join(in_folder, img_name)
        out_path = in_path.replace('.png', '_gpen.png')
        out_path = in_path.replace('.jpg', '_gpen.jpg')

        im = cv2.imread(in_path, cv2.IMREAD_COLOR)  # BGR
        img = gpen.image_infer(im)
        cv2.imwrite(out_path, img)