import os import glob if len(glob.glob('weights/*.pth')) != 6: os.system('wget -q "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116085&Signature=GlUNW6%2B8FxvxWmE9jKIZYOOciKQ%3D" -O weights/RetinaFace-R50.pth') os.system('wget -q "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116208&Signature=hBgvVvKVSNGeXqT8glG%2Bd2t2OKc%3D" -O weights/GPEN-512.pth') os.system('wget -q "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116315&Signature=9tPavW2h%2F1LhIKiXj73sTQoWqcc%3D" -O weights/GPEN-1024-Color.pth ') os.system('wget -q "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1962694780&Signature=lI%2FolhA%2FyigiTRvoDIVbtMIyhjI%3D" -O weights/realesrnet_x2.pth ') os.system('wget -q "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Inpainting-1024.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116338&Signature=tvYhdLaLgW7UdcUrApXp2jsek8w%3D" -O weights/GPEN-Inpainting-1024.pth ') jksp = os.environ['GPEN-BFR-2048'] os.system(f'wget -q "{jksp}" -O weights/GPEN-BFR-2048.pth') import gradio as gr ''' @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021) @author: yangxy (yangtao9009@gmail.com) ''' import os import cv2 import glob from face_enhancement import FaceEnhancement from face_colorization import FaceColorization from face_inpainting import FaceInpainting def inference(file, mode): im_orig = cv2.imread(file, cv2.IMREAD_COLOR) im = cv2.resize(im_orig, (0,0), fx=2, fy=2) faceenhancer = FaceEnhancement(size=512, model='GPEN-512', channel_multiplier=2, device='cpu', u=False) img, orig_faces, enhanced_faces = faceenhancer.process(im) cv2.imwrite(os.path.join("e.png"), img) if mode == "enhance": return os.path.join("e.png") elif mode == "colorize": model = {'name':'GPEN-1024-Color', 'size':1024} grayf = cv2.imread(file, cv2.IMREAD_GRAYSCALE) grayf = cv2.cvtColor(grayf, cv2.COLOR_GRAY2BGR) # channel: 1->3 facecolorizer = FaceColorization(size=model['size'], model=model['name'], channel_multiplier=2, device='cpu') colorf = facecolorizer.process(grayf) colorf = cv2.resize(colorf, (grayf.shape[1], grayf.shape[0])) cv2.imwrite(os.path.join("output.png"), colorf) return os.path.join("output.png") elif mode == "inpainting": model = {'name':'GPEN-Inpainting-1024', 'size':1024} faceinpainter = FaceInpainting(size=model['size'], model=model['name'], channel_multiplier=2, device='cpu') inpaint = faceinpainter.process(im_orig) cv2.imwrite(os.path.join("output.png"), inpaint) return os.path.join("output.png") elif mode == "selfie": model = {'name':'GPEN-BFR-2048', 'size':2048} im = cv2.resize(im, (0,0), fx=4, fy=4) faceenhancer = FaceEnhancement(size=model['size'], model=model['name'], channel_multiplier=2, device='cpu') img, orig_faces, enhanced_faces = faceenhancer.process(im) cv2.imwrite(os.path.join("output.png"), img) return os.path.join("output.png") else: faceenhancer = FaceEnhancement(size=512, model='GPEN-512', channel_multiplier=2, device='cpu', u=True) img, orig_faces, enhanced_faces = faceenhancer.process(im) cv2.imwrite(os.path.join("output.png"), img) return os.path.join("output.png") title = "GPEN" description = "Gradio demo for GAN Prior Embedded Network for Blind Face Restoration in the Wild. This version of gradio demo includes face colorization from GPEN. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "
GAN Prior Embedded Network for Blind Face Restoration in the Wild | Github Repo
" gr.Interface( inference, [gr.inputs.Image(type="filepath", label="Input"),gr.inputs.Radio(["enhance", "colorize", "inpainting", "selfie", "enhanced+background"], type="value", default="enhance", label="Type")], gr.outputs.Image(type="file", label="Output"), title=title, description=description, article=article, examples=[ ['enhance.png', 'enhance'], ['color.png', 'colorize'], ['inpainting.png', 'inpainting'], ['selfie.png', 'selfie'] ], enable_queue=True ).launch()