valhalla commited on
Commit
41c77df
1 Parent(s): ecbc8c8
bert/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "../fusing-models/bert/",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "architectures": [
6
+ "LDMBertModel"
7
+ ],
8
+ "attention_dropout": 0.0,
9
+ "classifier_dropout": 0.0,
10
+ "d_model": 1280,
11
+ "dropout": 0.1,
12
+ "encoder_attention_heads": 8,
13
+ "encoder_ffn_dim": 5120,
14
+ "encoder_layerdrop": 0.0,
15
+ "encoder_layers": 32,
16
+ "head_dim": 64,
17
+ "init_std": 0.02,
18
+ "max_position_embeddings": 77,
19
+ "model_type": "ldmbert",
20
+ "num_hidden_layers": 32,
21
+ "pad_token_id": 0,
22
+ "scale_embedding": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.20.0.dev0",
25
+ "use_cache": true,
26
+ "vocab_size": 30522
27
+ }
bert/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b33de66bbe4f4a28993bf2620f27252ebbfa4ef9a7e4dfb967ad093b4578c5eb
3
+ size 2328112821
model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "LatentDiffusion",
3
+ "_diffusers_version": "0.0.1",
4
+ "_module": "modeling_latent_diffusion.py",
5
+ "bert": [
6
+ "transformers",
7
+ "LDMBertModel"
8
+ ],
9
+ "noise_scheduler": [
10
+ "diffusers",
11
+ "GaussianDDPMScheduler"
12
+ ],
13
+ "tokenizer": [
14
+ "transformers",
15
+ "BertTokenizer"
16
+ ],
17
+ "unet": [
18
+ "diffusers",
19
+ "UNetLDMModel"
20
+ ],
21
+ "vqvae": [
22
+ "modeling_latent_diffusion",
23
+ "AutoencoderKL"
24
+ ]
25
+ }
modeling_latent_diffusion.py ADDED
@@ -0,0 +1,965 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+
4
+ import numpy as np
5
+ import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.configuration_utils import ConfigMixin
11
+ from diffusers.modeling_utils import ModelMixin
12
+
13
+
14
+ def get_timestep_embedding(timesteps, embedding_dim):
15
+ """
16
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
17
+ From Fairseq.
18
+ Build sinusoidal embeddings.
19
+ This matches the implementation in tensor2tensor, but differs slightly
20
+ from the description in Section 3.5 of "Attention Is All You Need".
21
+ """
22
+ assert len(timesteps.shape) == 1
23
+
24
+ half_dim = embedding_dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
27
+ emb = emb.to(device=timesteps.device)
28
+ emb = timesteps.float()[:, None] * emb[None, :]
29
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
30
+ if embedding_dim % 2 == 1: # zero pad
31
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
32
+ return emb
33
+
34
+
35
+ def nonlinearity(x):
36
+ # swish
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ def Normalize(in_channels):
41
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
42
+
43
+
44
+ class Upsample(nn.Module):
45
+ def __init__(self, in_channels, with_conv):
46
+ super().__init__()
47
+ self.with_conv = with_conv
48
+ if self.with_conv:
49
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
50
+
51
+ def forward(self, x):
52
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
53
+ if self.with_conv:
54
+ x = self.conv(x)
55
+ return x
56
+
57
+
58
+ class Downsample(nn.Module):
59
+ def __init__(self, in_channels, with_conv):
60
+ super().__init__()
61
+ self.with_conv = with_conv
62
+ if self.with_conv:
63
+ # no asymmetric padding in torch conv, must do it ourselves
64
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
65
+
66
+ def forward(self, x):
67
+ if self.with_conv:
68
+ pad = (0, 1, 0, 1)
69
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
70
+ x = self.conv(x)
71
+ else:
72
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
73
+ return x
74
+
75
+
76
+ class ResnetBlock(nn.Module):
77
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
86
+ if temb_channels > 0:
87
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
88
+ self.norm2 = Normalize(out_channels)
89
+ self.dropout = torch.nn.Dropout(dropout)
90
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
91
+ if self.in_channels != self.out_channels:
92
+ if self.use_conv_shortcut:
93
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
94
+ else:
95
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
96
+
97
+ def forward(self, x, temb):
98
+ h = x
99
+ h = self.norm1(h)
100
+ h = nonlinearity(h)
101
+ h = self.conv1(h)
102
+
103
+ if temb is not None:
104
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
105
+
106
+ h = self.norm2(h)
107
+ h = nonlinearity(h)
108
+ h = self.dropout(h)
109
+ h = self.conv2(h)
110
+
111
+ if self.in_channels != self.out_channels:
112
+ if self.use_conv_shortcut:
113
+ x = self.conv_shortcut(x)
114
+ else:
115
+ x = self.nin_shortcut(x)
116
+
117
+ return x + h
118
+
119
+
120
+ class AttnBlock(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ self.in_channels = in_channels
124
+
125
+ self.norm = Normalize(in_channels)
126
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
127
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
128
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
129
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
130
+
131
+ def forward(self, x):
132
+ h_ = x
133
+ h_ = self.norm(h_)
134
+ q = self.q(h_)
135
+ k = self.k(h_)
136
+ v = self.v(h_)
137
+
138
+ # compute attention
139
+ b, c, h, w = q.shape
140
+ q = q.reshape(b, c, h * w)
141
+ q = q.permute(0, 2, 1) # b,hw,c
142
+ k = k.reshape(b, c, h * w) # b,c,hw
143
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
144
+ w_ = w_ * (int(c) ** (-0.5))
145
+ w_ = torch.nn.functional.softmax(w_, dim=2)
146
+
147
+ # attend to values
148
+ v = v.reshape(b, c, h * w)
149
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
150
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
151
+ h_ = h_.reshape(b, c, h, w)
152
+
153
+ h_ = self.proj_out(h_)
154
+
155
+ return x + h_
156
+
157
+
158
+ class Model(nn.Module):
159
+ def __init__(
160
+ self,
161
+ *,
162
+ ch,
163
+ out_ch,
164
+ ch_mult=(1, 2, 4, 8),
165
+ num_res_blocks,
166
+ attn_resolutions,
167
+ dropout=0.0,
168
+ resamp_with_conv=True,
169
+ in_channels,
170
+ resolution,
171
+ use_timestep=True,
172
+ ):
173
+ super().__init__()
174
+ self.ch = ch
175
+ self.temb_ch = self.ch * 4
176
+ self.num_resolutions = len(ch_mult)
177
+ self.num_res_blocks = num_res_blocks
178
+ self.resolution = resolution
179
+ self.in_channels = in_channels
180
+
181
+ self.use_timestep = use_timestep
182
+ if self.use_timestep:
183
+ # timestep embedding
184
+ self.temb = nn.Module()
185
+ self.temb.dense = nn.ModuleList(
186
+ [
187
+ torch.nn.Linear(self.ch, self.temb_ch),
188
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
189
+ ]
190
+ )
191
+
192
+ # downsampling
193
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
194
+
195
+ curr_res = resolution
196
+ in_ch_mult = (1,) + tuple(ch_mult)
197
+ self.down = nn.ModuleList()
198
+ for i_level in range(self.num_resolutions):
199
+ block = nn.ModuleList()
200
+ attn = nn.ModuleList()
201
+ block_in = ch * in_ch_mult[i_level]
202
+ block_out = ch * ch_mult[i_level]
203
+ for i_block in range(self.num_res_blocks):
204
+ block.append(
205
+ ResnetBlock(
206
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
207
+ )
208
+ )
209
+ block_in = block_out
210
+ if curr_res in attn_resolutions:
211
+ attn.append(AttnBlock(block_in))
212
+ down = nn.Module()
213
+ down.block = block
214
+ down.attn = attn
215
+ if i_level != self.num_resolutions - 1:
216
+ down.downsample = Downsample(block_in, resamp_with_conv)
217
+ curr_res = curr_res // 2
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock(
223
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
224
+ )
225
+ self.mid.attn_1 = AttnBlock(block_in)
226
+ self.mid.block_2 = ResnetBlock(
227
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
228
+ )
229
+
230
+ # upsampling
231
+ self.up = nn.ModuleList()
232
+ for i_level in reversed(range(self.num_resolutions)):
233
+ block = nn.ModuleList()
234
+ attn = nn.ModuleList()
235
+ block_out = ch * ch_mult[i_level]
236
+ skip_in = ch * ch_mult[i_level]
237
+ for i_block in range(self.num_res_blocks + 1):
238
+ if i_block == self.num_res_blocks:
239
+ skip_in = ch * in_ch_mult[i_level]
240
+ block.append(
241
+ ResnetBlock(
242
+ in_channels=block_in + skip_in,
243
+ out_channels=block_out,
244
+ temb_channels=self.temb_ch,
245
+ dropout=dropout,
246
+ )
247
+ )
248
+ block_in = block_out
249
+ if curr_res in attn_resolutions:
250
+ attn.append(AttnBlock(block_in))
251
+ up = nn.Module()
252
+ up.block = block
253
+ up.attn = attn
254
+ if i_level != 0:
255
+ up.upsample = Upsample(block_in, resamp_with_conv)
256
+ curr_res = curr_res * 2
257
+ self.up.insert(0, up) # prepend to get consistent order
258
+
259
+ # end
260
+ self.norm_out = Normalize(block_in)
261
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
262
+
263
+ def forward(self, x, t=None):
264
+ # assert x.shape[2] == x.shape[3] == self.resolution
265
+
266
+ if self.use_timestep:
267
+ # timestep embedding
268
+ assert t is not None
269
+ temb = get_timestep_embedding(t, self.ch)
270
+ temb = self.temb.dense[0](temb)
271
+ temb = nonlinearity(temb)
272
+ temb = self.temb.dense[1](temb)
273
+ else:
274
+ temb = None
275
+
276
+ # downsampling
277
+ hs = [self.conv_in(x)]
278
+ for i_level in range(self.num_resolutions):
279
+ for i_block in range(self.num_res_blocks):
280
+ h = self.down[i_level].block[i_block](hs[-1], temb)
281
+ if len(self.down[i_level].attn) > 0:
282
+ h = self.down[i_level].attn[i_block](h)
283
+ hs.append(h)
284
+ if i_level != self.num_resolutions - 1:
285
+ hs.append(self.down[i_level].downsample(hs[-1]))
286
+
287
+ # middle
288
+ h = hs[-1]
289
+ h = self.mid.block_1(h, temb)
290
+ h = self.mid.attn_1(h)
291
+ h = self.mid.block_2(h, temb)
292
+
293
+ # upsampling
294
+ for i_level in reversed(range(self.num_resolutions)):
295
+ for i_block in range(self.num_res_blocks + 1):
296
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
297
+ if len(self.up[i_level].attn) > 0:
298
+ h = self.up[i_level].attn[i_block](h)
299
+ if i_level != 0:
300
+ h = self.up[i_level].upsample(h)
301
+
302
+ # end
303
+ h = self.norm_out(h)
304
+ h = nonlinearity(h)
305
+ h = self.conv_out(h)
306
+ return h
307
+
308
+
309
+ class Encoder(nn.Module):
310
+ def __init__(
311
+ self,
312
+ *,
313
+ ch,
314
+ out_ch,
315
+ ch_mult=(1, 2, 4, 8),
316
+ num_res_blocks,
317
+ attn_resolutions,
318
+ dropout=0.0,
319
+ resamp_with_conv=True,
320
+ in_channels,
321
+ resolution,
322
+ z_channels,
323
+ double_z=True,
324
+ **ignore_kwargs,
325
+ ):
326
+ super().__init__()
327
+ self.ch = ch
328
+ self.temb_ch = 0
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ # downsampling
335
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
336
+
337
+ curr_res = resolution
338
+ in_ch_mult = (1,) + tuple(ch_mult)
339
+ self.down = nn.ModuleList()
340
+ for i_level in range(self.num_resolutions):
341
+ block = nn.ModuleList()
342
+ attn = nn.ModuleList()
343
+ block_in = ch * in_ch_mult[i_level]
344
+ block_out = ch * ch_mult[i_level]
345
+ for i_block in range(self.num_res_blocks):
346
+ block.append(
347
+ ResnetBlock(
348
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
349
+ )
350
+ )
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(AttnBlock(block_in))
354
+ down = nn.Module()
355
+ down.block = block
356
+ down.attn = attn
357
+ if i_level != self.num_resolutions - 1:
358
+ down.downsample = Downsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res // 2
360
+ self.down.append(down)
361
+
362
+ # middle
363
+ self.mid = nn.Module()
364
+ self.mid.block_1 = ResnetBlock(
365
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
366
+ )
367
+ self.mid.attn_1 = AttnBlock(block_in)
368
+ self.mid.block_2 = ResnetBlock(
369
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
370
+ )
371
+
372
+ # end
373
+ self.norm_out = Normalize(block_in)
374
+ self.conv_out = torch.nn.Conv2d(
375
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
376
+ )
377
+
378
+ def forward(self, x):
379
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
380
+
381
+ # timestep embedding
382
+ temb = None
383
+
384
+ # downsampling
385
+ hs = [self.conv_in(x)]
386
+ for i_level in range(self.num_resolutions):
387
+ for i_block in range(self.num_res_blocks):
388
+ h = self.down[i_level].block[i_block](hs[-1], temb)
389
+ if len(self.down[i_level].attn) > 0:
390
+ h = self.down[i_level].attn[i_block](h)
391
+ hs.append(h)
392
+ if i_level != self.num_resolutions - 1:
393
+ hs.append(self.down[i_level].downsample(hs[-1]))
394
+
395
+ # middle
396
+ h = hs[-1]
397
+ h = self.mid.block_1(h, temb)
398
+ h = self.mid.attn_1(h)
399
+ h = self.mid.block_2(h, temb)
400
+
401
+ # end
402
+ h = self.norm_out(h)
403
+ h = nonlinearity(h)
404
+ h = self.conv_out(h)
405
+ return h
406
+
407
+
408
+ class Decoder(nn.Module):
409
+ def __init__(
410
+ self,
411
+ *,
412
+ ch,
413
+ out_ch,
414
+ ch_mult=(1, 2, 4, 8),
415
+ num_res_blocks,
416
+ attn_resolutions,
417
+ dropout=0.0,
418
+ resamp_with_conv=True,
419
+ in_channels,
420
+ resolution,
421
+ z_channels,
422
+ give_pre_end=False,
423
+ **ignorekwargs,
424
+ ):
425
+ super().__init__()
426
+ self.ch = ch
427
+ self.temb_ch = 0
428
+ self.num_resolutions = len(ch_mult)
429
+ self.num_res_blocks = num_res_blocks
430
+ self.resolution = resolution
431
+ self.in_channels = in_channels
432
+ self.give_pre_end = give_pre_end
433
+
434
+ # compute in_ch_mult, block_in and curr_res at lowest res
435
+ in_ch_mult = (1,) + tuple(ch_mult)
436
+ block_in = ch * ch_mult[self.num_resolutions - 1]
437
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
438
+ self.z_shape = (1, z_channels, curr_res, curr_res)
439
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
440
+
441
+ # z to block_in
442
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
443
+
444
+ # middle
445
+ self.mid = nn.Module()
446
+ self.mid.block_1 = ResnetBlock(
447
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
448
+ )
449
+ self.mid.attn_1 = AttnBlock(block_in)
450
+ self.mid.block_2 = ResnetBlock(
451
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
452
+ )
453
+
454
+ # upsampling
455
+ self.up = nn.ModuleList()
456
+ for i_level in reversed(range(self.num_resolutions)):
457
+ block = nn.ModuleList()
458
+ attn = nn.ModuleList()
459
+ block_out = ch * ch_mult[i_level]
460
+ for i_block in range(self.num_res_blocks + 1):
461
+ block.append(
462
+ ResnetBlock(
463
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
464
+ )
465
+ )
466
+ block_in = block_out
467
+ if curr_res in attn_resolutions:
468
+ attn.append(AttnBlock(block_in))
469
+ up = nn.Module()
470
+ up.block = block
471
+ up.attn = attn
472
+ if i_level != 0:
473
+ up.upsample = Upsample(block_in, resamp_with_conv)
474
+ curr_res = curr_res * 2
475
+ self.up.insert(0, up) # prepend to get consistent order
476
+
477
+ # end
478
+ self.norm_out = Normalize(block_in)
479
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
480
+
481
+ def forward(self, z):
482
+ # assert z.shape[1:] == self.z_shape[1:]
483
+ self.last_z_shape = z.shape
484
+
485
+ # timestep embedding
486
+ temb = None
487
+
488
+ # z to block_in
489
+ h = self.conv_in(z)
490
+
491
+ # middle
492
+ h = self.mid.block_1(h, temb)
493
+ h = self.mid.attn_1(h)
494
+ h = self.mid.block_2(h, temb)
495
+
496
+ # upsampling
497
+ for i_level in reversed(range(self.num_resolutions)):
498
+ for i_block in range(self.num_res_blocks + 1):
499
+ h = self.up[i_level].block[i_block](h, temb)
500
+ if len(self.up[i_level].attn) > 0:
501
+ h = self.up[i_level].attn[i_block](h)
502
+ if i_level != 0:
503
+ h = self.up[i_level].upsample(h)
504
+
505
+ # end
506
+ if self.give_pre_end:
507
+ return h
508
+
509
+ h = self.norm_out(h)
510
+ h = nonlinearity(h)
511
+ h = self.conv_out(h)
512
+ return h
513
+
514
+
515
+ class VectorQuantizer(nn.Module):
516
+ """
517
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
518
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
519
+ """
520
+
521
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
522
+ # backwards compatibility we use the buggy version by default, but you can
523
+ # specify legacy=False to fix it.
524
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
525
+ super().__init__()
526
+ self.n_e = n_e
527
+ self.e_dim = e_dim
528
+ self.beta = beta
529
+ self.legacy = legacy
530
+
531
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
532
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
533
+
534
+ self.remap = remap
535
+ if self.remap is not None:
536
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
537
+ self.re_embed = self.used.shape[0]
538
+ self.unknown_index = unknown_index # "random" or "extra" or integer
539
+ if self.unknown_index == "extra":
540
+ self.unknown_index = self.re_embed
541
+ self.re_embed = self.re_embed + 1
542
+ print(
543
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
544
+ f"Using {self.unknown_index} for unknown indices."
545
+ )
546
+ else:
547
+ self.re_embed = n_e
548
+
549
+ self.sane_index_shape = sane_index_shape
550
+
551
+ def remap_to_used(self, inds):
552
+ ishape = inds.shape
553
+ assert len(ishape) > 1
554
+ inds = inds.reshape(ishape[0], -1)
555
+ used = self.used.to(inds)
556
+ match = (inds[:, :, None] == used[None, None, ...]).long()
557
+ new = match.argmax(-1)
558
+ unknown = match.sum(2) < 1
559
+ if self.unknown_index == "random":
560
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
561
+ else:
562
+ new[unknown] = self.unknown_index
563
+ return new.reshape(ishape)
564
+
565
+ def unmap_to_all(self, inds):
566
+ ishape = inds.shape
567
+ assert len(ishape) > 1
568
+ inds = inds.reshape(ishape[0], -1)
569
+ used = self.used.to(inds)
570
+ if self.re_embed > self.used.shape[0]: # extra token
571
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
572
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
573
+ return back.reshape(ishape)
574
+
575
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
576
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
577
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
578
+ assert return_logits == False, "Only for interface compatible with Gumbel"
579
+ # reshape z -> (batch, height, width, channel) and flatten
580
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
581
+ z_flattened = z.view(-1, self.e_dim)
582
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
583
+
584
+ d = (
585
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
586
+ + torch.sum(self.embedding.weight**2, dim=1)
587
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
588
+ )
589
+
590
+ min_encoding_indices = torch.argmin(d, dim=1)
591
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
592
+ perplexity = None
593
+ min_encodings = None
594
+
595
+ # compute loss for embedding
596
+ if not self.legacy:
597
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
598
+ else:
599
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
600
+
601
+ # preserve gradients
602
+ z_q = z + (z_q - z).detach()
603
+
604
+ # reshape back to match original input shape
605
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
606
+
607
+ if self.remap is not None:
608
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
609
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
610
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
611
+
612
+ if self.sane_index_shape:
613
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
614
+
615
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
616
+
617
+ def get_codebook_entry(self, indices, shape):
618
+ # shape specifying (batch, height, width, channel)
619
+ if self.remap is not None:
620
+ indices = indices.reshape(shape[0], -1) # add batch axis
621
+ indices = self.unmap_to_all(indices)
622
+ indices = indices.reshape(-1) # flatten again
623
+
624
+ # get quantized latent vectors
625
+ z_q = self.embedding(indices)
626
+
627
+ if shape is not None:
628
+ z_q = z_q.view(shape)
629
+ # reshape back to match original input shape
630
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
631
+
632
+ return z_q
633
+
634
+
635
+ class VQModel(ModelMixin, ConfigMixin):
636
+ def __init__(
637
+ self,
638
+ ch,
639
+ out_ch,
640
+ num_res_blocks,
641
+ attn_resolutions,
642
+ in_channels,
643
+ resolution,
644
+ z_channels,
645
+ n_embed,
646
+ embed_dim,
647
+ remap=None,
648
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
649
+ ch_mult=(1, 2, 4, 8),
650
+ dropout=0.0,
651
+ double_z=True,
652
+ resamp_with_conv=True,
653
+ give_pre_end=False,
654
+ ):
655
+ super().__init__()
656
+
657
+ # register all __init__ params with self.register
658
+ self.register(
659
+ ch=ch,
660
+ out_ch=out_ch,
661
+ num_res_blocks=num_res_blocks,
662
+ attn_resolutions=attn_resolutions,
663
+ in_channels=in_channels,
664
+ resolution=resolution,
665
+ z_channels=z_channels,
666
+ n_embed=n_embed,
667
+ embed_dim=embed_dim,
668
+ remap=remap,
669
+ sane_index_shape=sane_index_shape,
670
+ ch_mult=ch_mult,
671
+ dropout=dropout,
672
+ double_z=double_z,
673
+ resamp_with_conv=resamp_with_conv,
674
+ give_pre_end=give_pre_end,
675
+ )
676
+
677
+ # pass init params to Encoder
678
+ self.encoder = Encoder(
679
+ ch=ch,
680
+ out_ch=out_ch,
681
+ num_res_blocks=num_res_blocks,
682
+ attn_resolutions=attn_resolutions,
683
+ in_channels=in_channels,
684
+ resolution=resolution,
685
+ z_channels=z_channels,
686
+ ch_mult=ch_mult,
687
+ dropout=dropout,
688
+ resamp_with_conv=resamp_with_conv,
689
+ double_z=double_z,
690
+ give_pre_end=give_pre_end,
691
+ )
692
+
693
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
694
+
695
+ # pass init params to Decoder
696
+ self.decoder = Decoder(
697
+ ch=ch,
698
+ out_ch=out_ch,
699
+ num_res_blocks=num_res_blocks,
700
+ attn_resolutions=attn_resolutions,
701
+ in_channels=in_channels,
702
+ resolution=resolution,
703
+ z_channels=z_channels,
704
+ ch_mult=ch_mult,
705
+ dropout=dropout,
706
+ resamp_with_conv=resamp_with_conv,
707
+ give_pre_end=give_pre_end,
708
+ )
709
+
710
+ def encode(self, x):
711
+ h = self.encoder(x)
712
+ h = self.quant_conv(h)
713
+ return h
714
+
715
+ def decode(self, h, force_not_quantize=False):
716
+ # also go through quantization layer
717
+ if not force_not_quantize:
718
+ quant, emb_loss, info = self.quantize(h)
719
+ else:
720
+ quant = h
721
+ quant = self.post_quant_conv(quant)
722
+ dec = self.decoder(quant)
723
+ return dec
724
+
725
+
726
+ class DiagonalGaussianDistribution(object):
727
+ def __init__(self, parameters, deterministic=False):
728
+ self.parameters = parameters
729
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
730
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
731
+ self.deterministic = deterministic
732
+ self.std = torch.exp(0.5 * self.logvar)
733
+ self.var = torch.exp(self.logvar)
734
+ if self.deterministic:
735
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
736
+
737
+ def sample(self):
738
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
739
+ return x
740
+
741
+ def kl(self, other=None):
742
+ if self.deterministic:
743
+ return torch.Tensor([0.])
744
+ else:
745
+ if other is None:
746
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
747
+ + self.var - 1.0 - self.logvar,
748
+ dim=[1, 2, 3])
749
+ else:
750
+ return 0.5 * torch.sum(
751
+ torch.pow(self.mean - other.mean, 2) / other.var
752
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
753
+ dim=[1, 2, 3])
754
+
755
+ def nll(self, sample, dims=[1,2,3]):
756
+ if self.deterministic:
757
+ return torch.Tensor([0.])
758
+ logtwopi = np.log(2.0 * np.pi)
759
+ return 0.5 * torch.sum(
760
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
761
+ dim=dims)
762
+
763
+ def mode(self):
764
+ return self.mean
765
+
766
+ class AutoencoderKL(ModelMixin, ConfigMixin):
767
+ def __init__(
768
+ self,
769
+ ch,
770
+ out_ch,
771
+ num_res_blocks,
772
+ attn_resolutions,
773
+ in_channels,
774
+ resolution,
775
+ z_channels,
776
+ embed_dim,
777
+ remap=None,
778
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
779
+ ch_mult=(1, 2, 4, 8),
780
+ dropout=0.0,
781
+ double_z=True,
782
+ resamp_with_conv=True,
783
+ give_pre_end=False,
784
+ ):
785
+ super().__init__()
786
+
787
+ # register all __init__ params with self.register
788
+ self.register(
789
+ ch=ch,
790
+ out_ch=out_ch,
791
+ num_res_blocks=num_res_blocks,
792
+ attn_resolutions=attn_resolutions,
793
+ in_channels=in_channels,
794
+ resolution=resolution,
795
+ z_channels=z_channels,
796
+ embed_dim=embed_dim,
797
+ remap=remap,
798
+ sane_index_shape=sane_index_shape,
799
+ ch_mult=ch_mult,
800
+ dropout=dropout,
801
+ double_z=double_z,
802
+ resamp_with_conv=resamp_with_conv,
803
+ give_pre_end=give_pre_end,
804
+ )
805
+
806
+ # pass init params to Encoder
807
+ self.encoder = Encoder(
808
+ ch=ch,
809
+ out_ch=out_ch,
810
+ num_res_blocks=num_res_blocks,
811
+ attn_resolutions=attn_resolutions,
812
+ in_channels=in_channels,
813
+ resolution=resolution,
814
+ z_channels=z_channels,
815
+ ch_mult=ch_mult,
816
+ dropout=dropout,
817
+ resamp_with_conv=resamp_with_conv,
818
+ double_z=double_z,
819
+ give_pre_end=give_pre_end,
820
+ )
821
+
822
+ # pass init params to Decoder
823
+ self.decoder = Decoder(
824
+ ch=ch,
825
+ out_ch=out_ch,
826
+ num_res_blocks=num_res_blocks,
827
+ attn_resolutions=attn_resolutions,
828
+ in_channels=in_channels,
829
+ resolution=resolution,
830
+ z_channels=z_channels,
831
+ ch_mult=ch_mult,
832
+ dropout=dropout,
833
+ resamp_with_conv=resamp_with_conv,
834
+ give_pre_end=give_pre_end,
835
+ )
836
+
837
+ self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
838
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
839
+
840
+ def encode(self, x):
841
+ h = self.encoder(x)
842
+ moments = self.quant_conv(h)
843
+ posterior = DiagonalGaussianDistribution(moments)
844
+ return posterior
845
+
846
+ def decode(self, z):
847
+ z = self.post_quant_conv(z)
848
+ dec = self.decoder(z)
849
+ return dec
850
+
851
+ def forward(self, input, sample_posterior=True):
852
+ posterior = self.encode(input)
853
+ if sample_posterior:
854
+ z = posterior.sample()
855
+ else:
856
+ z = posterior.mode()
857
+ dec = self.decode(z)
858
+ return dec, posterior
859
+
860
+
861
+ class LatentDiffusion(DiffusionPipeline):
862
+ def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
863
+ super().__init__()
864
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
865
+
866
+ def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
867
+ # eta corresponds to η in paper and should be between [0, 1]
868
+
869
+ if torch_device is None:
870
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
871
+
872
+ self.unet.to(torch_device)
873
+ self.vqvae.to(torch_device)
874
+ self.bert.to(torch_device)
875
+
876
+ if guidance_scale != 1.0:
877
+ uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
878
+ uncond_embeddings = self.bert(uncond_input.input_ids)[0]
879
+
880
+ # get text embedding
881
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
882
+ text_embedding = self.bert(text_input.input_ids)[0]
883
+
884
+ num_trained_timesteps = self.noise_scheduler.num_timesteps
885
+ inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
886
+
887
+ image = self.noise_scheduler.sample_noise(
888
+ (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
889
+ device=torch_device,
890
+ generator=generator,
891
+ )
892
+
893
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
894
+ # Ideally, read DDIM paper in-detail understanding
895
+
896
+ # Notation (<variable name> -> <name in paper>
897
+ # - pred_noise_t -> e_theta(x_t, t)
898
+ # - pred_original_image -> f_theta(x_t, t) or x_0
899
+ # - std_dev_t -> sigma_t
900
+ # - eta -> η
901
+ # - pred_image_direction -> "direction pointingc to x_t"
902
+ # - pred_prev_image -> "x_t-1"
903
+ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
904
+ # 1. predict noise residual
905
+ if guidance_scale == 1.0:
906
+ timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
907
+ context = text_embedding
908
+ image_in = image
909
+ else:
910
+ image_in = torch.cat([image] * 2)
911
+ timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
912
+ context = torch.cat([uncond_embeddings, text_embedding])
913
+
914
+ with torch.no_grad():
915
+ pred_noise_t = self.unet(image_in, timesteps, context=context)
916
+
917
+ if guidance_scale != 1.0:
918
+ pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
919
+ pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
920
+
921
+ # 2. get actual t and t-1
922
+ train_step = inference_step_times[t]
923
+ prev_train_step = inference_step_times[t - 1] if t > 0 else -1
924
+
925
+ # 3. compute alphas, betas
926
+ alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
927
+ alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
928
+ beta_prod_t = 1 - alpha_prod_t
929
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
930
+
931
+ # 4. Compute predicted previous image from predicted noise
932
+ # First: compute predicted original image from predicted noise also called
933
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
934
+ pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
935
+
936
+ # Second: Clip "predicted x_0"
937
+ # pred_original_image = torch.clamp(pred_original_image, -1, 1)
938
+
939
+ # Third: Compute variance: "sigma_t(η)" -> see formula (16)
940
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
941
+ std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
942
+ std_dev_t = eta * std_dev_t
943
+
944
+ # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
945
+ pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
946
+
947
+ # Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
948
+ pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
949
+
950
+ # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
951
+ # Note: eta = 1.0 essentially corresponds to DDPM
952
+ if eta > 0.0:
953
+ noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
954
+ prev_image = pred_prev_image + std_dev_t * noise
955
+ else:
956
+ prev_image = pred_prev_image
957
+
958
+ # 6. Set current image to prev_image: x_t -> x_t-1
959
+ image = prev_image
960
+
961
+ image = 1 / 0.18215 * image
962
+ image = self.vqvae.decode(image)
963
+ image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)
964
+
965
+ return image
noise_scheduler/scheduler_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "GaussianDDPMScheduler",
3
+ "_diffusers_version": "0.0.1",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "linear",
6
+ "beta_start": 0.00085,
7
+ "timesteps": 1000,
8
+ "variance_type": "fixed_small"
9
+ }
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 512,
7
+ "name_or_path": "bert-base-uncased",
8
+ "never_split": null,
9
+ "pad_token": "[PAD]",
10
+ "sep_token": "[SEP]",
11
+ "special_tokens_map_file": null,
12
+ "strip_accents": null,
13
+ "tokenize_chinese_chars": true,
14
+ "tokenizer_class": "BertTokenizer",
15
+ "unk_token": "[UNK]"
16
+ }
tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetLDMModel",
3
+ "_diffusers_version": "0.0.1",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "channel_mult": [
10
+ 1,
11
+ 2,
12
+ 4,
13
+ 4
14
+ ],
15
+ "context_dim": 1280,
16
+ "conv_resample": true,
17
+ "dims": 2,
18
+ "dropout": 0,
19
+ "image_size": 32,
20
+ "in_channels": 4,
21
+ "legacy": false,
22
+ "model_channels": 320,
23
+ "n_embed": null,
24
+ "name_or_path": "../fusing-models/unet/",
25
+ "num_classes": null,
26
+ "num_head_channels": -1,
27
+ "num_heads": 8,
28
+ "num_heads_upsample": -1,
29
+ "num_res_blocks": 2,
30
+ "out_channels": 4,
31
+ "resblock_updown": false,
32
+ "transformer_depth": 1,
33
+ "use_checkpoint": false,
34
+ "use_fp16": false,
35
+ "use_new_attention_order": false,
36
+ "use_scale_shift_norm": false,
37
+ "use_spatial_transformer": true
38
+ }
unet/diffusion_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95549fac1575e6dc07e532a2e5fbf2e2dc3844bdd25224aa5e9d07f74ae2ede6
3
+ size 3489482533
vqvae/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.0.1",
4
+ "attn_resolutions": [],
5
+ "ch": 128,
6
+ "ch_mult": [
7
+ 1,
8
+ 2,
9
+ 4,
10
+ 4
11
+ ],
12
+ "double_z": true,
13
+ "dropout": 0.0,
14
+ "embed_dim": 4,
15
+ "give_pre_end": false,
16
+ "in_channels": 3,
17
+ "name_or_path": "../fusing-models/vqvae/",
18
+ "num_res_blocks": 2,
19
+ "out_ch": 3,
20
+ "remap": null,
21
+ "resamp_with_conv": true,
22
+ "resolution": 256,
23
+ "sane_index_shape": false,
24
+ "z_channels": 4
25
+ }
vqvae/diffusion_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40e9811ed7c6c4775c110fd8347ee11283d99b603852f96863d172a23787a3b5
3
+ size 334704849