iakarshu commited on
Commit
d7f6f38
1 Parent(s): cfa10c1

Upload modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +273 -0
modeling.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Embedding Layer
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ class Embedding(nn.Module):
8
+
9
+ def __init__(self,
10
+ vocab_size : int = 50265, ## RobertA's tokenizer.vocab_size -> 50265
11
+ hidden_dim_t : int = 768, ## hidden_dim_text -> 768
12
+ hidden_dim_l : int = 768 // 6, ## hidden_dim_layout -> 768 // 6 for each of the 6 coordinates
13
+ max_x_coord : int = 1001, ## X coordinate ranges from 0 to 1000
14
+ max_y_coord : int = 1001,
15
+ max_seq_len_t : int = 512,
16
+ max_seq_len_l : int = 512): ## Y coordinate ranges from 0 to 1000
17
+
18
+ super(Embedding, self).__init__()
19
+ self.lang_embedding = nn.Embedding(
20
+ num_embeddings = vocab_size,
21
+ embedding_dim = hidden_dim_t
22
+ )
23
+
24
+ self.top_left_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
25
+ self.top_left_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
26
+ self.bottom_right_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
27
+ self.bottom_right_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
28
+ self.width_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
29
+ self.height_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
30
+
31
+ self.box_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_l + 1, embedding_dim = 6 * hidden_dim_l)
32
+ self.textual_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_t + 1, embedding_dim = hidden_dim_t)
33
+
34
+ # ## Layer Normalization, would be added as pre-normalization and post-normalization
35
+ # self.ln_t = nn.LayerNorm(normalized_shape = hidden_dim_t)
36
+ # self.ln_l = nn.LayerNorm(normalized_shape = 6*hidden_dim_l)
37
+
38
+
39
+ def forward(self, tokenized_words, tokenized_bbox):
40
+
41
+ ## Generating position Ids
42
+ text_len, box_len = tokenized_words.shape[1], tokenized_bbox.shape[1]
43
+ word_pos_ids = torch.arange(text_len).unsqueeze(0).to(tokenized_words.device)
44
+ box_pos_ids = torch.arange(box_len).unsqueeze(0).to(tokenized_bbox.device)
45
+
46
+ ## Using Embedding Table for extracting the correspoding features
47
+ text_feature = self.lang_embedding(tokenized_words)
48
+ top_left_x_feat = self.top_left_x_emb(tokenized_bbox[:, :, 0])
49
+ top_left_y_feat = self.top_left_y_emb(tokenized_bbox[:, :, 1])
50
+ bottom_right_x_feat = self.bottom_right_x_emb(tokenized_bbox[:, :, 2])
51
+ bottom_right_y_feat = self.bottom_right_y_emb(tokenized_bbox[:, :, 3])
52
+ width_feat = self.width_emb(tokenized_bbox[:, :, 4])
53
+ height_feat = self.height_emb(tokenized_bbox[:, :, 5])
54
+
55
+ ## Layout feature
56
+ layout_feature = torch.cat(
57
+ [top_left_x_feat,
58
+ top_left_y_feat,
59
+ bottom_right_x_feat,
60
+ bottom_right_y_feat,
61
+ width_feat,
62
+ height_feat
63
+ ],
64
+ axis = -1
65
+ )
66
+
67
+ ## Generating positional embedding
68
+ pos_emb_t = self.textual_position_embeddings(word_pos_ids)
69
+ pos_emb_l = self.box_position_embeddings(box_pos_ids)
70
+
71
+ ## Adding a positional encoding
72
+ layout_feature = layout_feature + pos_emb_l
73
+ text_feature = text_feature + pos_emb_t
74
+
75
+ # ## Adding the layer normalization, would be added in the encoder part
76
+ # layout_feature = self.ln_l(layout_feature)
77
+ # text_feature = self.ln_t(text_feature)
78
+
79
+ return {'layout_feature': layout_feature, 'text_feature': text_feature}
80
+
81
+
82
+ ## Attention Layer
83
+
84
+ ## Reference: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
85
+ class MultiModalAttentionLayer(nn.Module):
86
+
87
+ def __init__(self, embed_dim : int = 768,
88
+ n_heads : int = 12,
89
+ dim_head : int = 64,
90
+ fine_tune : bool = False,
91
+ dropout : float = 0.0
92
+ ):
93
+ super(MultiModalAttentionLayer, self).__init__()
94
+
95
+ inner_dim = n_heads * dim_head
96
+ self.n_heads = n_heads
97
+ self.fine_tune = fine_tune
98
+
99
+ self.proj_text_k = nn.Linear(in_features = embed_dim, out_features = inner_dim) ## 768 -> 512
100
+ self.proj_text_q = nn.Linear(in_features = embed_dim, out_features = inner_dim)
101
+ self.proj_text_v = nn.Linear(in_features = embed_dim, out_features = inner_dim)
102
+
103
+ self.proj_layout_k = nn.Linear(in_features = embed_dim, out_features = inner_dim)
104
+ self.proj_layout_q = nn.Linear(in_features = embed_dim, out_features = inner_dim)
105
+ self.proj_layout_v = nn.Linear(in_features = embed_dim, out_features = inner_dim)
106
+
107
+ self.attend = nn.Softmax(dim = -1)
108
+ self.scale = dim_head ** -0.5
109
+
110
+ self.dropout = nn.Dropout(dropout)
111
+ self.to_out_l = nn.Sequential(
112
+ nn.Linear(inner_dim, embed_dim),
113
+ nn.Dropout(dropout)
114
+ )
115
+ self.to_out_t = nn.Sequential(
116
+ nn.Linear(inner_dim, embed_dim),
117
+ nn.Dropout(dropout)
118
+ )
119
+
120
+ def forward(self, text_feature, layout_feature):
121
+
122
+ query_vec_t = rearrange(self.proj_text_q(text_feature), 'b t (head k) -> head b t k', head=self.n_heads) ## batch, 512, 768 -> 8, batch, 512, 64
123
+ key_vec_t = rearrange(self.proj_text_k(text_feature), 'b t (head k) -> head b t k', head=self.n_heads)
124
+ value_vec_t = rearrange(self.proj_text_v(text_feature), 'b t (head k) -> head b t k', head=self.n_heads)
125
+
126
+ query_vec_l = rearrange(self.proj_layout_q(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
127
+ key_vec_l = rearrange(self.proj_layout_k(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
128
+ value_vec_l = rearrange(self.proj_layout_v(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
129
+
130
+ attn_t = torch.einsum('hblk,hbtk->hblt', query_vec_t, key_vec_t) * self.scale
131
+ attn_l = torch.einsum('hblk,hbtk->hblt', query_vec_l, key_vec_l) * self.scale
132
+
133
+ attn_tilde_t = attn_t + attn_l
134
+
135
+ if self.fine_tune:
136
+ attn_tilde_l = attn_l + attn_t
137
+ else:
138
+ attn_tilde_l = attn_l + attn_t.detach()
139
+
140
+ text_attn_probs = self.dropout(self.attend(attn_tilde_t))
141
+ layout_attn_probs = self.dropout(self.attend(attn_tilde_l))
142
+
143
+ text_context = rearrange(torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_vec_t), 'h b l k -> b l (h k)')
144
+ layout_context = rearrange(torch.einsum('hblt,hbtv->hblv', layout_attn_probs, value_vec_l), 'h b l k -> b l (h k)')
145
+
146
+ text_context = self.to_out_t(text_context)
147
+ layout_context = self.to_out_l(layout_context)
148
+
149
+ return {'layout_feature': layout_context, 'text_feature': text_context,
150
+ 'layout_attention': attn_l,'textual_attention': attn_t}
151
+
152
+
153
+ ## Constructing the Encoder Layer
154
+
155
+ class PreNorm(nn.Module):
156
+ def __init__(self, dim, fn, eps = 1e-12):
157
+ super().__init__()
158
+ self.norm = nn.LayerNorm(dim, eps = eps)
159
+ self.fn = fn
160
+
161
+ def forward(self, x, **kwargs):
162
+ return self.fn(self.norm(x), **kwargs)
163
+
164
+ class PreNormAttn(nn.Module):
165
+ def __init__(self, dim, fn, eps = 1e-12):
166
+ super().__init__()
167
+
168
+ self.norm_t = nn.LayerNorm(dim, eps = eps)
169
+ self.norm_l = nn.LayerNorm(dim, eps = eps)
170
+ self.fn = fn
171
+
172
+ def forward(self, text_feat, layout_feat, **kwargs):
173
+ return self.fn(self.norm_t(text_feat),
174
+ self.norm_l(layout_feat),**kwargs)
175
+
176
+
177
+ ## FFN Network
178
+ class FeedForward(nn.Module):
179
+ def __init__(self, dim : int = 768, hidden_dim : int = 4 * 768, dropout=0.):
180
+ super().__init__()
181
+ self.net = nn.Sequential(
182
+ nn.Linear(dim, hidden_dim),
183
+ nn.GELU(),
184
+ nn.Dropout(dropout),
185
+ nn.Linear(hidden_dim, dim),
186
+ nn.Dropout(dropout)
187
+ )
188
+
189
+ def forward(self, x):
190
+ return self.net(x)
191
+
192
+
193
+ ## Encoder
194
+ class LiLTEncoder(nn.Module):
195
+ def __init__(self, config):
196
+ super().__init__()
197
+ self.config = config
198
+ self.layers = nn.ModuleList([])
199
+ for _ in range(config['num_hidden_layers']):
200
+ encoder_block = nn.ModuleList([
201
+ PreNormAttn(config['hidden_size'],
202
+ MultiModalAttentionLayer(embed_dim = config['hidden_size'],
203
+ n_heads = config['num_attention_heads'],
204
+ dim_head = config['dim_head'],
205
+ fine_tune = config['fine_tune'],
206
+ dropout = config['hidden_dropout_prob'],
207
+ ),
208
+ eps = config['eps']
209
+ ),
210
+ PreNorm(config['hidden_size'],
211
+ FeedForward(config['hidden_size'],
212
+ config['hidden_size'] * config['intermediate_ff_size_factor'],
213
+ dropout=config['hidden_dropout_prob'],
214
+ ),
215
+ eps = config['eps']),
216
+ PreNorm(config['hidden_size'],
217
+ FeedForward(config['hidden_size'],
218
+ config['hidden_size'] * config['intermediate_ff_size_factor'],
219
+ dropout=config['hidden_dropout_prob']
220
+ ),
221
+ eps = config['eps'])
222
+ ])
223
+ self.layers.append(encoder_block)
224
+
225
+ def forward(
226
+ self,
227
+ text_feat,
228
+ layout_feat,
229
+ ):
230
+
231
+ text_attn = []
232
+ layout_attn = []
233
+ text_hidden_states = []
234
+ layout_hidden_states = []
235
+
236
+ for attn, ff_t, ff_l in self.layers:
237
+
238
+ context_vec = attn(text_feat, layout_feat)
239
+ text_feat = text_feat + context_vec['text_feature']
240
+ layout_feat = layout_feat + context_vec['layout_feature']
241
+
242
+ text_feat = ff_t(text_feat) + text_feat
243
+ layout_feat = ff_l(layout_feat) + layout_feat
244
+
245
+ text_attn.append(context_vec['textual_attention'])
246
+ layout_attn.append(context_vec['layout_attention'])
247
+ text_hidden_states.append(text_feat)
248
+ layout_hidden_states.append(layout_feat)
249
+
250
+ return {'text_hidden_states' : text_hidden_states, 'layout_hidden_states': layout_hidden_states,
251
+ 'text_attn' : text_attn, 'layout_attn' : layout_attn}
252
+
253
+
254
+
255
+ ## Constructing the whole model from embeddings to the hidden states and attention
256
+ class LiLT(nn.Module):
257
+
258
+ def __init__(self, config):
259
+ super(LiLT, self).__init__()
260
+ self.lilt = LiLTEncoder(config)
261
+ self.emb = Embedding(vocab_size = config['vocab_size'],
262
+ hidden_dim_t = config['hidden_size_t'],
263
+ hidden_dim_l = config['hidden_size_l'],
264
+ max_x_coord = config['max_2d_position_embeddings'],
265
+ max_y_coord = config['max_2d_position_embeddings'],
266
+ max_seq_len_t = config['max_seq_len_t'],
267
+ max_seq_len_l = config['max_seq_len_l'])
268
+
269
+
270
+ def forward(self, tokenized_words, tokenized_bbox):
271
+ hidden_enc = self.emb(tokenized_words, tokenized_bbox)
272
+ encodings = self.lilt(hidden_enc['text_feature'], hidden_enc['layout_feature'])
273
+ return encodings