#!/usr/bin/env python from __future__ import annotations import argparse import os import sys from typing import Callable import dlib import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch import torch.nn as nn import torchvision.transforms as T if os.environ.get('SYSTEM') == 'spaces': os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py") os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py") sys.path.insert(0, 'DualStyleGAN') from model.dualstylegan import DualStyleGAN from model.encoder.align_all_parallel import align_face from model.encoder.psp import pSp STYLE_IMAGE_PATHS = { 'cartoon': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/cartoon_overview.jpg', 'caricature': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/caricature_overview.jpg', 'anime': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/anime_overview.jpg', 'arcane': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_arcane_overview.jpg', 'comic': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_comic_overview.jpg', 'pixar': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_pixar_overview.jpg', 'slamdunk': 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_slamdunk_overview.jpg', } TOKEN = os.environ['TOKEN'] MODEL_REPO = 'hysts/DualStyleGAN' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') return parser.parse_args() class App: def __init__(self, device: torch.device): self.device = device self.face_detector = self._create_dlib_landmark_model() self.encoder = self._load_encoder() self.transform = self._create_transform() self.style_types = [ 'cartoon', 'caricature', 'anime', 'arcane', 'comic', 'pixar', 'slamdunk', ] self.generator_dict = { style_type: self._load_generator(style_type) for style_type in self.style_types } self.exstyle_dict = { style_type: self._load_exstylecode(style_type) for style_type in self.style_types } @staticmethod def _create_dlib_landmark_model(): path = huggingface_hub.hf_hub_download( 'hysts/dlib_face_landmark_model', 'shape_predictor_68_face_landmarks.dat', use_auth_token=TOKEN) return dlib.shape_predictor(path) def _load_encoder(self) -> nn.Module: ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt', use_auth_token=TOKEN) ckpt = torch.load(ckpt_path, map_location='cpu') opts = ckpt['opts'] opts['device'] = self.device.type opts['checkpoint_path'] = ckpt_path opts = argparse.Namespace(**opts) model = pSp(opts) model.to(self.device) model.eval() return model @staticmethod def _create_transform() -> Callable: transform = T.Compose([ T.Resize(256), T.CenterCrop(256), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) return transform def _load_generator(self, style_type: str) -> nn.Module: model = DualStyleGAN(1024, 512, 8, 2, res_index=6) ckpt_path = huggingface_hub.hf_hub_download( MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN) ckpt = torch.load(ckpt_path, map_location='cpu') model.load_state_dict(ckpt['g_ema']) model.to(self.device) model.eval() return model @staticmethod def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]: if style_type in ['cartoon', 'caricature', 'anime']: filename = 'refined_exstyle_code.npy' else: filename = 'exstyle_code.npy' path = huggingface_hub.hf_hub_download( MODEL_REPO, f'models/{style_type}/{filename}', use_auth_token=TOKEN) exstyles = np.load(path, allow_pickle=True).item() return exstyles def detect_and_align_face(self, image) -> np.ndarray: image = align_face(filepath=image.name, predictor=self.face_detector) return image @staticmethod def denormalize(tensor: torch.Tensor) -> torch.Tensor: return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8) def postprocess(self, tensor: torch.Tensor) -> np.ndarray: tensor = self.denormalize(tensor) return tensor.cpu().numpy().transpose(1, 2, 0) @torch.inference_mode() def reconstruct_face(self, image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]: image = PIL.Image.fromarray(image) input_data = self.transform(image).unsqueeze(0).to(self.device) img_rec, instyle = self.encoder(input_data, randomize_noise=False, return_latents=True, z_plus_latent=True, return_z_plus_latent=True, resize=False) img_rec = torch.clamp(img_rec.detach(), -1, 1) img_rec = self.postprocess(img_rec[0]) return img_rec, instyle @torch.inference_mode() def generate(self, style_type: str, style_id: int, structure_weight: float, color_weight: float, structure_only: bool, instyle: torch.Tensor) -> np.ndarray: generator = self.generator_dict[style_type] exstyles = self.exstyle_dict[style_type] style_id = int(style_id) stylename = list(exstyles.keys())[style_id] latent = torch.tensor(exstyles[stylename]).to(self.device) if structure_only: latent[0, 7:18] = instyle[0, 7:18] exstyle = generator.generator.style( latent.reshape(latent.shape[0] * latent.shape[1], latent.shape[2])).reshape(latent.shape) img_gen, _ = generator([instyle], exstyle, z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[structure_weight] * 7 + [color_weight] * 11) img_gen = torch.clamp(img_gen.detach(), -1, 1) img_gen = self.postprocess(img_gen[0]) return img_gen def update_slider(choice: str): max_vals = { 'cartoon': 316, 'caricature': 198, 'anime': 173, 'arcane': 99, 'comic': 100, 'pixar': 121, 'slamdunk': 119, } return gr.Slider.update(maximum=max_vals[choice] + 1, value=26) def update_style_image(choice: str): style_image_path = STYLE_IMAGE_PATHS[choice] text = f'
style image
' return gr.Markdown.update(value=text) def main(): args = parse_args() app = App(device=torch.device(args.device)) with gr.Blocks(theme=args.theme) as demo: gr.Markdown( '''

Portrait Style Transfer with DualStyleGAN

This is an unofficial demo app for https://github.com/williamyang1991/DualStyleGAN.
overview
Related App: https://huggingface.co/spaces/hysts/DualStyleGAN ''') with gr.Box(): gr.Markdown('''## Step 1 - Drop an image containing a near-frontal face to the **Input Image**. - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand. - Hit the **Detect & Align** button. - Hit the **Reconstruct Face** button. - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image. ''') with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='file') with gr.Row(): detect_button = gr.Button('Detect & Align Face') with gr.Column(): with gr.Row(): face_image = gr.Image(label='Aligned Face', type='numpy') with gr.Row(): reconstruct_button = gr.Button('Reconstruct Face') with gr.Column(): reconstructed_face = gr.Image(label='Reconstructed Face', type='numpy') instyle = gr.Variable() with gr.Box(): gr.Markdown('''## Step 2 - Select **Style Type**. - Select **Style Image Index** from the image table below. ''') with gr.Row(): with gr.Column(): with gr.Column(): style_type = gr.Radio(app.style_types, label='Style Type') with gr.Column(): style_index = gr.Slider(0, 317, value=26, step=1, label='Style Image Index', interactive=True) style_image_path = STYLE_IMAGE_PATHS['cartoon'] text = f'
style image
' style_image = gr.Markdown(value=text) with gr.Box(): gr.Markdown('''## Step 3 - Adjust **Structure Weight** and **Color Weight**. - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image. - Hit the **Generate** button. ''') with gr.Row(): with gr.Column(): with gr.Row(): structure_weight = gr.Slider(0, 1, value=0.6, step=0.1, label='Structure Weight') with gr.Row(): color_weight = gr.Slider(0, 1, value=1, step=0.1, label='Color Weight') with gr.Row(): structure_only = gr.Checkbox(label='Structure Only') with gr.Row(): generate_button = gr.Button('Generate') with gr.Column(): output_image = gr.Image(label='Output Image') gr.Markdown( '
visitor badge
' ) detect_button.click(fn=app.detect_and_align_face, inputs=input_image, outputs=face_image) reconstruct_button.click(fn=app.reconstruct_face, inputs=face_image, outputs=[reconstructed_face, instyle]) style_type.change(fn=update_slider, inputs=style_type, outputs=style_index) style_type.change(fn=update_style_image, inputs=style_type, outputs=style_image) generate_button.click(fn=app.generate, inputs=[ style_type, style_index, structure_weight, color_weight, structure_only, instyle, ], outputs=output_image) demo.launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()