hysts HF staff commited on
Commit
20ed832
1 Parent(s): 94ce6bd
Files changed (5) hide show
  1. .gitmodules +3 -0
  2. StyleSwin +1 -0
  3. app.py +135 -0
  4. packages.txt +1 -0
  5. requirements.txt +5 -0
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "StyleSwin"]
2
+ path = StyleSwin
3
+ url = https://github.com/microsoft/StyleSwin
StyleSwin ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 52c23dcfa39a5da75f02892cb775fe8f424be6ec
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import sys
9
+
10
+ sys.path.insert(0, 'StyleSwin')
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+ import torch
17
+ import torch.nn as nn
18
+ from models.generator import Generator
19
+
20
+ TOKEN = os.environ['TOKEN']
21
+
22
+ MODEL_REPO = 'hysts/StyleSwin'
23
+ MODEL_NAMES = [
24
+ 'CelebAHQ_256',
25
+ 'FFHQ_256',
26
+ 'LSUNChurch_256',
27
+ 'CelebAHQ_1024',
28
+ 'FFHQ_1024',
29
+ ]
30
+
31
+
32
+ def parse_args() -> argparse.Namespace:
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument('--device', type=str, default='cpu')
35
+ parser.add_argument('--theme', type=str)
36
+ parser.add_argument('--live', action='store_true')
37
+ parser.add_argument('--share', action='store_true')
38
+ parser.add_argument('--port', type=int)
39
+ parser.add_argument('--disable-queue',
40
+ dest='enable_queue',
41
+ action='store_false')
42
+ parser.add_argument('--allow-flagging', type=str, default='never')
43
+ parser.add_argument('--allow-screenshot', action='store_true')
44
+ return parser.parse_args()
45
+
46
+
47
+ def load_model(model_name: str, device: torch.device) -> nn.Module:
48
+ size = int(model_name.split('_')[1])
49
+ channel_multiplier = 1 if size == 1024 else 2
50
+ model = Generator(size,
51
+ style_dim=512,
52
+ n_mlp=8,
53
+ channel_multiplier=channel_multiplier)
54
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
55
+ f'models/{model_name}.pt',
56
+ use_auth_token=TOKEN)
57
+ ckpt = torch.load(ckpt_path)
58
+ model.load_state_dict(ckpt['g_ema'])
59
+ model.to(device)
60
+ model.eval()
61
+ return model
62
+
63
+
64
+ def generate_z(seed: int, device: torch.device) -> torch.Tensor:
65
+ return torch.from_numpy(np.random.RandomState(seed).randn(
66
+ 1, 512)).to(device).float()
67
+
68
+
69
+ def postprocess(tensors: torch.Tensor) -> torch.Tensor:
70
+ assert tensors.dim() == 4
71
+ tensors = tensors.cpu()
72
+ std = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
73
+ mean = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
74
+ tensors = tensors * std + mean
75
+ tensors = (tensors * 255).clamp(0, 255).to(torch.uint8)
76
+ return tensors
77
+
78
+
79
+ @torch.inference_mode()
80
+ def generate_image(model_name: str, seed: int, model_dict: dict,
81
+ device: torch.device) -> PIL.Image.Image:
82
+ model = model_dict[model_name]
83
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
84
+ z = generate_z(seed, device)
85
+ out, _ = model(z)
86
+ out = postprocess(out)
87
+ out = out.numpy()[0].transpose(1, 2, 0)
88
+ return PIL.Image.fromarray(out, 'RGB')
89
+
90
+
91
+ def main():
92
+ gr.close_all()
93
+
94
+ args = parse_args()
95
+ device = torch.device(args.device)
96
+
97
+ model_dict = {name: load_model(name, device) for name in MODEL_NAMES}
98
+
99
+ func = functools.partial(generate_image,
100
+ model_dict=model_dict,
101
+ device=device)
102
+ func = functools.update_wrapper(func, generate_image)
103
+
104
+ repo_url = 'https://github.com/microsoft/StyleSwin'
105
+ title = 'microsoft/StyleSwin'
106
+ description = f'A demo for {repo_url}'
107
+ article = None
108
+
109
+ gr.Interface(
110
+ func,
111
+ [
112
+ gr.inputs.Radio(MODEL_NAMES,
113
+ type='value',
114
+ default='FFHQ_256',
115
+ label='model',
116
+ optional=False),
117
+ gr.inputs.Slider(0, 2147483647, step=1, default=0, label='Seed'),
118
+ ],
119
+ gr.outputs.Image(type='pil', label='Output'),
120
+ theme=args.theme,
121
+ title=title,
122
+ description=description,
123
+ article=article,
124
+ allow_screenshot=args.allow_screenshot,
125
+ allow_flagging=args.allow_flagging,
126
+ live=args.live,
127
+ ).launch(
128
+ enable_queue=args.enable_queue,
129
+ server_port=args.port,
130
+ share=args.share,
131
+ )
132
+
133
+
134
+ if __name__ == '__main__':
135
+ main()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ninja-build
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ timm==0.5.4
4
+ torch==1.11.0
5
+ torchvision==0.12.0