Spaces:
Runtime error
Runtime error
Init commits
Browse files- .gitignore +1 -0
- README.md +3 -2
- app.py +209 -0
- imgs/out1.png +0 -0
- imgs/out2.png +0 -0
- model.py +395 -0
- packages.txt +1 -0
- requirements.txt +8 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
gradio_queue*
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Anime BigGAN
|
3 |
-
emoji:
|
4 |
colorFrom: pink
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
@@ -10,4 +10,5 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
1 |
---
|
2 |
title: Anime BigGAN
|
3 |
+
emoji: (ミ●ﻌ●ミ)ฅ
|
4 |
colorFrom: pink
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
<center><h1>Anime-BigGAN</h1></center>
|
14 |
+
This is a Gradio Blocks app of <a href="https://github.com/HighCWu/anime_biggan_toy">HighCWu/anime_biggan_toy in github</a>.
|
app.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
import sys
|
9 |
+
from typing import List, Tuple
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from model import Generator
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
from moviepy.editor import *
|
19 |
+
|
20 |
+
|
21 |
+
def parse_args() -> argparse.Namespace:
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument('--device', type=str, default='cpu')
|
24 |
+
parser.add_argument('--theme', type=str)
|
25 |
+
parser.add_argument('--share', action='store_true')
|
26 |
+
parser.add_argument('--port', type=int)
|
27 |
+
parser.add_argument('--disable-queue',
|
28 |
+
dest='enable_queue',
|
29 |
+
action='store_false')
|
30 |
+
return parser.parse_args()
|
31 |
+
|
32 |
+
cache_mp4_path = [f'/tmp/{str(i).zfill(2)}.mp4' for i in range(50)]
|
33 |
+
path_iter = iter(cache_mp4_path)
|
34 |
+
|
35 |
+
class App:
|
36 |
+
'''
|
37 |
+
Construct refer to https://huggingface.co/spaces/Gradio-Blocks/StyleGAN-Human
|
38 |
+
'''
|
39 |
+
def __init__(self, device: torch.device):
|
40 |
+
self.device = device
|
41 |
+
self.model = self.load_model()
|
42 |
+
|
43 |
+
def load_model(self) -> nn.Module:
|
44 |
+
path = hf_hub_download('HighCWu/anime-biggan-pytorch',
|
45 |
+
f'pytorch_model.bin')
|
46 |
+
state_dict = torch.load(path, map_location='cpu')
|
47 |
+
model = Generator(
|
48 |
+
code_dim=140, n_class=1000, chn=96,
|
49 |
+
blocks_with_attention="B5", resolution=256
|
50 |
+
)
|
51 |
+
model.load_state_dict(state_dict)
|
52 |
+
model.eval()
|
53 |
+
model.to(self.device)
|
54 |
+
with torch.inference_mode():
|
55 |
+
z = torch.zeros((1, model.z_dim)).to(self.device)
|
56 |
+
label = torch.zeros([1, model.c_dim], device=self.device)
|
57 |
+
label[:,0] = 1
|
58 |
+
model(z, label)
|
59 |
+
return model
|
60 |
+
|
61 |
+
def get_levels(self) -> List[str]:
|
62 |
+
return [f'Level {i}' for i in range(self.model.n_level)]
|
63 |
+
|
64 |
+
def generate_z_label(self, z_dim: int, c_dim: int, seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
65 |
+
rng = np.random.RandomState(seed)
|
66 |
+
z = rng.randn(
|
67 |
+
1, z_dim)
|
68 |
+
label = rng.randint(0, c_dim, size=(1,))
|
69 |
+
z = torch.from_numpy(z).to(self.device).float()
|
70 |
+
label = torch.from_numpy(label).to(self.device).long()
|
71 |
+
label = torch.nn.functional.one_hot(label, 1000).float()
|
72 |
+
return z, label
|
73 |
+
|
74 |
+
@torch.inference_mode()
|
75 |
+
def generate_single_image(self, seed: int) -> np.ndarray:
|
76 |
+
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
|
77 |
+
|
78 |
+
z, label = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed)
|
79 |
+
|
80 |
+
out = self.model(z, label)
|
81 |
+
out = (out.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(
|
82 |
+
torch.uint8)
|
83 |
+
return out[0].cpu().numpy()
|
84 |
+
|
85 |
+
@torch.inference_mode()
|
86 |
+
def generate_interpolated_images(
|
87 |
+
self, seed0: int, seed1: int,
|
88 |
+
num_intermediate: int, levels: List[str]) -> List[np.ndarray]:
|
89 |
+
seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
|
90 |
+
seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
|
91 |
+
levels = [int(level.split(' ')[1]) for level in levels]
|
92 |
+
|
93 |
+
z0, label0 = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed0)
|
94 |
+
z1, label1 = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed1)
|
95 |
+
vec = z1 - z0
|
96 |
+
dvec = vec / (num_intermediate + 1)
|
97 |
+
zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
|
98 |
+
|
99 |
+
vec = label1 - label0
|
100 |
+
dvec = vec / (num_intermediate + 1)
|
101 |
+
labels = [label0 + dvec * i for i in range(num_intermediate + 2)]
|
102 |
+
|
103 |
+
res = []
|
104 |
+
for z, label in zip(zs, labels):
|
105 |
+
z0_split = list(torch.chunk(z0, self.model.n_level, 1))
|
106 |
+
z_split = list(torch.chunk(z, self.model.n_level, 1))
|
107 |
+
for j in levels:
|
108 |
+
z_split[j] = z0_split[j]
|
109 |
+
z = torch.cat(z_split, 1)
|
110 |
+
out = self.model(z, label)
|
111 |
+
out = (out.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(
|
112 |
+
torch.uint8)
|
113 |
+
out = out[0].cpu().numpy()
|
114 |
+
res.append(out)
|
115 |
+
|
116 |
+
fps = 1 / (5 / len(res))
|
117 |
+
video = ImageSequenceClip(res, fps=fps)
|
118 |
+
global path_iter
|
119 |
+
try:
|
120 |
+
video_path = next(path_iter)
|
121 |
+
except:
|
122 |
+
path_iter = iter(cache_mp4_path)
|
123 |
+
video_path = next(path_iter)
|
124 |
+
video.write_videofile(video_path, fps=fps)
|
125 |
+
|
126 |
+
return res, video_path
|
127 |
+
|
128 |
+
|
129 |
+
def main():
|
130 |
+
args = parse_args()
|
131 |
+
app = App(device=torch.device(args.device))
|
132 |
+
|
133 |
+
with gr.Blocks(theme=args.theme) as demo:
|
134 |
+
gr.Markdown('''<center><h1>Anime-BigGAN</h1></center>
|
135 |
+
This is a Gradio Blocks app of <a href="https://github.com/HighCWu/anime_biggan_toy">HighCWu/anime_biggan_toy in github</a>.
|
136 |
+
''')
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
with gr.Box():
|
140 |
+
with gr.Column():
|
141 |
+
with gr.Row():
|
142 |
+
with gr.Column():
|
143 |
+
with gr.Row():
|
144 |
+
seed1 = gr.Number(value=128, label='Seed 1')
|
145 |
+
with gr.Row():
|
146 |
+
generate_button1 = gr.Button('Generate')
|
147 |
+
with gr.Row():
|
148 |
+
generated_image1 = gr.Image(type='numpy', shape=(256,256),
|
149 |
+
label='Generated Image 1')
|
150 |
+
with gr.Column():
|
151 |
+
with gr.Row():
|
152 |
+
seed2 = gr.Number(value=6886, label='Seed 2')
|
153 |
+
with gr.Row():
|
154 |
+
generate_button2 = gr.Button('Generate')
|
155 |
+
with gr.Row():
|
156 |
+
generated_image2 = gr.Image(type='numpy', shape=(256,256),
|
157 |
+
label='Generated Image 2')
|
158 |
+
|
159 |
+
with gr.Row():
|
160 |
+
gr.Image(value='imgs/out1.png', type='filepath',
|
161 |
+
interactive=False, label='Sample results 1')
|
162 |
+
with gr.Row():
|
163 |
+
gr.Image(value='imgs/out2.png', type='filepath',
|
164 |
+
interactive=False, label='Sample results 2')
|
165 |
+
|
166 |
+
with gr.Box():
|
167 |
+
with gr.Column():
|
168 |
+
with gr.Row():
|
169 |
+
num_frames = gr.Slider(
|
170 |
+
0,
|
171 |
+
41,
|
172 |
+
value=7,
|
173 |
+
step=1,
|
174 |
+
label='Number of Intermediate Frames between image 1 and image 2')
|
175 |
+
with gr.Row():
|
176 |
+
level_choices = gr.CheckboxGroup(
|
177 |
+
choices=app.get_levels(),
|
178 |
+
label='Levels of latents to fix based on the first latent')
|
179 |
+
with gr.Row():
|
180 |
+
interpolate_button = gr.Button('Interpolate')
|
181 |
+
|
182 |
+
with gr.Row():
|
183 |
+
interpolated_images = gr.Gallery(label='Output Images')
|
184 |
+
with gr.Row():
|
185 |
+
interpolated_video = gr.Video(label='Output Video')
|
186 |
+
|
187 |
+
gr.Markdown(
|
188 |
+
'<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.anime-biggan" alt="visitor badge"/></center>'
|
189 |
+
)
|
190 |
+
|
191 |
+
generate_button1.click(app.generate_single_image,
|
192 |
+
inputs=[seed1],
|
193 |
+
outputs=generated_image1)
|
194 |
+
generate_button2.click(app.generate_single_image,
|
195 |
+
inputs=[seed2],
|
196 |
+
outputs=generated_image2)
|
197 |
+
interpolate_button.click(app.generate_interpolated_images,
|
198 |
+
inputs=[seed1, seed2, num_frames, level_choices],
|
199 |
+
outputs=[interpolated_images, interpolated_video])
|
200 |
+
|
201 |
+
demo.launch(
|
202 |
+
enable_queue=args.enable_queue,
|
203 |
+
server_port=args.port,
|
204 |
+
share=args.share,
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == '__main__':
|
209 |
+
main()
|
imgs/out1.png
ADDED
imgs/out2.png
ADDED
model.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#@title Define Generator and Discriminator model
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Parameter
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def l2_normalize(v, dim=None, eps=1e-12):
|
10 |
+
return v / (v.norm(dim=dim, keepdim=True) + eps)
|
11 |
+
|
12 |
+
|
13 |
+
def unpool(value):
|
14 |
+
"""Unpooling operation.
|
15 |
+
N-dimensional version of the unpooling operation from
|
16 |
+
https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf
|
17 |
+
Taken from: https://github.com/tensorflow/tensorflow/issues/2169
|
18 |
+
Args:
|
19 |
+
value: a Tensor of shape [b, d0, d1, ..., dn, ch]
|
20 |
+
name: name of the op
|
21 |
+
Returns:
|
22 |
+
A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch]
|
23 |
+
"""
|
24 |
+
value = torch.Tensor.permute(value, [0,2,3,1])
|
25 |
+
sh = list(value.shape)
|
26 |
+
dim = len(sh[1:-1])
|
27 |
+
out = (torch.reshape(value, [-1] + sh[-dim:]))
|
28 |
+
for i in range(dim, 0, -1):
|
29 |
+
out = torch.cat([out, torch.zeros_like(out)], i)
|
30 |
+
out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
|
31 |
+
out = torch.reshape(out, out_size)
|
32 |
+
out = torch.Tensor.permute(out, [0,3,1,2])
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class BatchNorm2d(nn.BatchNorm2d):
|
37 |
+
def __init__(self, *args, **kwargs):
|
38 |
+
super().__init__(*args, **kwargs)
|
39 |
+
self.initialized = False
|
40 |
+
self.accumulating = False
|
41 |
+
self.accumulated_mean = Parameter(torch.zeros(args[0]), requires_grad=False)
|
42 |
+
self.accumulated_var = Parameter(torch.zeros(args[0]), requires_grad=False)
|
43 |
+
self.accumulated_counter = Parameter(torch.zeros(1)+1e-12, requires_grad=False)
|
44 |
+
|
45 |
+
def forward(self, inputs, *args, **kwargs):
|
46 |
+
if not self.initialized:
|
47 |
+
self.check_accumulation()
|
48 |
+
self.set_initialized(True)
|
49 |
+
if self.accumulating:
|
50 |
+
self.eval()
|
51 |
+
with torch.no_grad():
|
52 |
+
axes = [0] + ([] if len(inputs.shape) == 2 else list(range(2,len(inputs.shape))))
|
53 |
+
_mean = torch.mean(inputs, axes, keepdim=True)
|
54 |
+
mean = torch.mean(inputs, axes, keepdim=False)
|
55 |
+
var = torch.mean((inputs-_mean)**2, axes)
|
56 |
+
self.accumulated_mean.copy_(self.accumulated_mean + mean)
|
57 |
+
self.accumulated_var.copy_(self.accumulated_var + var)
|
58 |
+
self.accumulated_counter.copy_(self.accumulated_counter + 1)
|
59 |
+
_mean = self.running_mean*1.0
|
60 |
+
_variance = self.running_var*1.0
|
61 |
+
self._mean.copy_(self.accumulated_mean / self.accumulated_counter)
|
62 |
+
self._variance.copy_(self.accumulated_var / self.accumulated_counter)
|
63 |
+
out = super().forward(inputs, *args, **kwargs)
|
64 |
+
self.running_mean.copy_(_mean)
|
65 |
+
self.running_var.copy_(_variance)
|
66 |
+
return out
|
67 |
+
out = super().forward(inputs, *args, **kwargs)
|
68 |
+
return out
|
69 |
+
|
70 |
+
def check_accumulation(self):
|
71 |
+
if self.accumulated_counter.detach().cpu().numpy().mean() > 1-1e-12:
|
72 |
+
self.running_mean.copy_(self.accumulated_mean / self.accumulated_counter)
|
73 |
+
self.running_var.copy_(self.accumulated_var / self.accumulated_counter)
|
74 |
+
return True
|
75 |
+
return False
|
76 |
+
|
77 |
+
def clear_accumulated(self):
|
78 |
+
self.accumulated_mean.copy_(self.accumulated_mean*0.0)
|
79 |
+
self.accumulated_var.copy_(self.accumulated_var*0.0)
|
80 |
+
self.accumulated_counter.copy_(self.accumulated_counter*0.0+1e-2)
|
81 |
+
|
82 |
+
def set_accumulating(self, status=True):
|
83 |
+
if status:
|
84 |
+
self.accumulating = True
|
85 |
+
else:
|
86 |
+
self.accumulating = False
|
87 |
+
|
88 |
+
def set_initialized(self, status=False):
|
89 |
+
if not status:
|
90 |
+
self.initialized = False
|
91 |
+
else:
|
92 |
+
self.initialized = True
|
93 |
+
|
94 |
+
|
95 |
+
class SpectralNorm(nn.Module):
|
96 |
+
def __init__(self, module, name='weight', power_iterations=2):
|
97 |
+
super().__init__()
|
98 |
+
self.module = module
|
99 |
+
self.name = name
|
100 |
+
self.power_iterations = power_iterations
|
101 |
+
if not self._made_params():
|
102 |
+
self._make_params()
|
103 |
+
|
104 |
+
def _update_u(self):
|
105 |
+
w = self.weight
|
106 |
+
u = self.weight_u
|
107 |
+
|
108 |
+
if len(w.shape) == 4:
|
109 |
+
_w = torch.Tensor.permute(w, [2,3,1,0])
|
110 |
+
_w = torch.reshape(_w, [-1, _w.shape[-1]])
|
111 |
+
elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
|
112 |
+
_w = torch.Tensor.permute(w, [1,0])
|
113 |
+
_w = torch.reshape(_w, [-1, _w.shape[-1]])
|
114 |
+
else:
|
115 |
+
_w = torch.reshape(w, [-1, w.shape[-1]])
|
116 |
+
_w = torch.reshape(_w, [-1, _w.shape[-1]])
|
117 |
+
singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
|
118 |
+
norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
|
119 |
+
for _ in range(self.power_iterations):
|
120 |
+
if singular_value == "left":
|
121 |
+
v = l2_normalize(torch.matmul(_w.t(), u), dim=norm_dim)
|
122 |
+
u = l2_normalize(torch.matmul(_w, v), dim=norm_dim)
|
123 |
+
else:
|
124 |
+
v = l2_normalize(torch.matmul(u, _w.t()), dim=norm_dim)
|
125 |
+
u = l2_normalize(torch.matmul(v, _w), dim=norm_dim)
|
126 |
+
|
127 |
+
if singular_value == "left":
|
128 |
+
sigma = torch.matmul(torch.matmul(u.t(), _w), v)
|
129 |
+
else:
|
130 |
+
sigma = torch.matmul(torch.matmul(v, _w), u.t())
|
131 |
+
_w = w / sigma.detach()
|
132 |
+
setattr(self.module, self.name, _w)
|
133 |
+
self.weight_u.copy_(u.detach())
|
134 |
+
|
135 |
+
def _made_params(self):
|
136 |
+
try:
|
137 |
+
self.weight
|
138 |
+
self.weight_u
|
139 |
+
return True
|
140 |
+
except AttributeError:
|
141 |
+
return False
|
142 |
+
|
143 |
+
def _make_params(self):
|
144 |
+
w = getattr(self.module, self.name)
|
145 |
+
|
146 |
+
if len(w.shape) == 4:
|
147 |
+
_w = torch.Tensor.permute(w, [2,3,1,0])
|
148 |
+
_w = torch.reshape(_w, [-1, _w.shape[-1]])
|
149 |
+
elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
|
150 |
+
_w = torch.Tensor.permute(w, [1,0])
|
151 |
+
_w = torch.reshape(_w, [-1, _w.shape[-1]])
|
152 |
+
else:
|
153 |
+
_w = torch.reshape(w, [-1, w.shape[-1]])
|
154 |
+
singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
|
155 |
+
norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
|
156 |
+
u_shape = (_w.shape[0], 1) if singular_value == "left" else (1, _w.shape[-1])
|
157 |
+
|
158 |
+
u = Parameter(w.data.new(*u_shape).normal_(0, 1), requires_grad=False)
|
159 |
+
u.copy_(l2_normalize(u, dim=norm_dim).detach())
|
160 |
+
|
161 |
+
del self.module._parameters[self.name]
|
162 |
+
self.weight = w
|
163 |
+
self.weight_u = u
|
164 |
+
|
165 |
+
def forward(self, *args, **kwargs):
|
166 |
+
self._update_u()
|
167 |
+
return self.module.forward(*args, **kwargs)
|
168 |
+
|
169 |
+
|
170 |
+
class SelfAttention(nn.Module):
|
171 |
+
def __init__(self, in_dim, activation=torch.relu):
|
172 |
+
super().__init__()
|
173 |
+
self.chanel_in = in_dim
|
174 |
+
self.activation = activation
|
175 |
+
|
176 |
+
self.theta = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
|
177 |
+
self.phi = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
|
178 |
+
self.pool = nn.MaxPool2d(2, 2)
|
179 |
+
self.g = SpectralNorm(nn.Conv2d(in_dim, in_dim // 2, 1, bias=False))
|
180 |
+
self.o_conv = SpectralNorm(nn.Conv2d(in_dim // 2, in_dim, 1, bias=False))
|
181 |
+
self.gamma = Parameter(torch.zeros(1))
|
182 |
+
|
183 |
+
def forward(self, x):
|
184 |
+
m_batchsize, C, width, height = x.shape
|
185 |
+
N = height * width
|
186 |
+
|
187 |
+
theta = self.theta(x)
|
188 |
+
phi = self.phi(x)
|
189 |
+
phi = self.pool(phi)
|
190 |
+
phi = torch.reshape(phi,(m_batchsize, -1, N // 4))
|
191 |
+
theta = torch.reshape(theta,(m_batchsize, -1, N))
|
192 |
+
theta = torch.Tensor.permute(theta,(0, 2, 1))
|
193 |
+
attention = torch.softmax(torch.bmm(theta, phi), -1)
|
194 |
+
g = self.g(x)
|
195 |
+
g = torch.reshape(self.pool(g),(m_batchsize, -1, N // 4))
|
196 |
+
attn_g = torch.reshape(torch.bmm(g, torch.Tensor.permute(attention,(0, 2, 1))),(m_batchsize, -1, width, height))
|
197 |
+
out = self.o_conv(attn_g)
|
198 |
+
return self.gamma * out + x
|
199 |
+
|
200 |
+
|
201 |
+
class ConditionalBatchNorm2d(nn.Module):
|
202 |
+
def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1):
|
203 |
+
super().__init__()
|
204 |
+
self.bn_in_cond = BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
|
205 |
+
self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
|
206 |
+
self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
|
207 |
+
|
208 |
+
def forward(self, x, y):
|
209 |
+
out = self.bn_in_cond(x)
|
210 |
+
|
211 |
+
if isinstance(y, list):
|
212 |
+
gamma, beta = y
|
213 |
+
out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
|
214 |
+
return out
|
215 |
+
|
216 |
+
gamma = self.gamma_embed(y)
|
217 |
+
# gamma = gamma + 1
|
218 |
+
beta = self.beta_embed(y)
|
219 |
+
out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
|
220 |
+
return out
|
221 |
+
|
222 |
+
|
223 |
+
class ResBlock(nn.Module):
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
in_channel,
|
227 |
+
out_channel,
|
228 |
+
kernel_size=[3, 3],
|
229 |
+
padding=1,
|
230 |
+
stride=1,
|
231 |
+
n_class=None,
|
232 |
+
conditional=True,
|
233 |
+
activation=torch.relu,
|
234 |
+
upsample=True,
|
235 |
+
downsample=False,
|
236 |
+
z_dim=128,
|
237 |
+
use_attention=False,
|
238 |
+
skip_proj=None
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
|
242 |
+
if conditional:
|
243 |
+
self.cond_norm1 = ConditionalBatchNorm2d(in_channel, z_dim)
|
244 |
+
|
245 |
+
self.conv0 = SpectralNorm(
|
246 |
+
nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)
|
247 |
+
)
|
248 |
+
|
249 |
+
if conditional:
|
250 |
+
self.cond_norm2 = ConditionalBatchNorm2d(out_channel, z_dim)
|
251 |
+
|
252 |
+
self.conv1 = SpectralNorm(
|
253 |
+
nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding)
|
254 |
+
)
|
255 |
+
|
256 |
+
self.skip_proj = False
|
257 |
+
if skip_proj is not True and (upsample or downsample):
|
258 |
+
self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
|
259 |
+
self.skip_proj = True
|
260 |
+
|
261 |
+
if use_attention:
|
262 |
+
self.attention = SelfAttention(out_channel)
|
263 |
+
|
264 |
+
self.upsample = upsample
|
265 |
+
self.downsample = downsample
|
266 |
+
self.activation = activation
|
267 |
+
self.conditional = conditional
|
268 |
+
self.use_attention = use_attention
|
269 |
+
|
270 |
+
def forward(self, input, condition=None):
|
271 |
+
out = input
|
272 |
+
|
273 |
+
if self.conditional:
|
274 |
+
out = self.cond_norm1(out, condition if not isinstance(condition, list) else condition[0])
|
275 |
+
out = self.activation(out)
|
276 |
+
if self.upsample:
|
277 |
+
out = unpool(out) # out = F.interpolate(out, scale_factor=2)
|
278 |
+
out = self.conv0(out)
|
279 |
+
if self.conditional:
|
280 |
+
out = self.cond_norm2(out, condition if not isinstance(condition, list) else condition[1])
|
281 |
+
out = self.activation(out)
|
282 |
+
out = self.conv1(out)
|
283 |
+
|
284 |
+
if self.downsample:
|
285 |
+
out = F.avg_pool2d(out, 2, 2)
|
286 |
+
|
287 |
+
if self.skip_proj:
|
288 |
+
skip = input
|
289 |
+
if self.upsample:
|
290 |
+
skip = unpool(skip) # skip = F.interpolate(skip, scale_factor=2)
|
291 |
+
skip = self.conv_sc(skip)
|
292 |
+
if self.downsample:
|
293 |
+
skip = F.avg_pool2d(skip, 2, 2)
|
294 |
+
out = out + skip
|
295 |
+
else:
|
296 |
+
skip = input
|
297 |
+
|
298 |
+
if self.use_attention:
|
299 |
+
out = self.attention(out)
|
300 |
+
|
301 |
+
return out
|
302 |
+
|
303 |
+
|
304 |
+
class Generator(nn.Module):
|
305 |
+
def __init__(self, code_dim=128, n_class=1000, chn=96, blocks_with_attention="B4", resolution=512):
|
306 |
+
super().__init__()
|
307 |
+
|
308 |
+
def GBlock(in_channel, out_channel, n_class, z_dim, use_attention):
|
309 |
+
return ResBlock(in_channel, out_channel, n_class=n_class, z_dim=z_dim, use_attention=use_attention)
|
310 |
+
|
311 |
+
self.embed_y = nn.Linear(n_class, 128, bias=False)
|
312 |
+
|
313 |
+
self.chn = chn
|
314 |
+
self.resolution = resolution
|
315 |
+
self.blocks_with_attention = set(blocks_with_attention.split(","))
|
316 |
+
self.blocks_with_attention.discard('')
|
317 |
+
|
318 |
+
gblock = []
|
319 |
+
in_channels, out_channels = self.get_in_out_channels()
|
320 |
+
self.num_split = len(in_channels) + 1
|
321 |
+
|
322 |
+
z_dim = code_dim//self.num_split + 128
|
323 |
+
self.noise_fc = SpectralNorm(nn.Linear(code_dim//self.num_split, 4 * 4 * in_channels[0]))
|
324 |
+
|
325 |
+
self.sa_ids = [int(s.split('B')[-1]) for s in self.blocks_with_attention]
|
326 |
+
|
327 |
+
for i, (nc_in, nc_out) in enumerate(zip(in_channels, out_channels)):
|
328 |
+
gblock.append(GBlock(nc_in, nc_out, n_class=n_class, z_dim=z_dim, use_attention=(i+1) in self.sa_ids))
|
329 |
+
self.blocks = nn.ModuleList(gblock)
|
330 |
+
|
331 |
+
self.output_layer_bn = BatchNorm2d(1 * chn, eps=1e-5)
|
332 |
+
self.output_layer_conv = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
|
333 |
+
|
334 |
+
self.z_dim = code_dim
|
335 |
+
self.c_dim = n_class
|
336 |
+
self.n_level = self.num_split
|
337 |
+
|
338 |
+
def get_in_out_channels(self):
|
339 |
+
resolution = self.resolution
|
340 |
+
if resolution == 1024:
|
341 |
+
channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1, 1]
|
342 |
+
elif resolution == 512:
|
343 |
+
channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1]
|
344 |
+
elif resolution == 256:
|
345 |
+
channel_multipliers = [16, 16, 8, 8, 4, 2, 1]
|
346 |
+
elif resolution == 128:
|
347 |
+
channel_multipliers = [16, 16, 8, 4, 2, 1]
|
348 |
+
elif resolution == 64:
|
349 |
+
channel_multipliers = [16, 16, 8, 4, 2]
|
350 |
+
elif resolution == 32:
|
351 |
+
channel_multipliers = [4, 4, 4, 4]
|
352 |
+
else:
|
353 |
+
raise ValueError("Unsupported resolution: {}".format(resolution))
|
354 |
+
in_channels = [self.chn * c for c in channel_multipliers[:-1]]
|
355 |
+
out_channels = [self.chn * c for c in channel_multipliers[1:]]
|
356 |
+
return in_channels, out_channels
|
357 |
+
|
358 |
+
def forward(self, input, class_id):
|
359 |
+
codes = torch.chunk(input, self.num_split, 1)
|
360 |
+
class_emb = self.embed_y(class_id) # 128
|
361 |
+
out = self.noise_fc(codes[0])
|
362 |
+
out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
|
363 |
+
for i, (code, gblock) in enumerate(zip(codes[1:], self.blocks)):
|
364 |
+
condition = torch.cat([code, class_emb], 1)
|
365 |
+
out = gblock(out, condition)
|
366 |
+
|
367 |
+
out = self.output_layer_bn(out)
|
368 |
+
out = torch.relu(out)
|
369 |
+
out = self.output_layer_conv(out)
|
370 |
+
|
371 |
+
return (torch.tanh(out) + 1) / 2
|
372 |
+
|
373 |
+
def forward_w(self, ws):
|
374 |
+
out = self.noise_fc(ws[0])
|
375 |
+
out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
|
376 |
+
for i, (w, gblock) in enumerate(zip(ws[1:], self.blocks)):
|
377 |
+
out = gblock(out, w)
|
378 |
+
|
379 |
+
out = self.output_layer_bn(out)
|
380 |
+
out = torch.relu(out)
|
381 |
+
out = self.output_layer_conv(out)
|
382 |
+
|
383 |
+
return (torch.tanh(out) + 1) / 2
|
384 |
+
|
385 |
+
def forward_wp(self, z0, gammas, betas):
|
386 |
+
out = self.noise_fc(z0)
|
387 |
+
out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
|
388 |
+
for i, (gamma, beta, gblock) in enumerate(zip(gammas, betas, self.blocks)):
|
389 |
+
out = gblock(out, [[gamma[0], beta[0]], [gamma[1], beta[1]]])
|
390 |
+
|
391 |
+
out = self.output_layer_bn(out)
|
392 |
+
out = torch.relu(out)
|
393 |
+
out = self.output_layer_conv(out)
|
394 |
+
|
395 |
+
return (torch.tanh(out) + 1) / 2
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.22.3
|
2 |
+
Pillow==9.0.1
|
3 |
+
scipy==1.8.0
|
4 |
+
torch==1.11.0
|
5 |
+
torchvision==0.12.0
|
6 |
+
gradio==3.0.3
|
7 |
+
huggingface-hub==0.6.0
|
8 |
+
moviepy==1.0.3
|