Spaces:
Runtime error
Runtime error
File size: 2,531 Bytes
7eb697d c4bae44 7eb697d c4bae44 7eb697d 88e25ba cb57f88 88e25ba cb57f88 88e25ba 7eb697d 88e25ba c4bae44 88e25ba cb57f88 88e25ba 7eb697d 88e25ba 7eb697d cb57f88 7eb697d 88e25ba 858cd5e 88e25ba 858cd5e 88e25ba 858cd5e c4bae44 88e25ba c4bae44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#!/usr/bin/env python
import pathlib
import pickle
import sys
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from torch import nn
sys.path.insert(0, "stylegan3")
TITLE = "StyleGAN3 Anime Face Generation"
MODEL_REPO = "hysts/stylegan3-anime-face-exp002-model"
MODEL_FILE_NAME = "009000.pkl"
def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
mat = np.eye(3)
sin = np.sin(angle / 360 * np.pi * 2)
cos = np.cos(angle / 360 * np.pi * 2)
mat[0][0] = cos
mat[0][1] = sin
mat[0][2] = translate[0]
mat[1][0] = -sin
mat[1][1] = cos
mat[1][2] = translate[1]
return mat
def load_model(device: torch.device) -> nn.Module:
path = hf_hub_download(MODEL_REPO, MODEL_FILE_NAME)
with pathlib.Path(path).open("rb") as f:
model = pickle.load(f) # noqa: S301
model.eval()
model.to(device)
with torch.inference_mode():
z = torch.zeros((1, 512)).to(device)
c = torch.zeros(0).to(device)
model(z, c)
return model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(device)
def generate_z(seed: int, device: torch.device) -> torch.Tensor:
return torch.from_numpy(np.random.RandomState(seed).randn(1, 512)).to(device)
@torch.inference_mode()
def generate_image(seed: int, truncation_psi: float, tx: float, ty: float, angle: float) -> np.ndarray:
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
z = generate_z(seed, device)
c = torch.zeros(0).to(device)
mat = make_transform((tx, ty), angle)
mat = np.linalg.inv(mat)
model.synthesis.input.transform.copy_(torch.from_numpy(mat))
out = model(z, c, truncation_psi=truncation_psi)
out = (out.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return out[0].cpu().numpy()
demo = gr.Interface(
fn=generate_image,
inputs=[
gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=3407851645),
gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7),
gr.Slider(label="Translate X", minimum=-1, maximum=1, step=0.05, value=0),
gr.Slider(label="Translate Y", minimum=-1, maximum=1, step=0.05, value=0),
gr.Slider(label="Angle", minimum=-180, maximum=180, step=5, value=0),
],
outputs=gr.Image(label="Output"),
title=TITLE,
css_paths="style.css",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|