media_enhancer / image_enhancer.py
d0tpy's picture
Update image_enhancer.py
62e3f73 verified
import os
import torch
from gfpgan import GFPGANer
from tqdm import tqdm
import cv2
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
import warnings
from enum import Enum
class EnhancementMethod(str, Enum):
gfpgan = "gfpgan"
RestoreFormer = "RestoreFormer"
codeformer = "codeformer"
realesrgan = "realesrgan"
class Enhancer:
def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2):
self.method = method
self.background_enhancement = background_enhancement
self.upscale = upscale
self.bg_upsampler = None
self.realesrgan_enhancer = None
if self.method != EnhancementMethod.realesrgan:
self.setup_face_enhancer()
if self.background_enhancement:
self.setup_background_enhancer()
else:
self.setup_realesrgan_enhancer()
def setup_background_enhancer(self):
if not torch.cuda.is_available():
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.')
return
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
self.bg_upsampler = RealESRGANer(
scale=self.upscale,
model_path=model_path,
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True)
def setup_realesrgan_enhancer(self):
if not torch.cuda.is_available():
raise ValueError('CUDA is not available for RealESRGAN')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale)
model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth'
self.realesrgan_enhancer = RealESRGANer(
scale=self.upscale,
model_path=model_path,
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True)
def setup_face_enhancer(self):
model_configs = {
EnhancementMethod.gfpgan: {
'arch': 'clean',
'channel_multiplier': 2,
'model_name': 'GFPGANv1.4',
'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
},
EnhancementMethod.RestoreFormer: {
'arch': 'RestoreFormer',
'channel_multiplier': 2,
'model_name': 'RestoreFormer',
'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
},
EnhancementMethod.codeformer: {
'arch': 'CodeFormer',
'channel_multiplier': 2,
'model_name': 'CodeFormer',
'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
}
}
config = model_configs.get(self.method)
if not config:
raise ValueError(f'Wrong model version {self.method}')
model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('checkpoints', config['model_name'] + '.pth')
if not os.path.isfile(model_path):
model_path = config['url']
self.face_enhancer = GFPGANer(
model_path=model_path,
upscale=self.upscale,
arch=config['arch'],
channel_multiplier=config['channel_multiplier'],
bg_upsampler=self.bg_upsampler)
def check_image_resolution(self, image):
height, width, _ = image.shape
return width, height
async def enhance(self, image):
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
width, height = self.check_image_resolution(img)
if self.method == EnhancementMethod.realesrgan:
enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale)
else:
_, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance,
img,
has_aligned=False,
only_center_face=False,
paste_back=True)
enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB)
enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img)
return enhanced_img, (width, height), (enhanced_width, enhanced_height)