HighCWu commited on
Commit
bfa0d3e
·
1 Parent(s): 8668c0c

Init commits

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. README.md +3 -2
  3. app.py +209 -0
  4. imgs/out1.png +0 -0
  5. imgs/out2.png +0 -0
  6. model.py +395 -0
  7. packages.txt +1 -0
  8. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio_queue*
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Anime BigGAN
3
- emoji:
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
1
  ---
2
  title: Anime BigGAN
3
+ emoji: (ミ●ﻌ●ミ)ฅ
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
 
10
  license: mit
11
  ---
12
 
13
+ <center><h1>Anime-BigGAN</h1></center>
14
+ This is a Gradio Blocks app of <a href="https://github.com/HighCWu/anime_biggan_toy">HighCWu/anime_biggan_toy in github</a>.
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import os
7
+ import pickle
8
+ import sys
9
+ from typing import List, Tuple
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from model import Generator
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from moviepy.editor import *
19
+
20
+
21
+ def parse_args() -> argparse.Namespace:
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--device', type=str, default='cpu')
24
+ parser.add_argument('--theme', type=str)
25
+ parser.add_argument('--share', action='store_true')
26
+ parser.add_argument('--port', type=int)
27
+ parser.add_argument('--disable-queue',
28
+ dest='enable_queue',
29
+ action='store_false')
30
+ return parser.parse_args()
31
+
32
+ cache_mp4_path = [f'/tmp/{str(i).zfill(2)}.mp4' for i in range(50)]
33
+ path_iter = iter(cache_mp4_path)
34
+
35
+ class App:
36
+ '''
37
+ Construct refer to https://huggingface.co/spaces/Gradio-Blocks/StyleGAN-Human
38
+ '''
39
+ def __init__(self, device: torch.device):
40
+ self.device = device
41
+ self.model = self.load_model()
42
+
43
+ def load_model(self) -> nn.Module:
44
+ path = hf_hub_download('HighCWu/anime-biggan-pytorch',
45
+ f'pytorch_model.bin')
46
+ state_dict = torch.load(path, map_location='cpu')
47
+ model = Generator(
48
+ code_dim=140, n_class=1000, chn=96,
49
+ blocks_with_attention="B5", resolution=256
50
+ )
51
+ model.load_state_dict(state_dict)
52
+ model.eval()
53
+ model.to(self.device)
54
+ with torch.inference_mode():
55
+ z = torch.zeros((1, model.z_dim)).to(self.device)
56
+ label = torch.zeros([1, model.c_dim], device=self.device)
57
+ label[:,0] = 1
58
+ model(z, label)
59
+ return model
60
+
61
+ def get_levels(self) -> List[str]:
62
+ return [f'Level {i}' for i in range(self.model.n_level)]
63
+
64
+ def generate_z_label(self, z_dim: int, c_dim: int, seed: int) -> Tuple[torch.Tensor, torch.Tensor]:
65
+ rng = np.random.RandomState(seed)
66
+ z = rng.randn(
67
+ 1, z_dim)
68
+ label = rng.randint(0, c_dim, size=(1,))
69
+ z = torch.from_numpy(z).to(self.device).float()
70
+ label = torch.from_numpy(label).to(self.device).long()
71
+ label = torch.nn.functional.one_hot(label, 1000).float()
72
+ return z, label
73
+
74
+ @torch.inference_mode()
75
+ def generate_single_image(self, seed: int) -> np.ndarray:
76
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
77
+
78
+ z, label = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed)
79
+
80
+ out = self.model(z, label)
81
+ out = (out.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(
82
+ torch.uint8)
83
+ return out[0].cpu().numpy()
84
+
85
+ @torch.inference_mode()
86
+ def generate_interpolated_images(
87
+ self, seed0: int, seed1: int,
88
+ num_intermediate: int, levels: List[str]) -> List[np.ndarray]:
89
+ seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
90
+ seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
91
+ levels = [int(level.split(' ')[1]) for level in levels]
92
+
93
+ z0, label0 = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed0)
94
+ z1, label1 = self.generate_z_label(self.model.z_dim, self.model.c_dim, seed1)
95
+ vec = z1 - z0
96
+ dvec = vec / (num_intermediate + 1)
97
+ zs = [z0 + dvec * i for i in range(num_intermediate + 2)]
98
+
99
+ vec = label1 - label0
100
+ dvec = vec / (num_intermediate + 1)
101
+ labels = [label0 + dvec * i for i in range(num_intermediate + 2)]
102
+
103
+ res = []
104
+ for z, label in zip(zs, labels):
105
+ z0_split = list(torch.chunk(z0, self.model.n_level, 1))
106
+ z_split = list(torch.chunk(z, self.model.n_level, 1))
107
+ for j in levels:
108
+ z_split[j] = z0_split[j]
109
+ z = torch.cat(z_split, 1)
110
+ out = self.model(z, label)
111
+ out = (out.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(
112
+ torch.uint8)
113
+ out = out[0].cpu().numpy()
114
+ res.append(out)
115
+
116
+ fps = 1 / (5 / len(res))
117
+ video = ImageSequenceClip(res, fps=fps)
118
+ global path_iter
119
+ try:
120
+ video_path = next(path_iter)
121
+ except:
122
+ path_iter = iter(cache_mp4_path)
123
+ video_path = next(path_iter)
124
+ video.write_videofile(video_path, fps=fps)
125
+
126
+ return res, video_path
127
+
128
+
129
+ def main():
130
+ args = parse_args()
131
+ app = App(device=torch.device(args.device))
132
+
133
+ with gr.Blocks(theme=args.theme) as demo:
134
+ gr.Markdown('''<center><h1>Anime-BigGAN</h1></center>
135
+ This is a Gradio Blocks app of <a href="https://github.com/HighCWu/anime_biggan_toy">HighCWu/anime_biggan_toy in github</a>.
136
+ ''')
137
+
138
+ with gr.Row():
139
+ with gr.Box():
140
+ with gr.Column():
141
+ with gr.Row():
142
+ with gr.Column():
143
+ with gr.Row():
144
+ seed1 = gr.Number(value=128, label='Seed 1')
145
+ with gr.Row():
146
+ generate_button1 = gr.Button('Generate')
147
+ with gr.Row():
148
+ generated_image1 = gr.Image(type='numpy', shape=(256,256),
149
+ label='Generated Image 1')
150
+ with gr.Column():
151
+ with gr.Row():
152
+ seed2 = gr.Number(value=6886, label='Seed 2')
153
+ with gr.Row():
154
+ generate_button2 = gr.Button('Generate')
155
+ with gr.Row():
156
+ generated_image2 = gr.Image(type='numpy', shape=(256,256),
157
+ label='Generated Image 2')
158
+
159
+ with gr.Row():
160
+ gr.Image(value='imgs/out1.png', type='filepath',
161
+ interactive=False, label='Sample results 1')
162
+ with gr.Row():
163
+ gr.Image(value='imgs/out2.png', type='filepath',
164
+ interactive=False, label='Sample results 2')
165
+
166
+ with gr.Box():
167
+ with gr.Column():
168
+ with gr.Row():
169
+ num_frames = gr.Slider(
170
+ 0,
171
+ 41,
172
+ value=7,
173
+ step=1,
174
+ label='Number of Intermediate Frames between image 1 and image 2')
175
+ with gr.Row():
176
+ level_choices = gr.CheckboxGroup(
177
+ choices=app.get_levels(),
178
+ label='Levels of latents to fix based on the first latent')
179
+ with gr.Row():
180
+ interpolate_button = gr.Button('Interpolate')
181
+
182
+ with gr.Row():
183
+ interpolated_images = gr.Gallery(label='Output Images')
184
+ with gr.Row():
185
+ interpolated_video = gr.Video(label='Output Video')
186
+
187
+ gr.Markdown(
188
+ '<center><img src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.anime-biggan" alt="visitor badge"/></center>'
189
+ )
190
+
191
+ generate_button1.click(app.generate_single_image,
192
+ inputs=[seed1],
193
+ outputs=generated_image1)
194
+ generate_button2.click(app.generate_single_image,
195
+ inputs=[seed2],
196
+ outputs=generated_image2)
197
+ interpolate_button.click(app.generate_interpolated_images,
198
+ inputs=[seed1, seed2, num_frames, level_choices],
199
+ outputs=[interpolated_images, interpolated_video])
200
+
201
+ demo.launch(
202
+ enable_queue=args.enable_queue,
203
+ server_port=args.port,
204
+ share=args.share,
205
+ )
206
+
207
+
208
+ if __name__ == '__main__':
209
+ main()
imgs/out1.png ADDED
imgs/out2.png ADDED
model.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Define Generator and Discriminator model
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Parameter
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def l2_normalize(v, dim=None, eps=1e-12):
10
+ return v / (v.norm(dim=dim, keepdim=True) + eps)
11
+
12
+
13
+ def unpool(value):
14
+ """Unpooling operation.
15
+ N-dimensional version of the unpooling operation from
16
+ https://www.robots.ox.ac.uk/~vgg/rg/papers/Dosovitskiy_Learning_to_Generate_2015_CVPR_paper.pdf
17
+ Taken from: https://github.com/tensorflow/tensorflow/issues/2169
18
+ Args:
19
+ value: a Tensor of shape [b, d0, d1, ..., dn, ch]
20
+ name: name of the op
21
+ Returns:
22
+ A Tensor of shape [b, 2*d0, 2*d1, ..., 2*dn, ch]
23
+ """
24
+ value = torch.Tensor.permute(value, [0,2,3,1])
25
+ sh = list(value.shape)
26
+ dim = len(sh[1:-1])
27
+ out = (torch.reshape(value, [-1] + sh[-dim:]))
28
+ for i in range(dim, 0, -1):
29
+ out = torch.cat([out, torch.zeros_like(out)], i)
30
+ out_size = [-1] + [s * 2 for s in sh[1:-1]] + [sh[-1]]
31
+ out = torch.reshape(out, out_size)
32
+ out = torch.Tensor.permute(out, [0,3,1,2])
33
+ return out
34
+
35
+
36
+ class BatchNorm2d(nn.BatchNorm2d):
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ self.initialized = False
40
+ self.accumulating = False
41
+ self.accumulated_mean = Parameter(torch.zeros(args[0]), requires_grad=False)
42
+ self.accumulated_var = Parameter(torch.zeros(args[0]), requires_grad=False)
43
+ self.accumulated_counter = Parameter(torch.zeros(1)+1e-12, requires_grad=False)
44
+
45
+ def forward(self, inputs, *args, **kwargs):
46
+ if not self.initialized:
47
+ self.check_accumulation()
48
+ self.set_initialized(True)
49
+ if self.accumulating:
50
+ self.eval()
51
+ with torch.no_grad():
52
+ axes = [0] + ([] if len(inputs.shape) == 2 else list(range(2,len(inputs.shape))))
53
+ _mean = torch.mean(inputs, axes, keepdim=True)
54
+ mean = torch.mean(inputs, axes, keepdim=False)
55
+ var = torch.mean((inputs-_mean)**2, axes)
56
+ self.accumulated_mean.copy_(self.accumulated_mean + mean)
57
+ self.accumulated_var.copy_(self.accumulated_var + var)
58
+ self.accumulated_counter.copy_(self.accumulated_counter + 1)
59
+ _mean = self.running_mean*1.0
60
+ _variance = self.running_var*1.0
61
+ self._mean.copy_(self.accumulated_mean / self.accumulated_counter)
62
+ self._variance.copy_(self.accumulated_var / self.accumulated_counter)
63
+ out = super().forward(inputs, *args, **kwargs)
64
+ self.running_mean.copy_(_mean)
65
+ self.running_var.copy_(_variance)
66
+ return out
67
+ out = super().forward(inputs, *args, **kwargs)
68
+ return out
69
+
70
+ def check_accumulation(self):
71
+ if self.accumulated_counter.detach().cpu().numpy().mean() > 1-1e-12:
72
+ self.running_mean.copy_(self.accumulated_mean / self.accumulated_counter)
73
+ self.running_var.copy_(self.accumulated_var / self.accumulated_counter)
74
+ return True
75
+ return False
76
+
77
+ def clear_accumulated(self):
78
+ self.accumulated_mean.copy_(self.accumulated_mean*0.0)
79
+ self.accumulated_var.copy_(self.accumulated_var*0.0)
80
+ self.accumulated_counter.copy_(self.accumulated_counter*0.0+1e-2)
81
+
82
+ def set_accumulating(self, status=True):
83
+ if status:
84
+ self.accumulating = True
85
+ else:
86
+ self.accumulating = False
87
+
88
+ def set_initialized(self, status=False):
89
+ if not status:
90
+ self.initialized = False
91
+ else:
92
+ self.initialized = True
93
+
94
+
95
+ class SpectralNorm(nn.Module):
96
+ def __init__(self, module, name='weight', power_iterations=2):
97
+ super().__init__()
98
+ self.module = module
99
+ self.name = name
100
+ self.power_iterations = power_iterations
101
+ if not self._made_params():
102
+ self._make_params()
103
+
104
+ def _update_u(self):
105
+ w = self.weight
106
+ u = self.weight_u
107
+
108
+ if len(w.shape) == 4:
109
+ _w = torch.Tensor.permute(w, [2,3,1,0])
110
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
111
+ elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
112
+ _w = torch.Tensor.permute(w, [1,0])
113
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
114
+ else:
115
+ _w = torch.reshape(w, [-1, w.shape[-1]])
116
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
117
+ singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
118
+ norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
119
+ for _ in range(self.power_iterations):
120
+ if singular_value == "left":
121
+ v = l2_normalize(torch.matmul(_w.t(), u), dim=norm_dim)
122
+ u = l2_normalize(torch.matmul(_w, v), dim=norm_dim)
123
+ else:
124
+ v = l2_normalize(torch.matmul(u, _w.t()), dim=norm_dim)
125
+ u = l2_normalize(torch.matmul(v, _w), dim=norm_dim)
126
+
127
+ if singular_value == "left":
128
+ sigma = torch.matmul(torch.matmul(u.t(), _w), v)
129
+ else:
130
+ sigma = torch.matmul(torch.matmul(v, _w), u.t())
131
+ _w = w / sigma.detach()
132
+ setattr(self.module, self.name, _w)
133
+ self.weight_u.copy_(u.detach())
134
+
135
+ def _made_params(self):
136
+ try:
137
+ self.weight
138
+ self.weight_u
139
+ return True
140
+ except AttributeError:
141
+ return False
142
+
143
+ def _make_params(self):
144
+ w = getattr(self.module, self.name)
145
+
146
+ if len(w.shape) == 4:
147
+ _w = torch.Tensor.permute(w, [2,3,1,0])
148
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
149
+ elif isinstance(self.module, nn.Linear) or isinstance(self.module, nn.Embedding):
150
+ _w = torch.Tensor.permute(w, [1,0])
151
+ _w = torch.reshape(_w, [-1, _w.shape[-1]])
152
+ else:
153
+ _w = torch.reshape(w, [-1, w.shape[-1]])
154
+ singular_value = "left" if _w.shape[0] <= _w.shape[1] else "right"
155
+ norm_dim = 0 if _w.shape[0] <= _w.shape[1] else 1
156
+ u_shape = (_w.shape[0], 1) if singular_value == "left" else (1, _w.shape[-1])
157
+
158
+ u = Parameter(w.data.new(*u_shape).normal_(0, 1), requires_grad=False)
159
+ u.copy_(l2_normalize(u, dim=norm_dim).detach())
160
+
161
+ del self.module._parameters[self.name]
162
+ self.weight = w
163
+ self.weight_u = u
164
+
165
+ def forward(self, *args, **kwargs):
166
+ self._update_u()
167
+ return self.module.forward(*args, **kwargs)
168
+
169
+
170
+ class SelfAttention(nn.Module):
171
+ def __init__(self, in_dim, activation=torch.relu):
172
+ super().__init__()
173
+ self.chanel_in = in_dim
174
+ self.activation = activation
175
+
176
+ self.theta = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
177
+ self.phi = SpectralNorm(nn.Conv2d(in_dim, in_dim // 8, 1, bias=False))
178
+ self.pool = nn.MaxPool2d(2, 2)
179
+ self.g = SpectralNorm(nn.Conv2d(in_dim, in_dim // 2, 1, bias=False))
180
+ self.o_conv = SpectralNorm(nn.Conv2d(in_dim // 2, in_dim, 1, bias=False))
181
+ self.gamma = Parameter(torch.zeros(1))
182
+
183
+ def forward(self, x):
184
+ m_batchsize, C, width, height = x.shape
185
+ N = height * width
186
+
187
+ theta = self.theta(x)
188
+ phi = self.phi(x)
189
+ phi = self.pool(phi)
190
+ phi = torch.reshape(phi,(m_batchsize, -1, N // 4))
191
+ theta = torch.reshape(theta,(m_batchsize, -1, N))
192
+ theta = torch.Tensor.permute(theta,(0, 2, 1))
193
+ attention = torch.softmax(torch.bmm(theta, phi), -1)
194
+ g = self.g(x)
195
+ g = torch.reshape(self.pool(g),(m_batchsize, -1, N // 4))
196
+ attn_g = torch.reshape(torch.bmm(g, torch.Tensor.permute(attention,(0, 2, 1))),(m_batchsize, -1, width, height))
197
+ out = self.o_conv(attn_g)
198
+ return self.gamma * out + x
199
+
200
+
201
+ class ConditionalBatchNorm2d(nn.Module):
202
+ def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1):
203
+ super().__init__()
204
+ self.bn_in_cond = BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
205
+ self.gamma_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
206
+ self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
207
+
208
+ def forward(self, x, y):
209
+ out = self.bn_in_cond(x)
210
+
211
+ if isinstance(y, list):
212
+ gamma, beta = y
213
+ out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
214
+ return out
215
+
216
+ gamma = self.gamma_embed(y)
217
+ # gamma = gamma + 1
218
+ beta = self.beta_embed(y)
219
+ out = torch.reshape(gamma, (gamma.shape[0], -1, 1, 1)) * out + torch.reshape(beta, (beta.shape[0], -1, 1, 1))
220
+ return out
221
+
222
+
223
+ class ResBlock(nn.Module):
224
+ def __init__(
225
+ self,
226
+ in_channel,
227
+ out_channel,
228
+ kernel_size=[3, 3],
229
+ padding=1,
230
+ stride=1,
231
+ n_class=None,
232
+ conditional=True,
233
+ activation=torch.relu,
234
+ upsample=True,
235
+ downsample=False,
236
+ z_dim=128,
237
+ use_attention=False,
238
+ skip_proj=None
239
+ ):
240
+ super().__init__()
241
+
242
+ if conditional:
243
+ self.cond_norm1 = ConditionalBatchNorm2d(in_channel, z_dim)
244
+
245
+ self.conv0 = SpectralNorm(
246
+ nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)
247
+ )
248
+
249
+ if conditional:
250
+ self.cond_norm2 = ConditionalBatchNorm2d(out_channel, z_dim)
251
+
252
+ self.conv1 = SpectralNorm(
253
+ nn.Conv2d(out_channel, out_channel, kernel_size, stride, padding)
254
+ )
255
+
256
+ self.skip_proj = False
257
+ if skip_proj is not True and (upsample or downsample):
258
+ self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
259
+ self.skip_proj = True
260
+
261
+ if use_attention:
262
+ self.attention = SelfAttention(out_channel)
263
+
264
+ self.upsample = upsample
265
+ self.downsample = downsample
266
+ self.activation = activation
267
+ self.conditional = conditional
268
+ self.use_attention = use_attention
269
+
270
+ def forward(self, input, condition=None):
271
+ out = input
272
+
273
+ if self.conditional:
274
+ out = self.cond_norm1(out, condition if not isinstance(condition, list) else condition[0])
275
+ out = self.activation(out)
276
+ if self.upsample:
277
+ out = unpool(out) # out = F.interpolate(out, scale_factor=2)
278
+ out = self.conv0(out)
279
+ if self.conditional:
280
+ out = self.cond_norm2(out, condition if not isinstance(condition, list) else condition[1])
281
+ out = self.activation(out)
282
+ out = self.conv1(out)
283
+
284
+ if self.downsample:
285
+ out = F.avg_pool2d(out, 2, 2)
286
+
287
+ if self.skip_proj:
288
+ skip = input
289
+ if self.upsample:
290
+ skip = unpool(skip) # skip = F.interpolate(skip, scale_factor=2)
291
+ skip = self.conv_sc(skip)
292
+ if self.downsample:
293
+ skip = F.avg_pool2d(skip, 2, 2)
294
+ out = out + skip
295
+ else:
296
+ skip = input
297
+
298
+ if self.use_attention:
299
+ out = self.attention(out)
300
+
301
+ return out
302
+
303
+
304
+ class Generator(nn.Module):
305
+ def __init__(self, code_dim=128, n_class=1000, chn=96, blocks_with_attention="B4", resolution=512):
306
+ super().__init__()
307
+
308
+ def GBlock(in_channel, out_channel, n_class, z_dim, use_attention):
309
+ return ResBlock(in_channel, out_channel, n_class=n_class, z_dim=z_dim, use_attention=use_attention)
310
+
311
+ self.embed_y = nn.Linear(n_class, 128, bias=False)
312
+
313
+ self.chn = chn
314
+ self.resolution = resolution
315
+ self.blocks_with_attention = set(blocks_with_attention.split(","))
316
+ self.blocks_with_attention.discard('')
317
+
318
+ gblock = []
319
+ in_channels, out_channels = self.get_in_out_channels()
320
+ self.num_split = len(in_channels) + 1
321
+
322
+ z_dim = code_dim//self.num_split + 128
323
+ self.noise_fc = SpectralNorm(nn.Linear(code_dim//self.num_split, 4 * 4 * in_channels[0]))
324
+
325
+ self.sa_ids = [int(s.split('B')[-1]) for s in self.blocks_with_attention]
326
+
327
+ for i, (nc_in, nc_out) in enumerate(zip(in_channels, out_channels)):
328
+ gblock.append(GBlock(nc_in, nc_out, n_class=n_class, z_dim=z_dim, use_attention=(i+1) in self.sa_ids))
329
+ self.blocks = nn.ModuleList(gblock)
330
+
331
+ self.output_layer_bn = BatchNorm2d(1 * chn, eps=1e-5)
332
+ self.output_layer_conv = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
333
+
334
+ self.z_dim = code_dim
335
+ self.c_dim = n_class
336
+ self.n_level = self.num_split
337
+
338
+ def get_in_out_channels(self):
339
+ resolution = self.resolution
340
+ if resolution == 1024:
341
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1, 1]
342
+ elif resolution == 512:
343
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1, 1]
344
+ elif resolution == 256:
345
+ channel_multipliers = [16, 16, 8, 8, 4, 2, 1]
346
+ elif resolution == 128:
347
+ channel_multipliers = [16, 16, 8, 4, 2, 1]
348
+ elif resolution == 64:
349
+ channel_multipliers = [16, 16, 8, 4, 2]
350
+ elif resolution == 32:
351
+ channel_multipliers = [4, 4, 4, 4]
352
+ else:
353
+ raise ValueError("Unsupported resolution: {}".format(resolution))
354
+ in_channels = [self.chn * c for c in channel_multipliers[:-1]]
355
+ out_channels = [self.chn * c for c in channel_multipliers[1:]]
356
+ return in_channels, out_channels
357
+
358
+ def forward(self, input, class_id):
359
+ codes = torch.chunk(input, self.num_split, 1)
360
+ class_emb = self.embed_y(class_id) # 128
361
+ out = self.noise_fc(codes[0])
362
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
363
+ for i, (code, gblock) in enumerate(zip(codes[1:], self.blocks)):
364
+ condition = torch.cat([code, class_emb], 1)
365
+ out = gblock(out, condition)
366
+
367
+ out = self.output_layer_bn(out)
368
+ out = torch.relu(out)
369
+ out = self.output_layer_conv(out)
370
+
371
+ return (torch.tanh(out) + 1) / 2
372
+
373
+ def forward_w(self, ws):
374
+ out = self.noise_fc(ws[0])
375
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
376
+ for i, (w, gblock) in enumerate(zip(ws[1:], self.blocks)):
377
+ out = gblock(out, w)
378
+
379
+ out = self.output_layer_bn(out)
380
+ out = torch.relu(out)
381
+ out = self.output_layer_conv(out)
382
+
383
+ return (torch.tanh(out) + 1) / 2
384
+
385
+ def forward_wp(self, z0, gammas, betas):
386
+ out = self.noise_fc(z0)
387
+ out = torch.Tensor.permute(torch.reshape(out,(out.shape[0], 4, 4, -1)),(0, 3, 1, 2))
388
+ for i, (gamma, beta, gblock) in enumerate(zip(gammas, betas, self.blocks)):
389
+ out = gblock(out, [[gamma[0], beta[0]], [gamma[1], beta[1]]])
390
+
391
+ out = self.output_layer_bn(out)
392
+ out = torch.relu(out)
393
+ out = self.output_layer_conv(out)
394
+
395
+ return (torch.tanh(out) + 1) / 2
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.22.3
2
+ Pillow==9.0.1
3
+ scipy==1.8.0
4
+ torch==1.11.0
5
+ torchvision==0.12.0
6
+ gradio==3.0.3
7
+ huggingface-hub==0.6.0
8
+ moviepy==1.0.3