JMalott commited on
Commit
90bed62
·
1 Parent(s): 1a2253c

Upload dalle_bart_encoder.py

Browse files
min_dalle/models/dalle_bart_encoder.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from torch import nn, BoolTensor, FloatTensor, LongTensor
4
+
5
+
6
+ class GLU(nn.Module):
7
+ def __init__(self, count_in_out: int, count_middle: int):
8
+ super().__init__()
9
+ self.gelu = nn.GELU()
10
+ self.ln0 = nn.LayerNorm(count_in_out)
11
+ self.ln1 = nn.LayerNorm(count_middle)
12
+ self.fc0 = nn.Linear(count_in_out, count_middle, bias=False)
13
+ self.fc1 = nn.Linear(count_in_out, count_middle, bias=False)
14
+ self.fc2 = nn.Linear(count_middle, count_in_out, bias=False)
15
+
16
+ def forward(self, z: FloatTensor) -> FloatTensor:
17
+ z = self.ln0.forward(z)
18
+ w = self.fc0.forward(z)
19
+ w = self.gelu.forward(w)
20
+ v = self.fc1.forward(z)
21
+ z = self.ln1.forward(w * v)
22
+ z = self.fc2.forward(z)
23
+ return z
24
+
25
+
26
+ class AttentionBase(nn.Module):
27
+ def __init__(self, head_count: int, embed_count: int):
28
+ super().__init__()
29
+ self.head_count = head_count
30
+ self.embed_count = embed_count
31
+
32
+ self.k_proj = nn.Linear(embed_count, embed_count, bias=False)
33
+ self.v_proj = nn.Linear(embed_count, embed_count, bias=False)
34
+ self.q_proj = nn.Linear(embed_count, embed_count, bias=False)
35
+ self.out_proj = nn.Linear(embed_count, embed_count, bias=False)
36
+
37
+ def forward(
38
+ self,
39
+ keys: FloatTensor,
40
+ values: FloatTensor,
41
+ queries: FloatTensor,
42
+ attention_mask: BoolTensor
43
+ ) -> FloatTensor:
44
+ keys = keys.reshape(keys.shape[:2] + (self.head_count, -1))
45
+ values = values.reshape(values.shape[:2] + (self.head_count, -1))
46
+ queries = queries.reshape(queries.shape[:2] + (self.head_count, -1))
47
+ queries /= queries.shape[-1] ** 0.5
48
+
49
+ attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
50
+ attention_weights: FloatTensor = torch.einsum(
51
+ 'bqhc,bkhc->bhqk',
52
+ queries,
53
+ keys
54
+ )
55
+ attention_weights += attention_bias[:, None, None, :]
56
+ attention_weights = torch.softmax(attention_weights, -1)
57
+ attention_output: FloatTensor = torch.einsum(
58
+ "bhqk,bkhc->bqhc",
59
+ attention_weights,
60
+ values
61
+ )
62
+ shape = attention_output.shape[:2] + (self.embed_count,)
63
+ attention_output = attention_output.reshape(shape)
64
+ attention_output = self.out_proj.forward(attention_output)
65
+ return attention_output
66
+
67
+
68
+ class EncoderSelfAttention(AttentionBase):
69
+ def forward(
70
+ self,
71
+ encoder_state: FloatTensor,
72
+ attention_mask: BoolTensor
73
+ ) -> FloatTensor:
74
+ keys = self.k_proj.forward(encoder_state)
75
+ values = self.v_proj.forward(encoder_state)
76
+ queries = self.q_proj.forward(encoder_state)
77
+ return super().forward(keys, values, queries, attention_mask)
78
+
79
+
80
+ class EncoderLayer(nn.Module):
81
+ def __init__(self, embed_count: int, head_count: int, glu_embed_count: int):
82
+ super().__init__()
83
+ self.pre_self_attn_layer_norm = nn.LayerNorm(embed_count)
84
+ self.self_attn = EncoderSelfAttention(head_count, embed_count)
85
+ self.self_attn_layer_norm = nn.LayerNorm(embed_count)
86
+ self.glu = GLU(embed_count, glu_embed_count)
87
+
88
+ def forward(
89
+ self,
90
+ encoder_state: FloatTensor,
91
+ attention_mask: BoolTensor
92
+ ) -> FloatTensor:
93
+ residual = encoder_state
94
+ encoder_state = self.pre_self_attn_layer_norm.forward(encoder_state)
95
+ encoder_state = self.self_attn.forward(encoder_state, attention_mask)
96
+ encoder_state = self.self_attn_layer_norm.forward(encoder_state)
97
+ encoder_state = residual + encoder_state
98
+ residual = encoder_state
99
+ encoder_state = self.glu.forward(encoder_state)
100
+ encoder_state = residual + encoder_state
101
+ return encoder_state
102
+
103
+
104
+ class DalleBartEncoder(nn.Module):
105
+ def __init__(
106
+ self,
107
+ layer_count: int,
108
+ embed_count: int,
109
+ attention_head_count: int,
110
+ text_vocab_count: int,
111
+ text_token_count: int,
112
+ glu_embed_count: int,
113
+ device: str
114
+ ):
115
+ super().__init__()
116
+ self.text_vocab_count = text_vocab_count
117
+ self.embed_tokens = nn.Embedding(text_vocab_count, embed_count)
118
+ self.embed_positions = nn.Embedding(text_token_count, embed_count)
119
+ self.layers: List[EncoderLayer] = nn.ModuleList([
120
+ EncoderLayer(
121
+ embed_count = embed_count,
122
+ head_count = attention_head_count,
123
+ glu_embed_count = glu_embed_count
124
+ )
125
+ for _ in range(layer_count)
126
+ ])
127
+ self.layernorm_embedding = nn.LayerNorm(embed_count)
128
+ self.final_ln = nn.LayerNorm(embed_count)
129
+ token_indices = torch.arange(text_token_count, device=device)
130
+ self.pose_tokens = torch.stack([token_indices] * 2)
131
+
132
+ def forward(self, text_tokens: LongTensor) -> FloatTensor:
133
+ attention_mask = text_tokens.not_equal(1)
134
+ encoder_state = (
135
+ self.embed_tokens.forward(text_tokens) +
136
+ self.embed_positions.forward(self.pose_tokens)
137
+ )
138
+ encoder_state = self.layernorm_embedding.forward(encoder_state)
139
+ for layer in self.layers:
140
+ encoder_state = layer.forward(encoder_state, attention_mask)
141
+ encoder_state = self.final_ln.forward(encoder_state)
142
+ return encoder_state