johnowhitaker commited on
Commit
ef231cd
1 Parent(s): 8f8cbb9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +477 -0
app.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #@title Gradio demo (used in space: )
2
+
3
+ from matplotlib import pyplot as plt
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ import numpy as np
6
+ import gradio as gr
7
+
8
+ ### A BIG CHUNK OF THIS IS COPIED FROM LIGHTWEIGHTGAN since the original has an assert requiring GPU
9
+
10
+ import os
11
+ import json
12
+ import multiprocessing
13
+ from random import random
14
+ import math
15
+ from math import log2, floor
16
+ from functools import partial
17
+ from contextlib import contextmanager, ExitStack
18
+ from pathlib import Path
19
+ from shutil import rmtree
20
+
21
+ import torch
22
+ from torch.cuda.amp import autocast, GradScaler
23
+ from torch.optim import Adam
24
+ from torch import nn, einsum
25
+ import torch.nn.functional as F
26
+ from torch.utils.data import Dataset, DataLoader
27
+ from torch.autograd import grad as torch_grad
28
+ from torch.utils.data.distributed import DistributedSampler
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+
31
+ from PIL import Image
32
+ import torchvision
33
+ from torchvision import transforms
34
+ from kornia.filters import filter2d
35
+
36
+ from tqdm import tqdm
37
+ from einops import rearrange, reduce, repeat
38
+
39
+ from adabelief_pytorch import AdaBelief
40
+
41
+ # helpers
42
+
43
+ def DiffAugment(x, types=[]):
44
+ for p in types:
45
+ for f in AUGMENT_FNS[p]:
46
+ x = f(x)
47
+ return x.contiguous()
48
+
49
+ @contextmanager
50
+ def null_context():
51
+ yield
52
+
53
+ def combine_contexts(contexts):
54
+ @contextmanager
55
+ def multi_contexts():
56
+ with ExitStack() as stack:
57
+ yield [stack.enter_context(ctx()) for ctx in contexts]
58
+ return multi_contexts
59
+
60
+ def exists(val):
61
+ return val is not None
62
+
63
+
64
+ def is_power_of_two(val):
65
+ return log2(val).is_integer()
66
+
67
+ def default(val, d):
68
+ return val if exists(val) else d
69
+
70
+ def set_requires_grad(model, bool):
71
+ for p in model.parameters():
72
+ p.requires_grad = bool
73
+
74
+ def cycle(iterable):
75
+ while True:
76
+ for i in iterable:
77
+ yield i
78
+
79
+ def raise_if_nan(t):
80
+ if torch.isnan(t):
81
+ raise NanException
82
+
83
+ def evaluate_in_chunks(max_batch_size, model, *args):
84
+ split_args = list(zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
85
+ chunked_outputs = [model(*i) for i in split_args]
86
+ if len(chunked_outputs) == 1:
87
+ return chunked_outputs[0]
88
+ return torch.cat(chunked_outputs, dim=0)
89
+
90
+ def slerp(val, low, high):
91
+ low_norm = low / torch.norm(low, dim=1, keepdim=True)
92
+ high_norm = high / torch.norm(high, dim=1, keepdim=True)
93
+ omega = torch.acos((low_norm * high_norm).sum(1))
94
+ so = torch.sin(omega)
95
+ res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
96
+ return res
97
+
98
+ def safe_div(n, d):
99
+ try:
100
+ res = n / d
101
+ except ZeroDivisionError:
102
+ prefix = '' if int(n >= 0) else '-'
103
+ res = float(f'{prefix}inf')
104
+ return res
105
+
106
+ # loss functions
107
+
108
+ def gen_hinge_loss(fake, real):
109
+ return fake.mean()
110
+
111
+ def hinge_loss(real, fake):
112
+ return (F.relu(1 + real) + F.relu(1 - fake)).mean()
113
+
114
+ def dual_contrastive_loss(real_logits, fake_logits):
115
+ device = real_logits.device
116
+ real_logits, fake_logits = map(lambda t: rearrange(t, '... -> (...)'), (real_logits, fake_logits))
117
+
118
+ def loss_half(t1, t2):
119
+ t1 = rearrange(t1, 'i -> i ()')
120
+ t2 = repeat(t2, 'j -> i j', i = t1.shape[0])
121
+ t = torch.cat((t1, t2), dim = -1)
122
+ return F.cross_entropy(t, torch.zeros(t1.shape[0], device = device, dtype = torch.long))
123
+
124
+ return loss_half(real_logits, fake_logits) + loss_half(-fake_logits, -real_logits)
125
+
126
+ # helper classes
127
+
128
+ class NanException(Exception):
129
+ pass
130
+
131
+ class EMA():
132
+ def __init__(self, beta):
133
+ super().__init__()
134
+ self.beta = beta
135
+ def update_average(self, old, new):
136
+ if not exists(old):
137
+ return new
138
+ return old * self.beta + (1 - self.beta) * new
139
+
140
+ class RandomApply(nn.Module):
141
+ def __init__(self, prob, fn, fn_else = lambda x: x):
142
+ super().__init__()
143
+ self.fn = fn
144
+ self.fn_else = fn_else
145
+ self.prob = prob
146
+ def forward(self, x):
147
+ fn = self.fn if random() < self.prob else self.fn_else
148
+ return fn(x)
149
+
150
+ class ChanNorm(nn.Module):
151
+ def __init__(self, dim, eps = 1e-5):
152
+ super().__init__()
153
+ self.eps = eps
154
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
155
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
156
+
157
+ def forward(self, x):
158
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
159
+ mean = torch.mean(x, dim = 1, keepdim = True)
160
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
161
+
162
+ class PreNorm(nn.Module):
163
+ def __init__(self, dim, fn):
164
+ super().__init__()
165
+ self.fn = fn
166
+ self.norm = ChanNorm(dim)
167
+
168
+ def forward(self, x):
169
+ return self.fn(self.norm(x))
170
+
171
+ class Residual(nn.Module):
172
+ def __init__(self, fn):
173
+ super().__init__()
174
+ self.fn = fn
175
+
176
+ def forward(self, x):
177
+ return self.fn(x) + x
178
+
179
+ class SumBranches(nn.Module):
180
+ def __init__(self, branches):
181
+ super().__init__()
182
+ self.branches = nn.ModuleList(branches)
183
+ def forward(self, x):
184
+ return sum(map(lambda fn: fn(x), self.branches))
185
+
186
+ class Blur(nn.Module):
187
+ def __init__(self):
188
+ super().__init__()
189
+ f = torch.Tensor([1, 2, 1])
190
+ self.register_buffer('f', f)
191
+ def forward(self, x):
192
+ f = self.f
193
+ f = f[None, None, :] * f [None, :, None]
194
+ return filter2d(x, f, normalized=True)
195
+
196
+ # attention
197
+
198
+ class DepthWiseConv2d(nn.Module):
199
+ def __init__(self, dim_in, dim_out, kernel_size, padding = 0, stride = 1, bias = True):
200
+ super().__init__()
201
+ self.net = nn.Sequential(
202
+ nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
203
+ nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
204
+ )
205
+ def forward(self, x):
206
+ return self.net(x)
207
+
208
+ class LinearAttention(nn.Module):
209
+ def __init__(self, dim, dim_head = 64, heads = 8):
210
+ super().__init__()
211
+ self.scale = dim_head ** -0.5
212
+ self.heads = heads
213
+ inner_dim = dim_head * heads
214
+
215
+ self.nonlin = nn.GELU()
216
+ self.to_q = nn.Conv2d(dim, inner_dim, 1, bias = False)
217
+ self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding = 1, bias = False)
218
+ self.to_out = nn.Conv2d(inner_dim, dim, 1)
219
+
220
+ def forward(self, fmap):
221
+ h, x, y = self.heads, *fmap.shape[-2:]
222
+ q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim = 1))
223
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
224
+
225
+ q = q.softmax(dim = -1)
226
+ k = k.softmax(dim = -2)
227
+
228
+ q = q * self.scale
229
+
230
+ context = einsum('b n d, b n e -> b d e', k, v)
231
+ out = einsum('b n d, b d e -> b n e', q, context)
232
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
233
+
234
+ out = self.nonlin(out)
235
+ return self.to_out(out)
236
+
237
+ # global context network
238
+ # https://arxiv.org/abs/2012.13375
239
+ # similar to squeeze-excite, but with a simplified attention pooling and a subsequent layer norm
240
+
241
+ class GlobalContext(nn.Module):
242
+ def __init__(
243
+ self,
244
+ *,
245
+ chan_in,
246
+ chan_out
247
+ ):
248
+ super().__init__()
249
+ self.to_k = nn.Conv2d(chan_in, 1, 1)
250
+ chan_intermediate = max(3, chan_out // 2)
251
+
252
+ self.net = nn.Sequential(
253
+ nn.Conv2d(chan_in, chan_intermediate, 1),
254
+ nn.LeakyReLU(0.1),
255
+ nn.Conv2d(chan_intermediate, chan_out, 1),
256
+ nn.Sigmoid()
257
+ )
258
+ def forward(self, x):
259
+ context = self.to_k(x)
260
+ context = context.flatten(2).softmax(dim = -1)
261
+ out = einsum('b i n, b c n -> b c i', context, x.flatten(2))
262
+ out = out.unsqueeze(-1)
263
+ return self.net(out)
264
+
265
+ # dataset
266
+
267
+ def convert_image_to(img_type, image):
268
+ if image.mode != img_type:
269
+ return image.convert(img_type)
270
+ return image
271
+
272
+ class identity(object):
273
+ def __call__(self, tensor):
274
+ return tensor
275
+
276
+ class expand_greyscale(object):
277
+ def __init__(self, transparent):
278
+ self.transparent = transparent
279
+
280
+ def __call__(self, tensor):
281
+ channels = tensor.shape[0]
282
+ num_target_channels = 4 if self.transparent else 3
283
+
284
+ if channels == num_target_channels:
285
+ return tensor
286
+
287
+ alpha = None
288
+ if channels == 1:
289
+ color = tensor.expand(3, -1, -1)
290
+ elif channels == 2:
291
+ color = tensor[:1].expand(3, -1, -1)
292
+ alpha = tensor[1:]
293
+ else:
294
+ raise Exception(f'image with invalid number of channels given {channels}')
295
+
296
+ if not exists(alpha) and self.transparent:
297
+ alpha = torch.ones(1, *tensor.shape[1:], device=tensor.device)
298
+
299
+ return color if not self.transparent else torch.cat((color, alpha))
300
+
301
+
302
+ class FCANet(nn.Module):
303
+ def __init__(
304
+ self,
305
+ *,
306
+ chan_in,
307
+ chan_out,
308
+ reduction = 4,
309
+ width
310
+ ):
311
+ super().__init__()
312
+
313
+ freq_w, freq_h = ([0] * 8), list(range(8)) # in paper, it seems 16 frequencies was ideal
314
+ dct_weights = get_dct_weights(width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w])
315
+ self.register_buffer('dct_weights', dct_weights)
316
+
317
+ chan_intermediate = max(3, chan_out // reduction)
318
+
319
+ self.net = nn.Sequential(
320
+ nn.Conv2d(chan_in, chan_intermediate, 1),
321
+ nn.LeakyReLU(0.1),
322
+ nn.Conv2d(chan_intermediate, chan_out, 1),
323
+ nn.Sigmoid()
324
+ )
325
+
326
+ def forward(self, x):
327
+ x = reduce(x * self.dct_weights, 'b c (h h1) (w w1) -> b c h1 w1', 'sum', h1 = 1, w1 = 1)
328
+ return self.net(x)
329
+
330
+ # modifiable global variables
331
+
332
+ norm_class = nn.BatchNorm2d
333
+
334
+ def upsample(scale_factor = 2):
335
+ return nn.Upsample(scale_factor = scale_factor)
336
+
337
+
338
+ # generative adversarial network
339
+
340
+ class Generator(nn.Module):
341
+ def __init__(
342
+ self,
343
+ *,
344
+ image_size,
345
+ latent_dim = 256,
346
+ fmap_max = 512,
347
+ fmap_inverse_coef = 12,
348
+ transparent = False,
349
+ greyscale = False,
350
+ attn_res_layers = [],
351
+ freq_chan_attn = False
352
+ ):
353
+ super().__init__()
354
+ resolution = log2(image_size)
355
+ assert is_power_of_two(image_size), 'image size must be a power of 2'
356
+
357
+ if transparent:
358
+ init_channel = 4
359
+ elif greyscale:
360
+ init_channel = 1
361
+ else:
362
+ init_channel = 3
363
+
364
+ fmap_max = default(fmap_max, latent_dim)
365
+
366
+ self.initial_conv = nn.Sequential(
367
+ nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
368
+ norm_class(latent_dim * 2),
369
+ nn.GLU(dim = 1)
370
+ )
371
+
372
+ num_layers = int(resolution) - 2
373
+ features = list(map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)))
374
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
375
+ features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
376
+ features = [latent_dim, *features]
377
+
378
+ in_out_features = list(zip(features[:-1], features[1:]))
379
+
380
+ self.res_layers = range(2, num_layers + 2)
381
+ self.layers = nn.ModuleList([])
382
+ self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
383
+
384
+ self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
385
+ self.sle_map = list(filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map))
386
+ self.sle_map = dict(self.sle_map)
387
+
388
+ self.num_layers_spatial_res = 1
389
+
390
+ for (res, (chan_in, chan_out)) in zip(self.res_layers, in_out_features):
391
+ image_width = 2 ** res
392
+
393
+ attn = None
394
+ if image_width in attn_res_layers:
395
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
396
+
397
+ sle = None
398
+ if res in self.sle_map:
399
+ residual_layer = self.sle_map[res]
400
+ sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
401
+
402
+ if freq_chan_attn:
403
+ sle = FCANet(
404
+ chan_in = chan_out,
405
+ chan_out = sle_chan_out,
406
+ width = 2 ** (res + 1)
407
+ )
408
+ else:
409
+ sle = GlobalContext(
410
+ chan_in = chan_out,
411
+ chan_out = sle_chan_out
412
+ )
413
+
414
+ layer = nn.ModuleList([
415
+ nn.Sequential(
416
+ upsample(),
417
+ Blur(),
418
+ nn.Conv2d(chan_in, chan_out * 2, 3, padding = 1),
419
+ norm_class(chan_out * 2),
420
+ nn.GLU(dim = 1)
421
+ ),
422
+ sle,
423
+ attn
424
+ ])
425
+ self.layers.append(layer)
426
+
427
+ self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding = 1)
428
+
429
+ def forward(self, x):
430
+ x = rearrange(x, 'b c -> b c () ()')
431
+ x = self.initial_conv(x)
432
+ x = F.normalize(x, dim = 1)
433
+
434
+ residuals = dict()
435
+
436
+ for (res, (up, sle, attn)) in zip(self.res_layers, self.layers):
437
+ if exists(attn):
438
+ x = attn(x) + x
439
+
440
+ x = up(x)
441
+
442
+ if exists(sle):
443
+ out_res = self.sle_map[res]
444
+ residual = sle(x)
445
+ residuals[out_res] = residual
446
+
447
+ next_res = res + 1
448
+ if next_res in residuals:
449
+ x = x * residuals[next_res]
450
+
451
+ return self.out_conv(x)
452
+
453
+
454
+ #### ACTUALLY LOAD THE MODEL AND DEFINE THE INTERFACE
455
+
456
+ # Initialize a generator model
457
+ gan_new = Generator(latent_dim=256, image_size=256, attn_res_layers = [32])
458
+
459
+ # Load from local saved state dict
460
+ # gan_new.load_state_dict(torch.load('/content/orbgan_e3_state_dict.pt'))
461
+
462
+ # Load from model hub:
463
+ class GeneratorWithPyTorchModelHubMixin(gan_new.__class__, PyTorchModelHubMixin):
464
+ pass
465
+ gan_new.__class__ = GeneratorWithPyTorchModelHubMixin
466
+ gan_new = gan_new.from_pretrained('johnowhitaker/orbgan_e1', latent_dim=256, image_size=256, attn_res_layers = [32])
467
+
468
+ def gen_ims(n_rows):
469
+ ims = gan_new(torch.randn(int(n_rows)**2, 256)).clamp_(0., 1.)
470
+ grid = torchvision.utils.make_grid(ims, nrow=int(n_rows)).permute(1, 2, 0).detach().cpu().numpy()
471
+ return (grid*255).astype(np.uint8)
472
+ iface = gr.Interface(fn=gen_ims,
473
+ inputs=[gr.inputs.Number(label="N rows", default=3)],
474
+ outputs=[gr.outputs.Image(type="numpy", label="Generated Images")],
475
+ title='Demo for https://huggingface.co/johnowhitaker/orbgan_e1'
476
+ )
477
+ iface.launch()