File size: 2,531 Bytes
7eb697d
 
c4bae44
7eb697d
 
 
 
 
 
 
c4bae44
7eb697d
88e25ba
cb57f88
88e25ba
cb57f88
88e25ba
 
7eb697d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88e25ba
 
c4bae44
 
88e25ba
 
 
 
 
 
 
 
 
 
 
 
 
cb57f88
88e25ba
7eb697d
 
 
88e25ba
7eb697d
 
 
 
 
 
 
 
 
 
cb57f88
7eb697d
 
88e25ba
 
858cd5e
88e25ba
 
 
 
 
858cd5e
88e25ba
858cd5e
c4bae44
88e25ba
 
 
c4bae44
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python

import pathlib
import pickle
import sys

import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from torch import nn

sys.path.insert(0, "stylegan3")

TITLE = "StyleGAN3 Anime Face Generation"

MODEL_REPO = "hysts/stylegan3-anime-face-exp002-model"
MODEL_FILE_NAME = "009000.pkl"


def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
    mat = np.eye(3)
    sin = np.sin(angle / 360 * np.pi * 2)
    cos = np.cos(angle / 360 * np.pi * 2)
    mat[0][0] = cos
    mat[0][1] = sin
    mat[0][2] = translate[0]
    mat[1][0] = -sin
    mat[1][1] = cos
    mat[1][2] = translate[1]
    return mat


def load_model(device: torch.device) -> nn.Module:
    path = hf_hub_download(MODEL_REPO, MODEL_FILE_NAME)
    with pathlib.Path(path).open("rb") as f:
        model = pickle.load(f)  # noqa: S301
    model.eval()
    model.to(device)
    with torch.inference_mode():
        z = torch.zeros((1, 512)).to(device)
        c = torch.zeros(0).to(device)
        model(z, c)
    return model


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)


def generate_z(seed: int, device: torch.device) -> torch.Tensor:
    return torch.from_numpy(np.random.RandomState(seed).randn(1, 512)).to(device)


@torch.inference_mode()
def generate_image(seed: int, truncation_psi: float, tx: float, ty: float, angle: float) -> np.ndarray:
    seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
    z = generate_z(seed, device)
    c = torch.zeros(0).to(device)

    mat = make_transform((tx, ty), angle)
    mat = np.linalg.inv(mat)
    model.synthesis.input.transform.copy_(torch.from_numpy(mat))

    out = model(z, c, truncation_psi=truncation_psi)
    out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return out[0].cpu().numpy()


demo = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=3407851645),
        gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7),
        gr.Slider(label="Translate X", minimum=-1, maximum=1, step=0.05, value=0),
        gr.Slider(label="Translate Y", minimum=-1, maximum=1, step=0.05, value=0),
        gr.Slider(label="Angle", minimum=-180, maximum=180, step=5, value=0),
    ],
    outputs=gr.Image(label="Output"),
    title=TITLE,
    css_paths="style.css",
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()