DualStyleGAN / app.py
hysts's picture
hysts HF staff
Add examples
dbbdaef
raw
history blame
14.5 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import os
import pathlib
import sys
from typing import Callable
import dlib
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T
if os.environ.get('SYSTEM') == 'spaces':
os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
sys.path.insert(0, 'DualStyleGAN')
from model.dualstylegan import DualStyleGAN
from model.encoder.align_all_parallel import align_face
from model.encoder.psp import pSp
TOKEN = os.environ['TOKEN']
MODEL_REPO = 'hysts/DualStyleGAN'
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
return parser.parse_args()
class App:
def __init__(self, device: torch.device):
self.device = device
self.landmark_model = self._create_dlib_landmark_model()
self.encoder = self._load_encoder()
self.transform = self._create_transform()
self.style_types = [
'cartoon',
'caricature',
'anime',
'arcane',
'comic',
'pixar',
'slamdunk',
]
self.generator_dict = {
style_type: self._load_generator(style_type)
for style_type in self.style_types
}
self.exstyle_dict = {
style_type: self._load_exstylecode(style_type)
for style_type in self.style_types
}
@staticmethod
def _create_dlib_landmark_model():
path = huggingface_hub.hf_hub_download(
'hysts/dlib_face_landmark_model',
'shape_predictor_68_face_landmarks.dat',
use_auth_token=TOKEN)
return dlib.shape_predictor(path)
def _load_encoder(self) -> nn.Module:
ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
'models/encoder.pt',
use_auth_token=TOKEN)
ckpt = torch.load(ckpt_path, map_location='cpu')
opts = ckpt['opts']
opts['device'] = self.device.type
opts['checkpoint_path'] = ckpt_path
opts = argparse.Namespace(**opts)
model = pSp(opts)
model.to(self.device)
model.eval()
return model
@staticmethod
def _create_transform() -> Callable:
transform = T.Compose([
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
return transform
def _load_generator(self, style_type: str) -> nn.Module:
model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
ckpt_path = huggingface_hub.hf_hub_download(
MODEL_REPO,
f'models/{style_type}/generator.pt',
use_auth_token=TOKEN)
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt['g_ema'])
model.to(self.device)
model.eval()
return model
@staticmethod
def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
if style_type in ['cartoon', 'caricature', 'anime']:
filename = 'refined_exstyle_code.npy'
else:
filename = 'exstyle_code.npy'
path = huggingface_hub.hf_hub_download(
MODEL_REPO,
f'models/{style_type}/{filename}',
use_auth_token=TOKEN)
exstyles = np.load(path, allow_pickle=True).item()
return exstyles
def detect_and_align_face(self, image) -> np.ndarray:
image = align_face(filepath=image.name, predictor=self.landmark_model)
return image
@staticmethod
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
tensor = self.denormalize(tensor)
return tensor.cpu().numpy().transpose(1, 2, 0)
@torch.inference_mode()
def reconstruct_face(self,
image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
image = PIL.Image.fromarray(image)
input_data = self.transform(image).unsqueeze(0).to(self.device)
img_rec, instyle = self.encoder(input_data,
randomize_noise=False,
return_latents=True,
z_plus_latent=True,
return_z_plus_latent=True,
resize=False)
img_rec = torch.clamp(img_rec.detach(), -1, 1)
img_rec = self.postprocess(img_rec[0])
return img_rec, instyle
@torch.inference_mode()
def generate(self, style_type: str, style_id: int, structure_weight: float,
color_weight: float, structure_only: bool,
instyle: torch.Tensor) -> np.ndarray:
generator = self.generator_dict[style_type]
exstyles = self.exstyle_dict[style_type]
style_id = int(style_id)
stylename = list(exstyles.keys())[style_id]
latent = torch.tensor(exstyles[stylename]).to(self.device)
if structure_only:
latent[0, 7:18] = instyle[0, 7:18]
exstyle = generator.generator.style(
latent.reshape(latent.shape[0] * latent.shape[1],
latent.shape[2])).reshape(latent.shape)
img_gen, _ = generator([instyle],
exstyle,
z_plus_latent=True,
truncation=0.7,
truncation_latent=0,
use_res=True,
interp_weights=[structure_weight] * 7 +
[color_weight] * 11)
img_gen = torch.clamp(img_gen.detach(), -1, 1)
img_gen = self.postprocess(img_gen[0])
return img_gen
def get_style_image_url(style_name: str) -> str:
base_url = 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images'
filenames = {
'cartoon': 'cartoon_overview.jpg',
'caricature': 'caricature_overview.jpg',
'anime': 'anime_overview.jpg',
'arcane': 'Reconstruction_arcane_overview.jpg',
'comic': 'Reconstruction_comic_overview.jpg',
'pixar': 'Reconstruction_pixar_overview.jpg',
'slamdunk': 'Reconstruction_slamdunk_overview.jpg',
}
return f'{base_url}/{filenames[style_name]}'
def get_style_image_markdown_text(style_name: str) -> str:
url = get_style_image_url(style_name)
return f'<center><img id="style-image" src="{url}" alt="style image"></center>'
def update_slider(choice: str) -> dict:
max_vals = {
'cartoon': 316,
'caricature': 198,
'anime': 173,
'arcane': 99,
'comic': 100,
'pixar': 121,
'slamdunk': 119,
}
return gr.Slider.update(maximum=max_vals[choice], value=26)
def update_style_image(style_name: str) -> dict:
text = get_style_image_markdown_text(style_name)
return gr.Markdown.update(value=text)
def set_example(example: list) -> list[dict]:
return [
gr.Image.update(value=example[0]),
gr.Radio.update(value=example[1]),
gr.Slider.update(value=example[2]),
gr.Slider.update(value=example[3]),
gr.Slider.update(value=example[4]),
gr.Checkbox.update(value=example[5]),
]
def main():
args = parse_args()
app = App(device=torch.device(args.device))
css = '''
h1#title {
text-align: center;
}
img#overview {
max-width: 800px;
max-height: 600px;
}
img#style-image {
max-width: 1000px;
max-height: 600px;
}
'''
with gr.Blocks(theme=args.theme, css=css) as demo:
gr.Markdown(
'''<h1 id="title">Portrait Style Transfer with DualStyleGAN</h1>
This is an unofficial demo app for [https://github.com/williamyang1991/DualStyleGAN](https://github.com/williamyang1991/DualStyleGAN).
<center><img id="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" alt="overview"></center>
''')
with gr.Box():
gr.Markdown('''## Step 1 (Preprocess Input Image)
- Drop an image containing a near-frontal face to the **Input Image**.
- If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
- You can also load example inputs from the **Examples** section at the bottom of this page.
- Hit the **Detect & Align Face** button.
- Hit the **Reconstruct Face** button.
- The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
''')
with gr.Row():
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input Image',
type='file')
with gr.Row():
detect_button = gr.Button('Detect & Align Face')
with gr.Column():
with gr.Row():
face_image = gr.Image(label='Aligned Face',
type='numpy')
with gr.Row():
reconstruct_button = gr.Button('Reconstruct Face')
with gr.Column():
reconstructed_face = gr.Image(label='Reconstructed Face',
type='numpy')
instyle = gr.Variable()
with gr.Box():
gr.Markdown('''## Step 2 (Select Style Image)
- Select **Style Type**.
- Select **Style Image Index** from the image table below.
''')
with gr.Row():
with gr.Column():
style_type = gr.Radio(app.style_types, label='Style Type')
text = get_style_image_markdown_text('cartoon')
style_image = gr.Markdown(value=text)
style_index = gr.Slider(0,
316,
value=26,
step=1,
label='Style Image Index',
interactive=True)
with gr.Box():
gr.Markdown('''## Step 3 (Generate Style Transferred Image)
- Adjust **Structure Weight** and **Color Weight**.
- These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
- Hit the **Generate** button.
''')
with gr.Row():
with gr.Column():
with gr.Row():
structure_weight = gr.Slider(0,
1,
value=0.6,
step=0.1,
label='Structure Weight')
with gr.Row():
color_weight = gr.Slider(0,
1,
value=1,
step=0.1,
label='Color Weight')
with gr.Row():
structure_only = gr.Checkbox(label='Structure Only')
with gr.Row():
generate_button = gr.Button('Generate')
with gr.Column():
output_image = gr.Image(label='Output Image')
with gr.Box():
gr.Markdown('## Examples')
paths = sorted(pathlib.Path('images').glob('*.jpg'))
samples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0, False]
for path in paths]
examples = gr.Dataset(components=[
input_image, style_type, style_index, structure_weight,
color_weight, structure_only
],
samples=samples)
gr.Markdown(
'Related App: [https://huggingface.co/spaces/hysts/DualStyleGAN](https://huggingface.co/spaces/hysts/DualStyleGAN)'
)
gr.Markdown(
'<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" alt="visitor badge"/></center>'
)
examples.click(fn=set_example,
inputs=examples,
outputs=examples.components)
detect_button.click(fn=app.detect_and_align_face,
inputs=input_image,
outputs=face_image)
reconstruct_button.click(fn=app.reconstruct_face,
inputs=face_image,
outputs=[reconstructed_face, instyle])
style_type.change(fn=update_slider,
inputs=style_type,
outputs=style_index)
style_type.change(fn=update_style_image,
inputs=style_type,
outputs=style_image)
generate_button.click(fn=app.generate,
inputs=[
style_type,
style_index,
structure_weight,
color_weight,
structure_only,
instyle,
],
outputs=output_image)
demo.launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()