File size: 4,637 Bytes
29b0bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62e3f73
29b0bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
62e3f73
29b0bbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch 
from gfpgan import GFPGANer
from tqdm import tqdm
import cv2
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
import warnings
from enum import Enum

class EnhancementMethod(str, Enum):
    gfpgan = "gfpgan"
    RestoreFormer = "RestoreFormer"
    codeformer = "codeformer"
    realesrgan = "realesrgan"


class Enhancer:
    def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
        self.method = method
        self.background_enhancement = background_enhancement
        self.upscale = upscale
        self.bg_upsampler = None
        self.realesrgan_enhancer = None

        if self.method != EnhancementMethod.realesrgan:
            self.setup_face_enhancer()
            if self.background_enhancement:
                self.setup_background_enhancer()
        else:
            self.setup_realesrgan_enhancer()

    def setup_background_enhancer(self):
        if not torch.cuda.is_available():
            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
            return

        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
        model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
        self.bg_upsampler = RealESRGANer(
            scale=self.upscale,
            model_path=model_path,
            model=model,
            tile=400,
            tile_pad=10,
            pre_pad=0,
            half=True)

    def setup_realesrgan_enhancer(self):
        if not torch.cuda.is_available():
            raise ValueError('CUDA is not available for RealESRGAN')

        model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
        model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
        self.realesrgan_enhancer = RealESRGANer(
            scale=self.upscale,
            model_path=model_path,
            model=model,
            tile=400,
            tile_pad=10,
            pre_pad=0,
            half=True)

    def setup_face_enhancer(self):
        model_configs = {
            EnhancementMethod.gfpgan: {
                'arch': 'clean',
                'channel_multiplier': 2,
                'model_name': 'GFPGANv1.4',
                'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
            },
            EnhancementMethod.RestoreFormer: {
                'arch': 'RestoreFormer',
                'channel_multiplier': 2,
                'model_name': 'RestoreFormer',
                'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
            },
            EnhancementMethod.codeformer: {
                'arch': 'CodeFormer',
                'channel_multiplier': 2,
                'model_name': 'CodeFormer',
                'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
            }
        }

        config = model_configs.get(self.method)
        if not config:
            raise ValueError(f'Wrong model version {self.method}')

        model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
        if not os.path.isfile(model_path):
            model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
        if not os.path.isfile(model_path):
            model_path = config['url']

        self.face_enhancer = GFPGANer(
            model_path=model_path,
            upscale=self.upscale,
            arch=config['arch'],
            channel_multiplier=config['channel_multiplier'],
            bg_upsampler=self.bg_upsampler)

    def check_image_resolution(self, image):
        height, width, _ = image.shape
        return width, height

    async def enhance(self, image):
        img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        width, height = self.check_image_resolution(img)
        
        if self.method == EnhancementMethod.realesrgan:
            enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
        else:
            _, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
                img,
                has_aligned=False,
                only_center_face=False,
                paste_back=True)
        
        enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
        enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)

        return enhanced_img, (width, height), (enhanced_width, enhanced_height)