Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from gfpgan import GFPGANer | |
from tqdm import tqdm | |
import cv2 | |
from realesrgan import RealESRGANer | |
from basicsr.archs.rrdbnet_arch import RRDBNet | |
import warnings | |
from enum import Enum | |
class EnhancementMethod(str, Enum): | |
gfpgan = "gfpgan" | |
RestoreFormer = "RestoreFormer" | |
codeformer = "codeformer" | |
realesrgan = "realesrgan" | |
class Enhancer: | |
def __init__(self, method: EnhancementMethod, background_enhancement=True, upscale=2): | |
self.method = method | |
self.background_enhancement = background_enhancement | |
self.upscale = upscale | |
self.bg_upsampler = None | |
self.realesrgan_enhancer = None | |
if self.method != EnhancementMethod.realesrgan: | |
self.setup_face_enhancer() | |
if self.background_enhancement: | |
self.setup_background_enhancer() | |
else: | |
self.setup_realesrgan_enhancer() | |
def setup_background_enhancer(self): | |
if not torch.cuda.is_available(): | |
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it.') | |
return | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) | |
model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth' | |
self.bg_upsampler = RealESRGANer( | |
scale=self.upscale, | |
model_path=model_path, | |
model=model, | |
tile=400, | |
tile_pad=10, | |
pre_pad=0, | |
half=True) | |
def setup_realesrgan_enhancer(self): | |
if not torch.cuda.is_available(): | |
raise ValueError('CUDA is not available for RealESRGAN') | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=self.upscale) | |
model_path = f'https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth' | |
self.realesrgan_enhancer = RealESRGANer( | |
scale=self.upscale, | |
model_path=model_path, | |
model=model, | |
tile=400, | |
tile_pad=10, | |
pre_pad=0, | |
half=True) | |
def setup_face_enhancer(self): | |
model_configs = { | |
EnhancementMethod.gfpgan: { | |
'arch': 'clean', | |
'channel_multiplier': 2, | |
'model_name': 'GFPGANv1.4', | |
'url': 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth' | |
}, | |
EnhancementMethod.RestoreFormer: { | |
'arch': 'RestoreFormer', | |
'channel_multiplier': 2, | |
'model_name': 'RestoreFormer', | |
'url': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth' | |
}, | |
EnhancementMethod.codeformer: { | |
'arch': 'CodeFormer', | |
'channel_multiplier': 2, | |
'model_name': 'CodeFormer', | |
'url': 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth' | |
} | |
} | |
config = model_configs.get(self.method) | |
if not config: | |
raise ValueError(f'Wrong model version {self.method}') | |
model_path = os.path.join('gfpgan/weights', config['model_name'] + '.pth') | |
if not os.path.isfile(model_path): | |
model_path = os.path.join('checkpoints', config['model_name'] + '.pth') | |
if not os.path.isfile(model_path): | |
model_path = config['url'] | |
self.face_enhancer = GFPGANer( | |
model_path=model_path, | |
upscale=self.upscale, | |
arch=config['arch'], | |
channel_multiplier=config['channel_multiplier'], | |
bg_upsampler=self.bg_upsampler) | |
def check_image_resolution(self, image): | |
height, width, _ = image.shape | |
return width, height | |
async def enhance(self, image): | |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
width, height = self.check_image_resolution(img) | |
if self.method == EnhancementMethod.realesrgan: | |
enhanced_img, _ = await asyncio.to_thread(self.realesrgan_enhancer.enhance, img, outscale=self.upscale) | |
else: | |
_, _, enhanced_img = await asyncio.to_thread(self.face_enhancer.enhance, | |
img, | |
has_aligned=False, | |
only_center_face=False, | |
paste_back=True) | |
enhanced_img = cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB) | |
enhanced_width, enhanced_height = self.check_image_resolution(enhanced_img) | |
return enhanced_img, (width, height), (enhanced_width, enhanced_height) |