keithhon commited on
Commit
3a88044
·
1 Parent(s): 5fe0715

Upload dalle/models/stage2/transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/models/stage2/transformer.py +255 -0
dalle/models/stage2/transformer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List
13
+ from torch.cuda.amp import autocast
14
+ from omegaconf import OmegaConf
15
+ from .layers import Block
16
+
17
+
18
+ class Transformer1d(nn.Module):
19
+
20
+ def __init__(self,
21
+ vocab_size_txt: int,
22
+ vocab_size_img: int,
23
+ hparams: OmegaConf) -> None:
24
+ super().__init__()
25
+ assert hparams.n_layers == hparams.n_dense_layers
26
+
27
+ # input embedding for image and text
28
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
29
+ self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
30
+
31
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
32
+ self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
33
+
34
+ self.drop = nn.Dropout(hparams.embd_pdrop)
35
+
36
+ # transformer blocks
37
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
38
+ embed_dim=hparams.embed_dim,
39
+ n_heads=hparams.n_heads,
40
+ mlp_bias=hparams.mlp_bias,
41
+ attn_bias=hparams.attn_bias,
42
+ resid_pdrop=hparams.resid_pdrop,
43
+ attn_pdrop=hparams.attn_pdrop,
44
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
45
+ self.blocks = nn.Sequential(*self.blocks)
46
+
47
+ # heads for image and text
48
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
49
+ self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
50
+ self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
51
+
52
+ self.ctx_len_img = hparams.ctx_len_img
53
+ self.ctx_len_txt = hparams.ctx_len_txt
54
+ self.n_layers = hparams.n_layers
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, module: nn.Module) -> None:
59
+ if isinstance(module, (nn.Linear, nn.Embedding)):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+ elif isinstance(module, nn.LayerNorm):
64
+ module.bias.data.zero_()
65
+ module.weight.data.fill_(1.0)
66
+
67
+ def forward(self,
68
+ images: torch.LongTensor,
69
+ texts: torch.LongTensor,
70
+ pos_images: torch.LongTensor,
71
+ pos_texts: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
72
+ B, T = images.shape
73
+ _, N = texts.shape
74
+
75
+ assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
76
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
77
+
78
+ texts = self.tok_emb_txt(texts)
79
+ images = self.tok_emb_img(images)
80
+
81
+ texts = texts + self.pos_emb_txt(pos_texts)
82
+ images = images + self.pos_emb_img(pos_images)
83
+
84
+ x = torch.cat([texts, images], axis=1).contiguous()
85
+ x = self.drop(x)
86
+ x = self.blocks(x)
87
+ x = self.ln_f(x)
88
+
89
+ texts = x[:, :N-1].contiguous()
90
+ images = x[:, N-1:-1].contiguous()
91
+
92
+ logits_txt = self.head_txt(texts)
93
+ logits_img = self.head_img(images)
94
+ return logits_img, logits_txt
95
+
96
+ @torch.no_grad()
97
+ def sampling(self,
98
+ images: torch.LongTensor,
99
+ texts: torch.LongTensor,
100
+ pos_images: torch.LongTensor,
101
+ pos_texts: torch.LongTensor,
102
+ use_fp16: bool = True,
103
+ past: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
104
+ _, N = texts.shape
105
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
106
+
107
+ with autocast(enabled=use_fp16):
108
+ if images is None:
109
+ assert past is None
110
+
111
+ texts = self.tok_emb_txt(texts)
112
+ x = texts + self.pos_emb_txt(pos_texts)
113
+ x = self.drop(x)
114
+
115
+ presents = []
116
+ for i, block in enumerate(self.blocks):
117
+ x, present = block.sample(x, layer_past=None)
118
+ presents.append(present)
119
+ x = self.ln_f(x)
120
+ x = x[:, N-1].contiguous()
121
+ logits = self.head_img(x)
122
+ else:
123
+ if past is None:
124
+ texts = self.tok_emb_txt(texts)
125
+ images = self.tok_emb_img(images)
126
+ texts = texts + self.pos_emb_txt(pos_texts)
127
+ images = images + self.pos_emb_img(pos_images)
128
+ x = torch.cat([texts, images], axis=1).contiguous()
129
+ else:
130
+ images = self.tok_emb_img(images)
131
+ x = images + self.pos_emb_img(pos_images)
132
+ x = self.drop(x)
133
+
134
+ if past is not None:
135
+ past = torch.cat(past, dim=-2)
136
+ presents = []
137
+ for i, block in enumerate(self.blocks):
138
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
139
+ presents.append(present)
140
+ x = self.ln_f(x)
141
+ x = x[:, -1].contiguous()
142
+ logits = self.head_img(x)
143
+ return logits, presents
144
+
145
+ def from_ckpt(self, path: str) -> None:
146
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
147
+ self.load_state_dict(ckpt, strict=True)
148
+ print(f'{path} succesfully restored..')
149
+
150
+
151
+ class iGPT(nn.Module):
152
+ def __init__(self,
153
+ vocab_size_img: int,
154
+ use_cls_cond: bool,
155
+ hparams: OmegaConf) -> None:
156
+ super().__init__()
157
+ self.use_cls_cond = use_cls_cond
158
+
159
+ # sos token embedding
160
+ if self.use_cls_cond:
161
+ self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
162
+ else:
163
+ self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
164
+
165
+ # input embedding
166
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
167
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
168
+
169
+ self.drop = nn.Dropout(hparams.embd_pdrop)
170
+
171
+ # transformer blocks
172
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
173
+ embed_dim=hparams.embed_dim,
174
+ n_heads=hparams.n_heads,
175
+ mlp_bias=hparams.mlp_bias,
176
+ attn_bias=hparams.attn_bias,
177
+ resid_pdrop=hparams.resid_pdrop,
178
+ attn_pdrop=hparams.attn_pdrop,
179
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
180
+ self.blocks = nn.Sequential(*self.blocks)
181
+
182
+ # head
183
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
184
+ self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
185
+
186
+ self.ctx_len_img = hparams.ctx_len_img
187
+ self.n_layers = hparams.n_layers
188
+
189
+ self.apply(self._init_weights)
190
+
191
+ def _init_weights(self, module: nn.Module) -> None:
192
+ if isinstance(module, (nn.Linear, nn.Embedding)):
193
+ module.weight.data.normal_(mean=0.0, std=0.02)
194
+ if isinstance(module, nn.Linear) and module.bias is not None:
195
+ module.bias.data.zero_()
196
+ elif isinstance(module, nn.LayerNorm):
197
+ module.bias.data.zero_()
198
+ module.weight.data.fill_(1.0)
199
+
200
+ @torch.no_grad()
201
+ def sampling(self,
202
+ sos: torch.FloatTensor,
203
+ codes: torch.LongTensor,
204
+ pos_codes: torch.LongTensor,
205
+ n_samples: int = 16,
206
+ use_fp16: bool = True,
207
+ past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
208
+ with autocast(enabled=use_fp16):
209
+ if codes is None:
210
+ assert past is None
211
+ xs = self.drop(sos)
212
+ presents = []
213
+ for i, block in enumerate(self.blocks):
214
+ xs, present = block.sample(xs, layer_past=None)
215
+ presents.append(present)
216
+ xs = self.ln_f(xs)
217
+ logits = self.head(xs)[:, -1]
218
+ else:
219
+ if past is None:
220
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
221
+ xs = torch.cat([sos, xs], dim=1)
222
+ else:
223
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
224
+ xs = self.drop(xs)
225
+
226
+ past = torch.cat(past, dim=-2) if past is not None else past
227
+ presents = []
228
+ for i, block in enumerate(self.blocks):
229
+ xs, present = block.sample(xs, layer_past=None if past is None else past[i])
230
+ presents.append(present)
231
+
232
+ xs = self.ln_f(xs)
233
+ logits = self.head(xs)[:, -1]
234
+ return logits, presents
235
+
236
+ def forward(self,
237
+ codes: torch.LongTensor,
238
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
239
+ B, T = codes.shape
240
+ xps = torch.arange(T, device=codes.device).repeat((B, 1))
241
+ sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
242
+
243
+ h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
244
+ h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
245
+
246
+ h = self.drop(h)
247
+ h = self.blocks(h)
248
+ h = self.ln_f(h)
249
+ logits = self.head(h)
250
+ return logits
251
+
252
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
253
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
254
+ self.load_state_dict(ckpt, strict=strict)
255
+ print(f'{path} successfully restored..')