Upload modeling.py
Browse files- modeling.py +273 -0
modeling.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Embedding Layer
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
class Embedding(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
vocab_size : int = 50265, ## RobertA's tokenizer.vocab_size -> 50265
|
11 |
+
hidden_dim_t : int = 768, ## hidden_dim_text -> 768
|
12 |
+
hidden_dim_l : int = 768 // 6, ## hidden_dim_layout -> 768 // 6 for each of the 6 coordinates
|
13 |
+
max_x_coord : int = 1001, ## X coordinate ranges from 0 to 1000
|
14 |
+
max_y_coord : int = 1001,
|
15 |
+
max_seq_len_t : int = 512,
|
16 |
+
max_seq_len_l : int = 512): ## Y coordinate ranges from 0 to 1000
|
17 |
+
|
18 |
+
super(Embedding, self).__init__()
|
19 |
+
self.lang_embedding = nn.Embedding(
|
20 |
+
num_embeddings = vocab_size,
|
21 |
+
embedding_dim = hidden_dim_t
|
22 |
+
)
|
23 |
+
|
24 |
+
self.top_left_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
|
25 |
+
self.top_left_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
|
26 |
+
self.bottom_right_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
|
27 |
+
self.bottom_right_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
|
28 |
+
self.width_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l)
|
29 |
+
self.height_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l)
|
30 |
+
|
31 |
+
self.box_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_l + 1, embedding_dim = 6 * hidden_dim_l)
|
32 |
+
self.textual_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_t + 1, embedding_dim = hidden_dim_t)
|
33 |
+
|
34 |
+
# ## Layer Normalization, would be added as pre-normalization and post-normalization
|
35 |
+
# self.ln_t = nn.LayerNorm(normalized_shape = hidden_dim_t)
|
36 |
+
# self.ln_l = nn.LayerNorm(normalized_shape = 6*hidden_dim_l)
|
37 |
+
|
38 |
+
|
39 |
+
def forward(self, tokenized_words, tokenized_bbox):
|
40 |
+
|
41 |
+
## Generating position Ids
|
42 |
+
text_len, box_len = tokenized_words.shape[1], tokenized_bbox.shape[1]
|
43 |
+
word_pos_ids = torch.arange(text_len).unsqueeze(0).to(tokenized_words.device)
|
44 |
+
box_pos_ids = torch.arange(box_len).unsqueeze(0).to(tokenized_bbox.device)
|
45 |
+
|
46 |
+
## Using Embedding Table for extracting the correspoding features
|
47 |
+
text_feature = self.lang_embedding(tokenized_words)
|
48 |
+
top_left_x_feat = self.top_left_x_emb(tokenized_bbox[:, :, 0])
|
49 |
+
top_left_y_feat = self.top_left_y_emb(tokenized_bbox[:, :, 1])
|
50 |
+
bottom_right_x_feat = self.bottom_right_x_emb(tokenized_bbox[:, :, 2])
|
51 |
+
bottom_right_y_feat = self.bottom_right_y_emb(tokenized_bbox[:, :, 3])
|
52 |
+
width_feat = self.width_emb(tokenized_bbox[:, :, 4])
|
53 |
+
height_feat = self.height_emb(tokenized_bbox[:, :, 5])
|
54 |
+
|
55 |
+
## Layout feature
|
56 |
+
layout_feature = torch.cat(
|
57 |
+
[top_left_x_feat,
|
58 |
+
top_left_y_feat,
|
59 |
+
bottom_right_x_feat,
|
60 |
+
bottom_right_y_feat,
|
61 |
+
width_feat,
|
62 |
+
height_feat
|
63 |
+
],
|
64 |
+
axis = -1
|
65 |
+
)
|
66 |
+
|
67 |
+
## Generating positional embedding
|
68 |
+
pos_emb_t = self.textual_position_embeddings(word_pos_ids)
|
69 |
+
pos_emb_l = self.box_position_embeddings(box_pos_ids)
|
70 |
+
|
71 |
+
## Adding a positional encoding
|
72 |
+
layout_feature = layout_feature + pos_emb_l
|
73 |
+
text_feature = text_feature + pos_emb_t
|
74 |
+
|
75 |
+
# ## Adding the layer normalization, would be added in the encoder part
|
76 |
+
# layout_feature = self.ln_l(layout_feature)
|
77 |
+
# text_feature = self.ln_t(text_feature)
|
78 |
+
|
79 |
+
return {'layout_feature': layout_feature, 'text_feature': text_feature}
|
80 |
+
|
81 |
+
|
82 |
+
## Attention Layer
|
83 |
+
|
84 |
+
## Reference: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
|
85 |
+
class MultiModalAttentionLayer(nn.Module):
|
86 |
+
|
87 |
+
def __init__(self, embed_dim : int = 768,
|
88 |
+
n_heads : int = 12,
|
89 |
+
dim_head : int = 64,
|
90 |
+
fine_tune : bool = False,
|
91 |
+
dropout : float = 0.0
|
92 |
+
):
|
93 |
+
super(MultiModalAttentionLayer, self).__init__()
|
94 |
+
|
95 |
+
inner_dim = n_heads * dim_head
|
96 |
+
self.n_heads = n_heads
|
97 |
+
self.fine_tune = fine_tune
|
98 |
+
|
99 |
+
self.proj_text_k = nn.Linear(in_features = embed_dim, out_features = inner_dim) ## 768 -> 512
|
100 |
+
self.proj_text_q = nn.Linear(in_features = embed_dim, out_features = inner_dim)
|
101 |
+
self.proj_text_v = nn.Linear(in_features = embed_dim, out_features = inner_dim)
|
102 |
+
|
103 |
+
self.proj_layout_k = nn.Linear(in_features = embed_dim, out_features = inner_dim)
|
104 |
+
self.proj_layout_q = nn.Linear(in_features = embed_dim, out_features = inner_dim)
|
105 |
+
self.proj_layout_v = nn.Linear(in_features = embed_dim, out_features = inner_dim)
|
106 |
+
|
107 |
+
self.attend = nn.Softmax(dim = -1)
|
108 |
+
self.scale = dim_head ** -0.5
|
109 |
+
|
110 |
+
self.dropout = nn.Dropout(dropout)
|
111 |
+
self.to_out_l = nn.Sequential(
|
112 |
+
nn.Linear(inner_dim, embed_dim),
|
113 |
+
nn.Dropout(dropout)
|
114 |
+
)
|
115 |
+
self.to_out_t = nn.Sequential(
|
116 |
+
nn.Linear(inner_dim, embed_dim),
|
117 |
+
nn.Dropout(dropout)
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, text_feature, layout_feature):
|
121 |
+
|
122 |
+
query_vec_t = rearrange(self.proj_text_q(text_feature), 'b t (head k) -> head b t k', head=self.n_heads) ## batch, 512, 768 -> 8, batch, 512, 64
|
123 |
+
key_vec_t = rearrange(self.proj_text_k(text_feature), 'b t (head k) -> head b t k', head=self.n_heads)
|
124 |
+
value_vec_t = rearrange(self.proj_text_v(text_feature), 'b t (head k) -> head b t k', head=self.n_heads)
|
125 |
+
|
126 |
+
query_vec_l = rearrange(self.proj_layout_q(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
|
127 |
+
key_vec_l = rearrange(self.proj_layout_k(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
|
128 |
+
value_vec_l = rearrange(self.proj_layout_v(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads)
|
129 |
+
|
130 |
+
attn_t = torch.einsum('hblk,hbtk->hblt', query_vec_t, key_vec_t) * self.scale
|
131 |
+
attn_l = torch.einsum('hblk,hbtk->hblt', query_vec_l, key_vec_l) * self.scale
|
132 |
+
|
133 |
+
attn_tilde_t = attn_t + attn_l
|
134 |
+
|
135 |
+
if self.fine_tune:
|
136 |
+
attn_tilde_l = attn_l + attn_t
|
137 |
+
else:
|
138 |
+
attn_tilde_l = attn_l + attn_t.detach()
|
139 |
+
|
140 |
+
text_attn_probs = self.dropout(self.attend(attn_tilde_t))
|
141 |
+
layout_attn_probs = self.dropout(self.attend(attn_tilde_l))
|
142 |
+
|
143 |
+
text_context = rearrange(torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_vec_t), 'h b l k -> b l (h k)')
|
144 |
+
layout_context = rearrange(torch.einsum('hblt,hbtv->hblv', layout_attn_probs, value_vec_l), 'h b l k -> b l (h k)')
|
145 |
+
|
146 |
+
text_context = self.to_out_t(text_context)
|
147 |
+
layout_context = self.to_out_l(layout_context)
|
148 |
+
|
149 |
+
return {'layout_feature': layout_context, 'text_feature': text_context,
|
150 |
+
'layout_attention': attn_l,'textual_attention': attn_t}
|
151 |
+
|
152 |
+
|
153 |
+
## Constructing the Encoder Layer
|
154 |
+
|
155 |
+
class PreNorm(nn.Module):
|
156 |
+
def __init__(self, dim, fn, eps = 1e-12):
|
157 |
+
super().__init__()
|
158 |
+
self.norm = nn.LayerNorm(dim, eps = eps)
|
159 |
+
self.fn = fn
|
160 |
+
|
161 |
+
def forward(self, x, **kwargs):
|
162 |
+
return self.fn(self.norm(x), **kwargs)
|
163 |
+
|
164 |
+
class PreNormAttn(nn.Module):
|
165 |
+
def __init__(self, dim, fn, eps = 1e-12):
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
self.norm_t = nn.LayerNorm(dim, eps = eps)
|
169 |
+
self.norm_l = nn.LayerNorm(dim, eps = eps)
|
170 |
+
self.fn = fn
|
171 |
+
|
172 |
+
def forward(self, text_feat, layout_feat, **kwargs):
|
173 |
+
return self.fn(self.norm_t(text_feat),
|
174 |
+
self.norm_l(layout_feat),**kwargs)
|
175 |
+
|
176 |
+
|
177 |
+
## FFN Network
|
178 |
+
class FeedForward(nn.Module):
|
179 |
+
def __init__(self, dim : int = 768, hidden_dim : int = 4 * 768, dropout=0.):
|
180 |
+
super().__init__()
|
181 |
+
self.net = nn.Sequential(
|
182 |
+
nn.Linear(dim, hidden_dim),
|
183 |
+
nn.GELU(),
|
184 |
+
nn.Dropout(dropout),
|
185 |
+
nn.Linear(hidden_dim, dim),
|
186 |
+
nn.Dropout(dropout)
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
return self.net(x)
|
191 |
+
|
192 |
+
|
193 |
+
## Encoder
|
194 |
+
class LiLTEncoder(nn.Module):
|
195 |
+
def __init__(self, config):
|
196 |
+
super().__init__()
|
197 |
+
self.config = config
|
198 |
+
self.layers = nn.ModuleList([])
|
199 |
+
for _ in range(config['num_hidden_layers']):
|
200 |
+
encoder_block = nn.ModuleList([
|
201 |
+
PreNormAttn(config['hidden_size'],
|
202 |
+
MultiModalAttentionLayer(embed_dim = config['hidden_size'],
|
203 |
+
n_heads = config['num_attention_heads'],
|
204 |
+
dim_head = config['dim_head'],
|
205 |
+
fine_tune = config['fine_tune'],
|
206 |
+
dropout = config['hidden_dropout_prob'],
|
207 |
+
),
|
208 |
+
eps = config['eps']
|
209 |
+
),
|
210 |
+
PreNorm(config['hidden_size'],
|
211 |
+
FeedForward(config['hidden_size'],
|
212 |
+
config['hidden_size'] * config['intermediate_ff_size_factor'],
|
213 |
+
dropout=config['hidden_dropout_prob'],
|
214 |
+
),
|
215 |
+
eps = config['eps']),
|
216 |
+
PreNorm(config['hidden_size'],
|
217 |
+
FeedForward(config['hidden_size'],
|
218 |
+
config['hidden_size'] * config['intermediate_ff_size_factor'],
|
219 |
+
dropout=config['hidden_dropout_prob']
|
220 |
+
),
|
221 |
+
eps = config['eps'])
|
222 |
+
])
|
223 |
+
self.layers.append(encoder_block)
|
224 |
+
|
225 |
+
def forward(
|
226 |
+
self,
|
227 |
+
text_feat,
|
228 |
+
layout_feat,
|
229 |
+
):
|
230 |
+
|
231 |
+
text_attn = []
|
232 |
+
layout_attn = []
|
233 |
+
text_hidden_states = []
|
234 |
+
layout_hidden_states = []
|
235 |
+
|
236 |
+
for attn, ff_t, ff_l in self.layers:
|
237 |
+
|
238 |
+
context_vec = attn(text_feat, layout_feat)
|
239 |
+
text_feat = text_feat + context_vec['text_feature']
|
240 |
+
layout_feat = layout_feat + context_vec['layout_feature']
|
241 |
+
|
242 |
+
text_feat = ff_t(text_feat) + text_feat
|
243 |
+
layout_feat = ff_l(layout_feat) + layout_feat
|
244 |
+
|
245 |
+
text_attn.append(context_vec['textual_attention'])
|
246 |
+
layout_attn.append(context_vec['layout_attention'])
|
247 |
+
text_hidden_states.append(text_feat)
|
248 |
+
layout_hidden_states.append(layout_feat)
|
249 |
+
|
250 |
+
return {'text_hidden_states' : text_hidden_states, 'layout_hidden_states': layout_hidden_states,
|
251 |
+
'text_attn' : text_attn, 'layout_attn' : layout_attn}
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
## Constructing the whole model from embeddings to the hidden states and attention
|
256 |
+
class LiLT(nn.Module):
|
257 |
+
|
258 |
+
def __init__(self, config):
|
259 |
+
super(LiLT, self).__init__()
|
260 |
+
self.lilt = LiLTEncoder(config)
|
261 |
+
self.emb = Embedding(vocab_size = config['vocab_size'],
|
262 |
+
hidden_dim_t = config['hidden_size_t'],
|
263 |
+
hidden_dim_l = config['hidden_size_l'],
|
264 |
+
max_x_coord = config['max_2d_position_embeddings'],
|
265 |
+
max_y_coord = config['max_2d_position_embeddings'],
|
266 |
+
max_seq_len_t = config['max_seq_len_t'],
|
267 |
+
max_seq_len_l = config['max_seq_len_l'])
|
268 |
+
|
269 |
+
|
270 |
+
def forward(self, tokenized_words, tokenized_bbox):
|
271 |
+
hidden_enc = self.emb(tokenized_words, tokenized_bbox)
|
272 |
+
encodings = self.lilt(hidden_enc['text_feature'], hidden_enc['layout_feature'])
|
273 |
+
return encodings
|