Jonathan Malott commited on
Commit
44df93e
·
0 Parent(s):
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .ipynb_checkpoints/
2
+
3
+
4
+ __pycache__/
5
+
6
+
7
+ _archives/
8
+
9
+
10
+ _exampleImages/
11
+
12
+
13
+ _trash/
14
+
15
+
16
+ minDALL-E/
17
+
18
+
19
+
20
+ temp/
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: sh setup.sh && streamlit run streamlit_app.py
dalle/__init__.py ADDED
File without changes
dalle/models/__init__.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import pytorch_lightning as pl
11
+ from typing import Optional, Tuple
12
+ from omegaconf import OmegaConf
13
+ from torch.cuda.amp import autocast
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+ from torch.nn import functional as F
16
+ from .stage1.vqgan import VQGAN
17
+ from .stage2.transformer import Transformer1d, iGPT
18
+ from .. import utils
19
+ from ..utils.config import get_base_config
20
+ from ..utils.sampling import sampling, sampling_igpt
21
+ from .tokenizer import build_tokenizer
22
+
23
+ _MODELS = {
24
+ 'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
25
+ }
26
+
27
+
28
+ class Dalle(nn.Module):
29
+ def __init__(self,
30
+ config: OmegaConf) -> None:
31
+ super().__init__()
32
+ self.tokenizer = None
33
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
34
+ embed_dim=config.stage1.embed_dim,
35
+ hparams=config.stage1.hparams)
36
+ self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
37
+ vocab_size_img=config.stage2.vocab_size_img,
38
+ hparams=config.stage2.hparams)
39
+ self.config_stage1 = config.stage1
40
+ self.config_stage2 = config.stage2
41
+ self.config_dataset = config.dataset
42
+
43
+ @classmethod
44
+ def from_pretrained(cls,
45
+ path: str) -> nn.Module:
46
+ #path = _MODELS[path] if path in _MODELS else path
47
+ #path = utils.realpath_url_or_path(path, root=os.path.expanduser(".cache/minDALL-E"))
48
+ path = ''
49
+
50
+ config_base = get_base_config()
51
+ config_new = OmegaConf.load(os.path.join(path, '.cache/minDALL-E/1.3B/config.yaml'))
52
+ config_update = OmegaConf.merge(config_base, config_new)
53
+
54
+ model = cls(config_update)
55
+ model.tokenizer = build_tokenizer('.cache/minDALL-E/1.3B/tokenizer',
56
+ context_length=model.config_dataset.context_length,
57
+ lowercase=True,
58
+ dropout=None)
59
+ model.stage1.from_ckpt('.cache/minDALL-E/1.3B/stage1_last.ckpt')
60
+ model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
61
+ #model.stage1.from_ckpt('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt')
62
+ #model.stage2.from_ckpt('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt')
63
+
64
+ return model
65
+
66
+ @torch.no_grad()
67
+ def sampling(self,
68
+ prompt: str,
69
+ top_k: int = 256,
70
+ top_p: Optional[float] = None,
71
+ softmax_temperature: float = 1.0,
72
+ num_candidates: int = 96,
73
+ device: str = 'cuda:0',
74
+ use_fp16: bool = True) -> torch.FloatTensor:
75
+ self.stage1.eval()
76
+ self.stage2.eval()
77
+
78
+ tokens = self.tokenizer.encode(prompt)
79
+ tokens = torch.LongTensor(tokens.ids)
80
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
81
+
82
+ # Check if the encoding works as intended
83
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
84
+
85
+ tokens = tokens.to(device)
86
+ codes = sampling(self.stage2,
87
+ tokens,
88
+ top_k=top_k,
89
+ top_p=top_p,
90
+ softmax_temperature=softmax_temperature,
91
+ use_fp16=use_fp16)
92
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
93
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
94
+ return pixels
95
+
96
+
97
+ class ImageGPT(pl.LightningModule):
98
+ def __init__(self,
99
+ config: OmegaConf) -> None:
100
+ super().__init__()
101
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
102
+ embed_dim=config.stage1.embed_dim,
103
+ hparams=config.stage1.hparams)
104
+ self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
105
+ use_cls_cond=config.stage2.use_cls_cond,
106
+ hparams=config.stage2.hparams)
107
+ self.config = config
108
+ self.use_cls_cond = config.stage2.use_cls_cond
109
+
110
+ # make the parameters in stage 1 not trainable
111
+ self.stage1.eval()
112
+ for p in self.stage1.parameters():
113
+ p.requires_grad = False
114
+
115
+ @classmethod
116
+ def from_pretrained(cls,
117
+ path_upstream: str,
118
+ path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
119
+ config_base = get_base_config(use_default=False)
120
+ config_down = OmegaConf.load(path_downstream)
121
+ config_down = OmegaConf.merge(config_base, config_down)
122
+
123
+ model = cls(config_down)
124
+ model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
125
+ model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
126
+ return model, config_down
127
+
128
+ def sample(self,
129
+ cls_idx: Optional[int] = None,
130
+ top_k: int = 256,
131
+ top_p: Optional[float] = None,
132
+ softmax_temperature: float = 1.0,
133
+ num_candidates: int = 16,
134
+ device: str = 'cuda:0',
135
+ use_fp16: bool = True,
136
+ is_tqdm: bool = True) -> torch.FloatTensor:
137
+ self.stage1.eval()
138
+ self.stage2.eval()
139
+
140
+ if cls_idx is None:
141
+ sos = self.stage2.sos.repeat(num_candidates, 1, 1)
142
+ else:
143
+ sos = torch.LongTensor([cls_idx]).to(device=device)
144
+ sos = sos.repeat(num_candidates)
145
+ sos = self.stage2.sos(sos).unsqueeze(1)
146
+
147
+ codes = sampling_igpt(self.stage2,
148
+ sos=sos,
149
+ top_k=top_k,
150
+ top_p=top_p,
151
+ softmax_temperature=softmax_temperature,
152
+ use_fp16=use_fp16,
153
+ is_tqdm=is_tqdm)
154
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
155
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
156
+ return pixels
157
+
158
+ def forward(self,
159
+ images: torch.FloatTensor,
160
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
161
+ B, C, H, W = images.shape
162
+ with torch.no_grad():
163
+ with autocast(enabled=False):
164
+ codes = self.stage1.get_codes(images).detach()
165
+ logits = self.stage2(codes, labels)
166
+ return logits, codes
167
+
168
+ def training_step(self, batch, batch_idx):
169
+ images, labels = batch
170
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
171
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
172
+ self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
173
+ return loss
174
+
175
+ def validation_step(self, batch, batch_idx):
176
+ images, labels = batch
177
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
178
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
179
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
180
+ return loss
181
+
182
+ def configure_optimizers(self):
183
+ assert self.config.optimizer.opt_type == 'adamW'
184
+ assert self.config.optimizer.sched_type == 'cosine'
185
+
186
+ opt = torch.optim.AdamW(self.parameters(),
187
+ lr=self.config.optimizer.base_lr,
188
+ betas=self.config.optimizer.betas,
189
+ weight_decay=self.config.optimizer.weight_decay)
190
+ sched = CosineAnnealingLR(opt,
191
+ T_max=self.config.optimizer.max_steps,
192
+ eta_min=self.config.optimizer.min_lr)
193
+ sched = {
194
+ 'scheduler': sched,
195
+ 'name': 'cosine'
196
+ }
197
+ return [opt], [sched]
198
+
199
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
200
+ on_tpu=False, using_native_amp=False, using_lbfgs=False):
201
+ optimizer.step(closure=optimizer_closure)
202
+ self.lr_schedulers().step()
203
+ self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
204
+
205
+ def on_epoch_start(self):
206
+ self.stage1.eval()
dalle/models/stage1/layers.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple, Optional
9
+
10
+
11
+ def nonlinearity(x):
12
+ # swish
13
+ return x*torch.sigmoid(x)
14
+
15
+
16
+ def Normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32,
18
+ num_channels=in_channels,
19
+ eps=1e-6,
20
+ affine=True)
21
+
22
+
23
+ class Upsample(nn.Module):
24
+ def __init__(self, in_channels, with_conv):
25
+ super().__init__()
26
+ self.with_conv = with_conv
27
+ if self.with_conv:
28
+ self.conv = torch.nn.Conv2d(in_channels,
29
+ in_channels,
30
+ kernel_size=3,
31
+ stride=1,
32
+ padding=1)
33
+
34
+ def forward(self, x):
35
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
36
+ if self.with_conv:
37
+ x = self.conv(x)
38
+ return x
39
+
40
+
41
+ class Downsample(nn.Module):
42
+ def __init__(self, in_channels, with_conv):
43
+ super().__init__()
44
+ self.with_conv = with_conv
45
+ if self.with_conv:
46
+ # no asymmetric padding in torch conv, must do it ourselves
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=2,
51
+ padding=0)
52
+
53
+ def forward(self, x):
54
+ if self.with_conv:
55
+ pad = (0, 1, 0, 1)
56
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
57
+ x = self.conv(x)
58
+ else:
59
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
60
+ return x
61
+
62
+
63
+ class ResnetBlock(nn.Module):
64
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
65
+ dropout, temb_channels=512):
66
+ assert temb_channels == 0
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+ self.use_conv_shortcut = conv_shortcut
72
+
73
+ self.norm1 = Normalize(in_channels)
74
+ self.conv1 = torch.nn.Conv2d(in_channels,
75
+ out_channels,
76
+ kernel_size=3,
77
+ stride=1,
78
+ padding=1)
79
+ self.norm2 = Normalize(out_channels)
80
+ self.dropout = torch.nn.Dropout(dropout)
81
+ self.conv2 = torch.nn.Conv2d(out_channels,
82
+ out_channels,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1)
86
+ if self.in_channels != self.out_channels:
87
+ if self.use_conv_shortcut:
88
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ else:
94
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
95
+ out_channels,
96
+ kernel_size=1,
97
+ stride=1,
98
+ padding=0)
99
+
100
+ def forward(self, x, temb=None):
101
+ assert temb is None
102
+
103
+ h = x
104
+ h = self.norm1(h)
105
+ h = nonlinearity(h)
106
+ h = self.conv1(h)
107
+
108
+ h = self.norm2(h)
109
+ h = nonlinearity(h)
110
+ h = self.dropout(h)
111
+ h = self.conv2(h)
112
+
113
+ if self.in_channels != self.out_channels:
114
+ if self.use_conv_shortcut:
115
+ x = self.conv_shortcut(x)
116
+ else:
117
+ x = self.nin_shortcut(x)
118
+ return x+h
119
+
120
+
121
+ class AttnBlock(nn.Module):
122
+ def __init__(self, in_channels):
123
+ super().__init__()
124
+ self.in_channels = in_channels
125
+
126
+ self.norm = Normalize(in_channels)
127
+ self.q = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+ self.k = torch.nn.Conv2d(in_channels,
133
+ in_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+ self.v = torch.nn.Conv2d(in_channels,
138
+ in_channels,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0)
142
+ self.proj_out = torch.nn.Conv2d(in_channels,
143
+ in_channels,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+
148
+ def forward(self, x):
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h*w)
158
+ q = q.permute(0, 2, 1) # b,hw,c
159
+ k = k.reshape(b, c, h*w) # b,c,hw
160
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
161
+ w_ = w_ * (int(c)**(-0.5))
162
+ w_ = torch.nn.functional.softmax(w_, dim=2)
163
+
164
+ # attend to values
165
+ v = v.reshape(b, c, h*w)
166
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
167
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+ return x+h_
172
+
173
+
174
+ class Encoder(nn.Module):
175
+ def __init__(self,
176
+ *, # forced to use named arguments
177
+ ch: int,
178
+ out_ch: int,
179
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
180
+ num_res_blocks: int,
181
+ attn_resolutions: Tuple[int],
182
+ pdrop: float = 0.0,
183
+ resamp_with_conv: bool = True,
184
+ in_channels: int,
185
+ resolution: int,
186
+ z_channels: int,
187
+ double_z: Optional[bool] = None) -> None:
188
+ super().__init__()
189
+ self.ch = ch
190
+ self.temb_ch = 0
191
+ self.num_resolutions = len(ch_mult)
192
+ self.num_res_blocks = num_res_blocks
193
+ self.resolution = resolution
194
+ self.in_channels = in_channels
195
+
196
+ # downsampling
197
+ self.conv_in = torch.nn.Conv2d(in_channels,
198
+ self.ch,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ curr_res = resolution
204
+ in_ch_mult = (1,)+tuple(ch_mult)
205
+ self.down = nn.ModuleList()
206
+ for i_level in range(self.num_resolutions):
207
+ block = nn.ModuleList()
208
+ attn = nn.ModuleList()
209
+ block_in = ch*in_ch_mult[i_level]
210
+ block_out = ch*ch_mult[i_level]
211
+ for i_block in range(self.num_res_blocks):
212
+ block.append(ResnetBlock(in_channels=block_in,
213
+ out_channels=block_out,
214
+ temb_channels=self.temb_ch,
215
+ dropout=pdrop))
216
+ block_in = block_out
217
+ if curr_res in attn_resolutions:
218
+ attn.append(AttnBlock(block_in))
219
+ down = nn.Module()
220
+ down.block = block
221
+ down.attn = attn
222
+ if i_level != self.num_resolutions-1:
223
+ down.downsample = Downsample(block_in, resamp_with_conv)
224
+ curr_res = curr_res // 2
225
+ self.down.append(down)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=pdrop)
233
+ self.mid.attn_1 = AttnBlock(block_in)
234
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
235
+ out_channels=block_in,
236
+ temb_channels=self.temb_ch,
237
+ dropout=pdrop)
238
+
239
+ # end
240
+ self.norm_out = Normalize(block_in)
241
+ self.conv_out = torch.nn.Conv2d(block_in,
242
+ 2*z_channels if double_z else z_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ def forward(self, x):
248
+ assert x.shape[2] == x.shape[3] == self.resolution, \
249
+ "{}, {}".format(x.shape, self.resolution)
250
+
251
+ # downsampling
252
+ h = self.conv_in(x)
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](h)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ if i_level != self.num_resolutions-1:
259
+ h = self.down[i_level].downsample(h)
260
+
261
+ # middle
262
+ h = self.mid.block_1(h)
263
+ h = self.mid.attn_1(h)
264
+ h = self.mid.block_2(h)
265
+
266
+ # end
267
+ h = self.norm_out(h)
268
+ h = nonlinearity(h)
269
+ h = self.conv_out(h)
270
+ return h
271
+
272
+
273
+ class Decoder(nn.Module):
274
+ def __init__(self,
275
+ *, # forced to use named arguments
276
+ ch: int,
277
+ out_ch: int,
278
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
279
+ num_res_blocks: int,
280
+ attn_resolutions: Tuple[int],
281
+ pdrop: float = 0.0,
282
+ resamp_with_conv: bool = True,
283
+ in_channels: int,
284
+ resolution: int,
285
+ z_channels: int,
286
+ double_z: bool) -> None:
287
+ super().__init__()
288
+ self.ch = ch
289
+ self.temb_ch = 0
290
+ self.num_resolutions = len(ch_mult)
291
+ self.num_res_blocks = num_res_blocks
292
+ self.resolution = resolution
293
+ self.in_channels = in_channels
294
+
295
+ # compute in_ch_mult, block_in and curr_res at lowest res
296
+ block_in = ch*ch_mult[self.num_resolutions-1]
297
+ curr_res = resolution // 2**(self.num_resolutions-1)
298
+ self.z_shape = (1, z_channels, curr_res, curr_res)
299
+
300
+ # z to block_in
301
+ self.conv_in = torch.nn.Conv2d(z_channels,
302
+ block_in,
303
+ kernel_size=3,
304
+ stride=1,
305
+ padding=1)
306
+
307
+ # middle
308
+ self.mid = nn.Module()
309
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
310
+ out_channels=block_in,
311
+ temb_channels=self.temb_ch,
312
+ dropout=pdrop)
313
+ self.mid.attn_1 = AttnBlock(block_in)
314
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
315
+ out_channels=block_in,
316
+ temb_channels=self.temb_ch,
317
+ dropout=pdrop)
318
+
319
+ # upsampling
320
+ self.up = nn.ModuleList()
321
+ for i_level in reversed(range(self.num_resolutions)):
322
+ block = nn.ModuleList()
323
+ attn = nn.ModuleList()
324
+ block_out = ch*ch_mult[i_level]
325
+ for i_block in range(self.num_res_blocks+1):
326
+ block.append(ResnetBlock(in_channels=block_in,
327
+ out_channels=block_out,
328
+ temb_channels=self.temb_ch,
329
+ dropout=pdrop))
330
+ block_in = block_out
331
+ if curr_res in attn_resolutions:
332
+ attn.append(AttnBlock(block_in))
333
+ up = nn.Module()
334
+ up.block = block
335
+ up.attn = attn
336
+ if i_level != 0:
337
+ up.upsample = Upsample(block_in, resamp_with_conv)
338
+ curr_res = curr_res * 2
339
+ self.up.insert(0, up) # prepend to get consistent order
340
+
341
+ # end
342
+ self.norm_out = Normalize(block_in)
343
+ self.conv_out = torch.nn.Conv2d(block_in,
344
+ out_ch,
345
+ kernel_size=3,
346
+ stride=1,
347
+ padding=1)
348
+
349
+ def forward(self, z):
350
+ assert z.shape[1:] == self.z_shape[1:]
351
+ self.last_z_shape = z.shape
352
+
353
+ # z to block_in
354
+ h = self.conv_in(z)
355
+
356
+ # middle
357
+ h = self.mid.block_1(h)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](h)
365
+ if len(self.up[i_level].attn) > 0:
366
+ h = self.up[i_level].attn[i_block](h)
367
+ if i_level != 0:
368
+ h = self.up[i_level].upsample(h)
369
+
370
+ h = self.norm_out(h)
371
+ h = nonlinearity(h)
372
+ h = self.conv_out(h)
373
+ return h
dalle/models/stage1/vqgan.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Tuple, Optional
9
+ from einops import rearrange
10
+ from omegaconf import OmegaConf
11
+ from .layers import Encoder, Decoder
12
+
13
+
14
+ class VectorQuantizer(nn.Module):
15
+ """
16
+ Simplified VectorQuantizer in the original VQGAN repository
17
+ by removing unncessary modules for sampling
18
+ """
19
+ def __init__(self, dim: int, n_embed: int, beta: float) -> None:
20
+ super().__init__()
21
+ self.n_embed = n_embed
22
+ self.dim = dim
23
+ self.beta = beta
24
+
25
+ self.embedding = nn.Embedding(self.n_embed, self.dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
27
+
28
+ def forward(self,
29
+ z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
30
+ z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
31
+ z_flattened = z.view(-1, self.dim)
32
+
33
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
34
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
35
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
36
+
37
+ min_encoding_indices = torch.argmin(d, dim=1)
38
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
39
+ return z_q, min_encoding_indices
40
+
41
+ def get_codebook_entry(self,
42
+ indices: torch.LongTensor,
43
+ shape: Optional[List[int]] = None) -> torch.FloatTensor:
44
+ z_q = self.embedding(indices)
45
+ if shape is not None:
46
+ z_q = z_q.view(shape)
47
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
48
+ return z_q
49
+
50
+
51
+ class VQGAN(nn.Module):
52
+ def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
53
+ super().__init__()
54
+ self.encoder = Encoder(**hparams)
55
+ self.decoder = Decoder(**hparams)
56
+ self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
57
+ self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
58
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
59
+ self.latent_dim = hparams.attn_resolutions[0]
60
+
61
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
62
+ quant = self.encode(x)
63
+ dec = self.decode(quant)
64
+ return dec
65
+
66
+ def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
67
+ h = self.encoder(x)
68
+ h = self.quant_conv(h)
69
+ quant = self.quantize(h)[0]
70
+ quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
71
+ return quant
72
+
73
+ def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
74
+ quant = self.post_quant_conv(quant)
75
+ dec = self.decoder(quant)
76
+ return dec
77
+
78
+ def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
79
+ quant = self.quantize.get_codebook_entry(code)
80
+ quant = quant.permute(0, 3, 1, 2)
81
+ dec = self.decode(quant)
82
+ return dec
83
+
84
+ def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
85
+ h = self.encoder(x)
86
+ h = self.quant_conv(h)
87
+ codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
88
+ return codes
89
+
90
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
91
+ #ckpt = torch.load(path, map_location='cpu')['state_dict']
92
+ #self.load_state_dict(ckpt, strict=strict)
93
+ #print(f'{path} successfully restored..')
94
+
95
+ #ckpt = torch.load(path, map_location='cpu')['state_dict']
96
+ ckpt = torch.utils.model_zoo.load_url('https://utexas.box.com/shared/static/rpt9miyj2kikogyekpqnkd6y115xp51i.ckpt', map_location='cpu')['state_dict']
97
+
98
+ self.load_state_dict(ckpt, strict=True)
99
+ print(f'{path} succesfully restored..')
dalle/models/stage2/layers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ class GELU(nn.Module):
17
+ def __init__(self, use_approx=False):
18
+ super().__init__()
19
+ self.use_approx = use_approx
20
+
21
+ def forward(self, x):
22
+ if self.use_approx:
23
+ return x * torch.sigmoid(1.702 * x)
24
+ else:
25
+ return F.gelu(x)
26
+
27
+
28
+ class MultiHeadSelfAttention(nn.Module):
29
+
30
+ def __init__(self,
31
+ ctx_len: int,
32
+ embed_dim: int,
33
+ n_heads: int,
34
+ resid_pdrop: float,
35
+ attn_pdrop: float,
36
+ attn_bias: bool,
37
+ use_mask: bool = True):
38
+ super().__init__()
39
+ assert embed_dim % n_heads == 0
40
+
41
+ # key, query, value projections for all heads
42
+ self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
43
+ self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
44
+ self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
45
+
46
+ # regularization
47
+ self.attn_drop = nn.Dropout(attn_pdrop)
48
+ self.resid_drop = nn.Dropout(resid_pdrop)
49
+
50
+ # output projection
51
+ self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
52
+
53
+ self.n_heads = n_heads
54
+ self.ctx_len = ctx_len
55
+ self.use_mask = use_mask
56
+ if self.use_mask:
57
+ self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
58
+ self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
59
+
60
+ def forward(self, x, use_cache=False, layer_past=None):
61
+ B, T, C = x.shape
62
+ x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66
+ q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67
+ v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
68
+
69
+ if use_cache:
70
+ present = torch.stack([k, v])
71
+
72
+ if layer_past is not None:
73
+ past_key, past_value = layer_past
74
+ k = torch.cat([past_key, k], dim=-2)
75
+ v = torch.cat([past_value, v], dim=-2)
76
+
77
+ if use_cache and layer_past is not None:
78
+ # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
79
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
80
+ att = F.softmax(att, dim=-1)
81
+ att = self.attn_drop(att)
82
+ y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
83
+ else:
84
+ # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
85
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
86
+ if self.use_mask:
87
+ mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
88
+ att = att.masked_fill(mask == 0, float('-inf'))
89
+ att = F.softmax(att, dim=-1)
90
+ att = self.attn_drop(att)
91
+ y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
92
+ y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
93
+
94
+ # output projection
95
+ y = self.resid_drop(self.proj(y))
96
+ if use_cache:
97
+ return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
98
+ else:
99
+ return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
100
+
101
+
102
+ class Block(nn.Module):
103
+
104
+ def __init__(self,
105
+ ctx_len: int,
106
+ embed_dim: int,
107
+ n_heads: int,
108
+ mlp_bias: bool,
109
+ attn_bias: bool,
110
+ resid_pdrop: bool,
111
+ attn_pdrop: bool,
112
+ gelu_use_approx: bool):
113
+ super().__init__()
114
+ self.ln1 = nn.LayerNorm(embed_dim)
115
+ self.ln2 = nn.LayerNorm(embed_dim)
116
+
117
+ self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
118
+ embed_dim=embed_dim,
119
+ n_heads=n_heads,
120
+ attn_pdrop=attn_pdrop,
121
+ resid_pdrop=resid_pdrop,
122
+ attn_bias=attn_bias,
123
+ use_mask=True)
124
+ self.mlp = nn.Sequential(
125
+ nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
126
+ GELU(gelu_use_approx),
127
+ nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
128
+ nn.Dropout(resid_pdrop),
129
+ )
130
+
131
+ def forward(self, x):
132
+ x = x + self.attn(self.ln1(x))
133
+ x = x + self.mlp(self.ln2(x))
134
+ return x
135
+
136
+ def sample(self, x, layer_past=None):
137
+ attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
138
+ x = x + attn
139
+ x = x + self.mlp(self.ln2(x))
140
+ return x, present
dalle/models/stage2/transformer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List
13
+ from torch.cuda.amp import autocast
14
+ from omegaconf import OmegaConf
15
+ from .layers import Block
16
+ import io
17
+
18
+ class Transformer1d(nn.Module):
19
+
20
+ def __init__(self,
21
+ vocab_size_txt: int,
22
+ vocab_size_img: int,
23
+ hparams: OmegaConf) -> None:
24
+ super().__init__()
25
+ assert hparams.n_layers == hparams.n_dense_layers
26
+
27
+ # input embedding for image and text
28
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
29
+ self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
30
+
31
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
32
+ self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
33
+
34
+ self.drop = nn.Dropout(hparams.embd_pdrop)
35
+
36
+ # transformer blocks
37
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
38
+ embed_dim=hparams.embed_dim,
39
+ n_heads=hparams.n_heads,
40
+ mlp_bias=hparams.mlp_bias,
41
+ attn_bias=hparams.attn_bias,
42
+ resid_pdrop=hparams.resid_pdrop,
43
+ attn_pdrop=hparams.attn_pdrop,
44
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
45
+ self.blocks = nn.Sequential(*self.blocks)
46
+
47
+ # heads for image and text
48
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
49
+ self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
50
+ self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
51
+
52
+ self.ctx_len_img = hparams.ctx_len_img
53
+ self.ctx_len_txt = hparams.ctx_len_txt
54
+ self.n_layers = hparams.n_layers
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, module: nn.Module) -> None:
59
+ if isinstance(module, (nn.Linear, nn.Embedding)):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+ elif isinstance(module, nn.LayerNorm):
64
+ module.bias.data.zero_()
65
+ module.weight.data.fill_(1.0)
66
+
67
+ def forward(self,
68
+ images: torch.LongTensor,
69
+ texts: torch.LongTensor,
70
+ pos_images: torch.LongTensor,
71
+ pos_texts: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
72
+ B, T = images.shape
73
+ _, N = texts.shape
74
+
75
+ assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
76
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
77
+
78
+ texts = self.tok_emb_txt(texts)
79
+ images = self.tok_emb_img(images)
80
+
81
+ texts = texts + self.pos_emb_txt(pos_texts)
82
+ images = images + self.pos_emb_img(pos_images)
83
+
84
+ x = torch.cat([texts, images], axis=1).contiguous()
85
+ x = self.drop(x)
86
+ x = self.blocks(x)
87
+ x = self.ln_f(x)
88
+
89
+ texts = x[:, :N-1].contiguous()
90
+ images = x[:, N-1:-1].contiguous()
91
+
92
+ logits_txt = self.head_txt(texts)
93
+ logits_img = self.head_img(images)
94
+ return logits_img, logits_txt
95
+
96
+ @torch.no_grad()
97
+ def sampling(self,
98
+ images: torch.LongTensor,
99
+ texts: torch.LongTensor,
100
+ pos_images: torch.LongTensor,
101
+ pos_texts: torch.LongTensor,
102
+ use_fp16: bool = True,
103
+ past: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
104
+ _, N = texts.shape
105
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
106
+
107
+ with autocast(enabled=use_fp16):
108
+ if images is None:
109
+ assert past is None
110
+
111
+ texts = self.tok_emb_txt(texts)
112
+ x = texts + self.pos_emb_txt(pos_texts)
113
+ x = self.drop(x)
114
+
115
+ presents = []
116
+ for i, block in enumerate(self.blocks):
117
+ x, present = block.sample(x, layer_past=None)
118
+ presents.append(present)
119
+ x = self.ln_f(x)
120
+ x = x[:, N-1].contiguous()
121
+ logits = self.head_img(x)
122
+ else:
123
+ if past is None:
124
+ texts = self.tok_emb_txt(texts)
125
+ images = self.tok_emb_img(images)
126
+ texts = texts + self.pos_emb_txt(pos_texts)
127
+ images = images + self.pos_emb_img(pos_images)
128
+ x = torch.cat([texts, images], axis=1).contiguous()
129
+ else:
130
+ images = self.tok_emb_img(images)
131
+ x = images + self.pos_emb_img(pos_images)
132
+ x = self.drop(x)
133
+
134
+ if past is not None:
135
+ past = torch.cat(past, dim=-2)
136
+ presents = []
137
+ for i, block in enumerate(self.blocks):
138
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
139
+ presents.append(present)
140
+ x = self.ln_f(x)
141
+ x = x[:, -1].contiguous()
142
+ logits = self.head_img(x)
143
+ return logits, presents
144
+
145
+ def from_ckpt(self, path: str) -> None:
146
+ #ckpt = torch.load(path, map_location='cpu')['state_dict']
147
+ ckpt = torch.utils.model_zoo.load_url('https://utexas.box.com/shared/static/54jc9fw0bious5nx6wvayeqaskcrdgv4.ckpt', map_location='cpu')['state_dict']
148
+
149
+ self.load_state_dict(ckpt, strict=True)
150
+ print(f'{path} succesfully restored..')
151
+
152
+
153
+ class iGPT(nn.Module):
154
+ def __init__(self,
155
+ vocab_size_img: int,
156
+ use_cls_cond: bool,
157
+ hparams: OmegaConf) -> None:
158
+ super().__init__()
159
+ self.use_cls_cond = use_cls_cond
160
+
161
+ # sos token embedding
162
+ if self.use_cls_cond:
163
+ self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
164
+ else:
165
+ self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
166
+
167
+ # input embedding
168
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
169
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
170
+
171
+ self.drop = nn.Dropout(hparams.embd_pdrop)
172
+
173
+ # transformer blocks
174
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
175
+ embed_dim=hparams.embed_dim,
176
+ n_heads=hparams.n_heads,
177
+ mlp_bias=hparams.mlp_bias,
178
+ attn_bias=hparams.attn_bias,
179
+ resid_pdrop=hparams.resid_pdrop,
180
+ attn_pdrop=hparams.attn_pdrop,
181
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
182
+ self.blocks = nn.Sequential(*self.blocks)
183
+
184
+ # head
185
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
186
+ self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
187
+
188
+ self.ctx_len_img = hparams.ctx_len_img
189
+ self.n_layers = hparams.n_layers
190
+
191
+ self.apply(self._init_weights)
192
+
193
+ def _init_weights(self, module: nn.Module) -> None:
194
+ if isinstance(module, (nn.Linear, nn.Embedding)):
195
+ module.weight.data.normal_(mean=0.0, std=0.02)
196
+ if isinstance(module, nn.Linear) and module.bias is not None:
197
+ module.bias.data.zero_()
198
+ elif isinstance(module, nn.LayerNorm):
199
+ module.bias.data.zero_()
200
+ module.weight.data.fill_(1.0)
201
+
202
+ @torch.no_grad()
203
+ def sampling(self,
204
+ sos: torch.FloatTensor,
205
+ codes: torch.LongTensor,
206
+ pos_codes: torch.LongTensor,
207
+ n_samples: int = 16,
208
+ use_fp16: bool = True,
209
+ past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
210
+ with autocast(enabled=use_fp16):
211
+ if codes is None:
212
+ assert past is None
213
+ xs = self.drop(sos)
214
+ presents = []
215
+ for i, block in enumerate(self.blocks):
216
+ xs, present = block.sample(xs, layer_past=None)
217
+ presents.append(present)
218
+ xs = self.ln_f(xs)
219
+ logits = self.head(xs)[:, -1]
220
+ else:
221
+ if past is None:
222
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
223
+ xs = torch.cat([sos, xs], dim=1)
224
+ else:
225
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
226
+ xs = self.drop(xs)
227
+
228
+ past = torch.cat(past, dim=-2) if past is not None else past
229
+ presents = []
230
+ for i, block in enumerate(self.blocks):
231
+ xs, present = block.sample(xs, layer_past=None if past is None else past[i])
232
+ presents.append(present)
233
+
234
+ xs = self.ln_f(xs)
235
+ logits = self.head(xs)[:, -1]
236
+ return logits, presents
237
+
238
+ def forward(self,
239
+ codes: torch.LongTensor,
240
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
241
+ B, T = codes.shape
242
+ xps = torch.arange(T, device=codes.device).repeat((B, 1))
243
+ sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
244
+
245
+ h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
246
+ h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
247
+
248
+ h = self.drop(h)
249
+ h = self.blocks(h)
250
+ h = self.ln_f(h)
251
+ logits = self.head(h)
252
+ return logits
253
+
254
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
255
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
256
+ self.load_state_dict(ckpt, strict=strict)
257
+ print(f'{path} successfully restored..')
dalle/models/tokenizer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ from functools import partial
9
+ from tokenizers import CharBPETokenizer
10
+
11
+
12
+ def build_tokenizer(path: str,
13
+ context_length: int = 64,
14
+ *args,
15
+ **kwargs):
16
+ from_file = partial(CharBPETokenizer.from_file,
17
+ vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
18
+ merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
19
+ unk_token='[UNK]')
20
+ tokenizer = from_file(*args, **kwargs)
21
+ tokenizer.add_special_tokens(['[PAD]'])
22
+ tokenizer.enable_padding(length=context_length,
23
+ pad_id=tokenizer.token_to_id('[PAD]'))
24
+ tokenizer.enable_truncation(max_length=context_length)
25
+ print(f'{path} successfully restored..')
26
+ return tokenizer
dalle/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .config import *
3
+ from .sampling import *
dalle/utils/config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ from typing import Optional, List
8
+ from dataclasses import dataclass, field
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+ @dataclass
13
+ class DataConfig:
14
+ dataset: Optional[str] = None
15
+ tokenizer_type: str = 'CharBPE'
16
+ context_length: int = 64
17
+ image_resolution: int = 256
18
+ transforms: str = 'dalle-vqvae'
19
+ bpe_pdrop: Optional[float] = None
20
+
21
+
22
+ @dataclass
23
+ class Stage1Hparams:
24
+ double_z: bool = False
25
+ z_channels: int = 256
26
+ resolution: int = 256
27
+ in_channels: int = 3
28
+ out_ch: int = 3
29
+ ch: int = 128
30
+ ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
31
+ num_res_blocks: int = 2
32
+ attn_resolutions: List[int] = field(default_factory=lambda: [16])
33
+ pdrop: float = 0.0
34
+
35
+
36
+ @dataclass
37
+ class Stage2Hparams:
38
+ embed_dim: int = 1536
39
+ n_layers: int = 42
40
+ n_heads: int = 24
41
+ n_dense_layers: int = 42
42
+ ctx_len_img: int = 256
43
+ ctx_len_txt: int = 64
44
+ embd_pdrop: float = 0.0
45
+ resid_pdrop: float = 0.0
46
+ attn_pdrop: float = 0.0
47
+ mlp_bias: bool = True
48
+ attn_bias: bool = True
49
+ gelu_use_approx: bool = False
50
+ use_head_txt: bool = True
51
+ n_classes: Optional[int] = None
52
+
53
+
54
+ @dataclass
55
+ class Stage1Config:
56
+ type: str = 'vqgan'
57
+ embed_dim: int = 256
58
+ n_embed: int = 16384
59
+ hparams: Stage1Hparams = Stage1Hparams()
60
+
61
+
62
+ @dataclass
63
+ class Stage2Config:
64
+ type: str = 'transformer1d'
65
+ vocab_size_txt: int = 16384
66
+ vocab_size_img: int = 16384
67
+ use_cls_cond: Optional[bool] = None
68
+ hparams: Stage2Hparams = Stage2Hparams()
69
+
70
+
71
+ @dataclass
72
+ class WarmupConfig:
73
+ epoch: int = 1
74
+ multiplier: int = 1
75
+ buffer_epoch: int = 0
76
+ min_lr: float = 0.0
77
+ mode: str = 'fix'
78
+ peak_lr: float = 1e-4
79
+ start_from_zero: bool = True
80
+
81
+
82
+ @dataclass
83
+ class OptConfig:
84
+ opt_type: str = 'adamW'
85
+ base_lr: float = 1e-4
86
+ weight_decay: float = 1e-4
87
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
88
+ grad_clip_norm: float = 1.0
89
+
90
+ sched_type: str = 'cosine'
91
+ max_steps: int = 0
92
+ min_lr: float = 0.0
93
+
94
+
95
+ @dataclass
96
+ class ExpConfig:
97
+ local_batch_size: int = 4
98
+ total_batch_size: int = 512
99
+ valid_batch_size: int = 32
100
+ epochs: int = 10
101
+ save_ckpt_freq: int = 2
102
+ test_freq: int = 1
103
+ use_amp: bool = True
104
+
105
+
106
+ @dataclass
107
+ class DefaultConfig:
108
+ dataset: DataConfig = DataConfig()
109
+ stage1: Stage1Config = Stage1Config()
110
+ stage2: Stage2Config = Stage2Config()
111
+
112
+
113
+ @dataclass
114
+ class FineTuningConfig:
115
+ dataset: DataConfig = DataConfig()
116
+ stage1: Stage1Config = Stage1Config()
117
+ stage2: Stage2Config = Stage2Config()
118
+ optimizer: OptConfig = OptConfig()
119
+ experiment: ExpConfig = ExpConfig()
120
+
121
+
122
+ def get_base_config(use_default=True):
123
+ return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
dalle/utils/sampling.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from torch.nn import functional as F
11
+ import streamlit as st
12
+
13
+ def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
14
+ if k is None:
15
+ return logits
16
+ else:
17
+ v, ix = torch.topk(logits, k)
18
+ out = logits.clone()
19
+ out[out < v[:, [-1]]] = -float('Inf')
20
+ return out
21
+
22
+
23
+ def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
24
+ if p is None:
25
+ return probs
26
+ else:
27
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
28
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
29
+
30
+ sorted_idx_remove_cond = cum_probs >= p
31
+
32
+ sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
33
+ sorted_idx_remove_cond[..., 0] = 0
34
+
35
+ indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
36
+ probs = probs.masked_fill(indices_to_remove, 0.0)
37
+ norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
38
+ return norm_probs
39
+
40
+
41
+ def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
42
+ device = inputs.device
43
+ if mode == '1d':
44
+ B, N = inputs.shape
45
+ xs_pos = torch.arange(N, device=device).repeat((B, 1))
46
+ elif mode == '2d':
47
+ B, H, W = inputs.shape
48
+ xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
49
+ xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
50
+ xs_pos = (xs_pos_h, xs_pos_w)
51
+ else:
52
+ raise ValueError('%s positional encoding invalid' % mode)
53
+ return xs_pos
54
+
55
+
56
+ @torch.no_grad()
57
+ def sampling(model: torch.nn.Module,
58
+ tokens: torch.LongTensor,
59
+ top_k: Optional[float] = None,
60
+ top_p: Optional[float] = None,
61
+ softmax_temperature: float = 1.0,
62
+ is_tqdm: bool = True,
63
+ use_fp16: bool = True,
64
+ max_seq_len: int = 256) -> torch.LongTensor:
65
+ code = None
66
+ past = None
67
+
68
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
69
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
70
+
71
+ #my_bar = st.progress(0)
72
+
73
+ for cnt, h in enumerate(pbar):
74
+ if code is None:
75
+ code_ = None
76
+ pos_enc_code_ = None
77
+ else:
78
+ code_ = code.clone().detach()
79
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
80
+ code_ = code_[:, cnt-1].unsqueeze(-1)
81
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
82
+
83
+ logits, present = model.sampling(images=code_,
84
+ texts=tokens,
85
+ pos_images=pos_enc_code_,
86
+ pos_texts=pos_enc_tokens,
87
+ use_fp16=use_fp16,
88
+ past=past)
89
+ logits = logits.to(dtype=torch.float32)
90
+ logits = logits / softmax_temperature
91
+
92
+ present = torch.stack(present).clone().detach()
93
+ if past is None:
94
+ past = [present]
95
+ else:
96
+ past.append(present)
97
+
98
+ logits = cutoff_topk_logits(logits, top_k)
99
+ probs = F.softmax(logits, dim=-1)
100
+ probs = cutoff_topp_probs(probs, top_p)
101
+
102
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
103
+ code = idx if code is None else torch.cat([code, idx], axis=1)
104
+
105
+ #print(cnt/max_seq_len)
106
+ if(st.session_state.page != 0):
107
+ break
108
+
109
+ st.session_state.bar.progress(cnt/max_seq_len)
110
+
111
+ #my_bar.progress(cnt/max_seq_len)
112
+
113
+ del past
114
+ return code
115
+
116
+
117
+ @torch.no_grad()
118
+ def sampling_igpt(model: torch.nn.Module,
119
+ sos: torch.FloatTensor,
120
+ top_k: Optional[float] = None,
121
+ top_p: Optional[float] = None,
122
+ softmax_temperature: float = 1.0,
123
+ is_tqdm: bool = True,
124
+ use_fp16: bool = True,
125
+ max_seq_len: int = 256) -> torch.LongTensor:
126
+ code = None
127
+ past = None
128
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
129
+
130
+ for cnt, h in enumerate(pbar):
131
+ if code is None:
132
+ code_ = None
133
+ pos_enc_code_ = None
134
+ else:
135
+ code_ = code.clone().detach()
136
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
137
+ code_ = code_[:, cnt-1].unsqueeze(-1)
138
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
139
+
140
+ logits, present = model.sampling(sos=sos,
141
+ codes=code_,
142
+ pos_codes=pos_enc_code_,
143
+ use_fp16=use_fp16,
144
+ past=past)
145
+ logits = logits.to(dtype=torch.float32)
146
+ logits = logits / softmax_temperature
147
+
148
+ present = torch.stack(present).clone().detach()
149
+ if past is None:
150
+ past = [present]
151
+ else:
152
+ past.append(present)
153
+
154
+ logits = cutoff_topk_logits(logits, top_k)
155
+ probs = F.softmax(logits, dim=-1)
156
+ probs = cutoff_topp_probs(probs, top_p)
157
+
158
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
159
+ code = idx if code is None else torch.cat([code, idx], axis=1)
160
+
161
+ del past
162
+ return code
dalle/utils/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # minDALL-E
3
+ # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import random
9
+ import urllib
10
+ import hashlib
11
+ import tarfile
12
+ import torch
13
+ import clip
14
+ import numpy as np
15
+ from PIL import Image
16
+ from torch.nn import functional as F
17
+ from tqdm import tqdm
18
+
19
+
20
+ def set_seed(seed: int):
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ @torch.no_grad()
28
+ def clip_score(prompt: str,
29
+ images: np.ndarray,
30
+ model_clip: torch.nn.Module,
31
+ preprocess_clip,
32
+ device: str) -> np.ndarray:
33
+ images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
34
+ images = torch.stack(images, dim=0).to(device=device)
35
+ texts = clip.tokenize(prompt).to(device=device)
36
+ texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
37
+
38
+ image_features = model_clip.encode_image(images)
39
+ text_features = model_clip.encode_text(texts)
40
+
41
+ scores = F.cosine_similarity(image_features, text_features).squeeze()
42
+ rank = torch.argsort(scores, descending=True).cpu().numpy()
43
+ return rank
44
+
45
+
46
+ def download(url: str, root: str) -> str:
47
+ os.makedirs(root, exist_ok=True)
48
+ filename = os.path.basename(url)
49
+ pathname = filename[:-len('.tar.gz')]
50
+
51
+ expected_md5 = url.split("/")[-2]
52
+ download_target = os.path.join(root, filename)
53
+ result_path = os.path.join(root, pathname)
54
+
55
+ if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
56
+ return result_path
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
59
+ with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
60
+ unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
70
+ raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
71
+
72
+ with tarfile.open(download_target, 'r:gz') as f:
73
+ pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
74
+ for member in pbar:
75
+ pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
76
+ f.extract(member=member, path=root)
77
+
78
+ return result_path
79
+
80
+
81
+ def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
82
+ if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
83
+ return download(url_or_path, root)
84
+ return url_or_path
page/generate.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from numpy.core.defchararray import lower
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import numpy as np
9
+ import os, random, time
10
+ from utils import footer, generate, drawGrid
11
+ from PIL import Image
12
+
13
+ mode = "ai"
14
+ #mode = "dummy"
15
+
16
+ def app():
17
+
18
+
19
+
20
+ st.title('AI-Generated Architecture')
21
+
22
+ st.subheader('Describe a building, interior, or other architecture you would like to see.')
23
+
24
+ #Modern architecture museum with black brick and large windows.
25
+ prompt = st.text_input(label="",value="Modern architecture museum with black brick and large windows.")
26
+
27
+ st.text("")
28
+
29
+
30
+ with st.expander("Having trouble thinking of something? Click here to view examples."):
31
+ st.write("""
32
+ • Modern architecture museum with black brick and large windows.\n
33
+ • A prosaic, simple architecture.\n
34
+ • An urban, post-modern architecture with concrete and steel.\n
35
+ • A sleek urban interior design.
36
+ """)
37
+
38
+ st.text("")
39
+
40
+ crazy = st.slider('Temperature. This controls how "crazy" generated images are, where 0 is the least crazy.', 0.0, 1.0, 0.75)
41
+ k = st.slider('Top K. The higher the value, the higher quality the results tend to be at the cost of extra processing time.', 1, 10, 1)
42
+
43
+ if( 'results' not in st.session_state ):
44
+ st.session_state.results = []
45
+
46
+ holder = st.empty()
47
+ startButton = holder.button("Start")
48
+
49
+ already = []
50
+
51
+ print("-0-")
52
+
53
+ if startButton or hasattr(st.session_state, 'load_state'):
54
+
55
+ with st.spinner("Generating..."):
56
+
57
+ print("-1-")
58
+
59
+ holder.empty()
60
+
61
+ nextButton = holder.button("finished generating images")
62
+ st.session_state.load_state = True
63
+
64
+ placeholder = st.empty()
65
+ second = st.empty()
66
+
67
+ with second.container():
68
+ drawGrid()
69
+
70
+ while len(st.session_state.results) <= 15:
71
+
72
+ print("Length "+str(len(st.session_state.results)))
73
+
74
+ with placeholder.container():
75
+
76
+ st.session_state.bar = placeholder.progress(0)
77
+
78
+
79
+ if(nextButton):
80
+ st.session_state.page = 1
81
+ break
82
+
83
+ generate(prompt,crazy,k)
84
+
85
+ with second.container():
86
+ drawGrid()
87
+
88
+
89
+
90
+
91
+ #placeholder.empty()
92
+
93
+ #st.session_state.bar = placeholder.progress(0)
94
+ #drawGrid(placeholder)
95
+
96
+
97
+
page/reduce.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from numpy.core.defchararray import lower
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ from zipfile import ZipFile
7
+ import io
8
+ import os
9
+
10
+ def dell(ix):
11
+ print("!!!!")
12
+ st.session_state.results.pop(ix)
13
+
14
+
15
+ def app():
16
+
17
+ st.title('AI-Generated Architecture')
18
+
19
+ st.subheader('Choose which images you would like to remove from your working set.')
20
+
21
+ os.chdir(r"temp/")
22
+ all_files = os.listdir()
23
+ for f in all_files:
24
+ os.remove(f)
25
+
26
+ # create a ZipFile object
27
+ zipObj = ZipFile('ai_architecture.zip', 'w')
28
+ # Add multiple files to the zip
29
+ for ix,file in enumerate( st.session_state.results ):
30
+ file['image'].save("temp/"+str(ix)+".jpeg")
31
+ zipObj.write("temp/"+str(ix)+".jpeg")
32
+
33
+ zipObj.close()
34
+
35
+ st.download_button(
36
+ label="Download images as zip",
37
+ data=open('ai_architecture.zip', 'rb'),
38
+ file_name='ai_architecture.zip',
39
+ mime='application/zip'
40
+ )
41
+
42
+
43
+ deleteButtons = []
44
+
45
+ for ix,result in enumerate( st.session_state.results ):
46
+
47
+ with st.container():
48
+ col1,col2 = st.columns(2)
49
+
50
+ with col1:
51
+ st.image(result['image'])
52
+ with col2:
53
+ st.button("delete ", key=ix, on_click=dell, kwargs=dict(ix=ix) )
54
+
55
+ m = st.markdown("""
56
+ <hr />""", unsafe_allow_html=True)
57
+
58
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip==0.2.0
2
+ Cython==0.29.30
3
+ clip_anytorch==2.4.0
4
+ htbuilder==0.6.0
5
+ iteration_utilities==0.11.0
6
+ numpy==1.22.4
7
+ omegaconf==2.2.2
8
+ pages==0.3
9
+ pandas==1.4.2
10
+ Pillow==9.2.0
11
+ pytorch_lightning==1.6.3
12
+ ruclip==0.0.1
13
+ rudalle==1.1.3
14
+ streamlit==1.10.0
15
+ tokenizers==0.12.1
16
+ torch==1.8.0
17
+ torchvision==0.9.0
18
+ tqdm==4.64.0
streamlit_app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ import os, random, time
6
+
7
+ from utils import footer
8
+ from page import generate, reduce
9
+
10
+
11
+ if( hasattr(st.session_state, 'page') == False):
12
+ st.session_state.page = 0
13
+
14
+ if( hasattr(st.session_state, 'results') == False):
15
+ st.session_state.results = []
16
+
17
+ p1 = st.empty()
18
+ p2 = st.empty()
19
+ p3 = st.empty()
20
+
21
+
22
+ st.session_state.stop = False
23
+ st.session_state.progress = 0
24
+ st.session_state.regenerate = False
25
+
26
+ if(st.session_state.page == 0):
27
+ p2.empty()
28
+ p3.empty()
29
+ with p1.container():
30
+ generate.app()
31
+
32
+
33
+ if(st.session_state.page == 1):
34
+ p1.empty()
35
+ p3.empty()
36
+ with p2.container():
37
+ reduce.app()
38
+
39
+ if(st.session_state.page == 2):
40
+ p1.empty()
41
+ p2.empty()
42
+ with p3.container():
43
+ st.write("This 333")
44
+ startButton = st.button("S3")
45
+ if startButton:
46
+ st.session_state.page = 0
47
+
48
+ footer()
utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
2
+ from htbuilder.units import percent, px
3
+ from htbuilder.funcs import rgba, rgb
4
+ import streamlit as st
5
+ import os
6
+ import sys
7
+ import argparse
8
+ import clip
9
+ import numpy as np
10
+ from PIL import Image
11
+ from dalle.models import Dalle
12
+ from dalle.utils.utils import set_seed, clip_score
13
+
14
+ def link(link, text, **style):
15
+ return a(_href=link, _target="_blank", style=styles(**style))(text)
16
+
17
+ def layout(*args):
18
+
19
+ style = """
20
+ <style>
21
+ # MainMenu {visibility: hidden;}
22
+ footer {visibility: hidden;}
23
+ .stApp { bottom: 105px; }
24
+ </style>
25
+ """
26
+
27
+ style_div = styles(
28
+ position="fixed",
29
+ left=0,
30
+ bottom=0,
31
+ margin=px(0, 0, 0, 0),
32
+ width=percent(100),
33
+ color="black",
34
+ text_align="center",
35
+ height="auto",
36
+ opacity=1
37
+ )
38
+
39
+ style_hr = styles(
40
+ display="block",
41
+ margin=px(8, 8, "auto", "auto"),
42
+ border_style="inset",
43
+ border_width=px(2)
44
+ )
45
+
46
+ body = p()
47
+ foot = div(
48
+ style=style_div
49
+ )(
50
+ hr(
51
+ style=style_hr
52
+ ),
53
+ body
54
+ )
55
+
56
+ st.markdown(style, unsafe_allow_html=True)
57
+
58
+ for arg in args:
59
+ if isinstance(arg, str):
60
+ body(arg)
61
+
62
+ elif isinstance(arg, HtmlElement):
63
+ body(arg)
64
+
65
+ st.markdown(str(foot), unsafe_allow_html=True)
66
+
67
+ def footer():
68
+ myargs = [
69
+ "Created by ",
70
+ link("https://jonathanmalott.com", "Jonathan Malott"),
71
+ br(),
72
+ link("https://bridgingbarriers.utexas.edu/good-systems", "Good Systems"),
73
+ " Grand Challenge",
74
+ ", The University of Texas at Austin.",
75
+ " Advised by Dr. Junfeng Jiao.",
76
+ br(),
77
+ br(),
78
+ ]
79
+ layout(*myargs)
80
+
81
+ #footer()
82
+
83
+ def generate(prompt,crazy,k):
84
+
85
+ device = 'cpu'
86
+ print("-2-")
87
+ model = Dalle.from_pretrained('.cache/minDALL-E/1.3B') # This will automatically download the pretrained model.
88
+ print("-3-")
89
+ model.to(device=device)
90
+ num_candidates = 1
91
+
92
+ images = []
93
+
94
+ set_seed(np.random.randint(0,10000))
95
+
96
+
97
+
98
+ # Sampling
99
+ images = model.sampling(prompt=prompt,
100
+ top_k=2048,
101
+ top_p=None,
102
+ softmax_temperature=crazy,
103
+ num_candidates=num_candidates,
104
+ device=device).cpu().numpy()
105
+ images = np.transpose(images, (0, 2, 3, 1))
106
+
107
+ # CLIP Re-ranking
108
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
109
+ model_clip.to(device=device)
110
+ rank = clip_score(prompt=prompt,
111
+ images=images,
112
+ model_clip=model_clip,
113
+ preprocess_clip=preprocess_clip,
114
+ device=device)
115
+
116
+ result = images[rank]
117
+
118
+ item = {}
119
+ item['prompt'] = prompt
120
+ item['crazy'] = crazy
121
+ item['k'] = k
122
+ item['image'] = Image.fromarray((result*255).astype(np.uint8))
123
+ st.session_state.results.append(item)
124
+
125
+
126
+
127
+ def drawGrid():
128
+ master = {}
129
+ order = 0
130
+
131
+ #print(st.session_state.results)
132
+
133
+ for r in st.session_state.results[::-1]:
134
+ _txt = r['prompt']+" "+str(r['crazy'])+" "+str(r['k'])
135
+
136
+ if(_txt not in master):
137
+ master[_txt] = [r]
138
+ order += 1
139
+ else:
140
+ master[_txt].append(r)
141
+
142
+
143
+ for m in master:
144
+ #with placeholder.container():
145
+
146
+ txt = master[m][0]['prompt']+" (temperature:"+ str(master[m][0]['crazy']) + ", top k:" + str(master[m][0]['k']) + ")"
147
+ st.subheader(txt)
148
+ col1, col2, col3 = st.columns(3)
149
+
150
+ for ix, item in enumerate(master[m]):
151
+ if ix % 3 == 0:
152
+ with col1:
153
+ st.image(item["image"])
154
+ if ix % 3 == 1:
155
+ with col2:
156
+ st.image(item["image"])
157
+ if ix % 3 == 2:
158
+ with col3:
159
+ st.image(item["image"])
160
+