|
|
|
|
|
import torch.nn as nn |
|
import torch |
|
from einops import rearrange |
|
|
|
class Embedding(nn.Module): |
|
|
|
def __init__(self, |
|
vocab_size : int = 50265, |
|
hidden_dim_t : int = 768, |
|
hidden_dim_l : int = 768 // 6, |
|
max_x_coord : int = 1001, |
|
max_y_coord : int = 1001, |
|
max_seq_len_t : int = 512, |
|
max_seq_len_l : int = 512): |
|
|
|
super(Embedding, self).__init__() |
|
self.lang_embedding = nn.Embedding( |
|
num_embeddings = vocab_size, |
|
embedding_dim = hidden_dim_t |
|
) |
|
|
|
self.top_left_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l) |
|
self.top_left_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l) |
|
self.bottom_right_x_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l) |
|
self.bottom_right_y_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l) |
|
self.width_emb = nn.Embedding(num_embeddings = max_x_coord,embedding_dim = hidden_dim_l) |
|
self.height_emb = nn.Embedding(num_embeddings = max_y_coord,embedding_dim = hidden_dim_l) |
|
|
|
self.box_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_l + 1, embedding_dim = 6 * hidden_dim_l) |
|
self.textual_position_embeddings = nn.Embedding(num_embeddings = max_seq_len_t + 1, embedding_dim = hidden_dim_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, tokenized_words, tokenized_bbox): |
|
|
|
|
|
text_len, box_len = tokenized_words.shape[1], tokenized_bbox.shape[1] |
|
word_pos_ids = torch.arange(text_len).unsqueeze(0).to(tokenized_words.device) |
|
box_pos_ids = torch.arange(box_len).unsqueeze(0).to(tokenized_bbox.device) |
|
|
|
|
|
text_feature = self.lang_embedding(tokenized_words) |
|
top_left_x_feat = self.top_left_x_emb(tokenized_bbox[:, :, 0]) |
|
top_left_y_feat = self.top_left_y_emb(tokenized_bbox[:, :, 1]) |
|
bottom_right_x_feat = self.bottom_right_x_emb(tokenized_bbox[:, :, 2]) |
|
bottom_right_y_feat = self.bottom_right_y_emb(tokenized_bbox[:, :, 3]) |
|
width_feat = self.width_emb(tokenized_bbox[:, :, 4]) |
|
height_feat = self.height_emb(tokenized_bbox[:, :, 5]) |
|
|
|
|
|
layout_feature = torch.cat( |
|
[top_left_x_feat, |
|
top_left_y_feat, |
|
bottom_right_x_feat, |
|
bottom_right_y_feat, |
|
width_feat, |
|
height_feat |
|
], |
|
axis = -1 |
|
) |
|
|
|
|
|
pos_emb_t = self.textual_position_embeddings(word_pos_ids) |
|
pos_emb_l = self.box_position_embeddings(box_pos_ids) |
|
|
|
|
|
layout_feature = layout_feature + pos_emb_l |
|
text_feature = text_feature + pos_emb_t |
|
|
|
|
|
|
|
|
|
|
|
return {'layout_feature': layout_feature, 'text_feature': text_feature} |
|
|
|
|
|
|
|
|
|
|
|
class MultiModalAttentionLayer(nn.Module): |
|
|
|
def __init__(self, embed_dim : int = 768, |
|
n_heads : int = 12, |
|
dim_head : int = 64, |
|
fine_tune : bool = False, |
|
dropout : float = 0.0 |
|
): |
|
super(MultiModalAttentionLayer, self).__init__() |
|
|
|
inner_dim = n_heads * dim_head |
|
self.n_heads = n_heads |
|
self.fine_tune = fine_tune |
|
|
|
self.proj_text_k = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
self.proj_text_q = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
self.proj_text_v = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
|
|
self.proj_layout_k = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
self.proj_layout_q = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
self.proj_layout_v = nn.Linear(in_features = embed_dim, out_features = inner_dim) |
|
|
|
self.attend = nn.Softmax(dim = -1) |
|
self.scale = dim_head ** -0.5 |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.to_out_l = nn.Sequential( |
|
nn.Linear(inner_dim, embed_dim), |
|
nn.Dropout(dropout) |
|
) |
|
self.to_out_t = nn.Sequential( |
|
nn.Linear(inner_dim, embed_dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, text_feature, layout_feature): |
|
|
|
query_vec_t = rearrange(self.proj_text_q(text_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
key_vec_t = rearrange(self.proj_text_k(text_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
value_vec_t = rearrange(self.proj_text_v(text_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
|
|
query_vec_l = rearrange(self.proj_layout_q(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
key_vec_l = rearrange(self.proj_layout_k(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
value_vec_l = rearrange(self.proj_layout_v(layout_feature), 'b t (head k) -> head b t k', head=self.n_heads) |
|
|
|
attn_t = torch.einsum('hblk,hbtk->hblt', query_vec_t, key_vec_t) * self.scale |
|
attn_l = torch.einsum('hblk,hbtk->hblt', query_vec_l, key_vec_l) * self.scale |
|
|
|
attn_tilde_t = attn_t + attn_l |
|
|
|
if self.fine_tune: |
|
attn_tilde_l = attn_l + attn_t |
|
else: |
|
attn_tilde_l = attn_l + attn_t.detach() |
|
|
|
text_attn_probs = self.dropout(self.attend(attn_tilde_t)) |
|
layout_attn_probs = self.dropout(self.attend(attn_tilde_l)) |
|
|
|
text_context = rearrange(torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_vec_t), 'h b l k -> b l (h k)') |
|
layout_context = rearrange(torch.einsum('hblt,hbtv->hblv', layout_attn_probs, value_vec_l), 'h b l k -> b l (h k)') |
|
|
|
text_context = self.to_out_t(text_context) |
|
layout_context = self.to_out_l(layout_context) |
|
|
|
return {'layout_feature': layout_context, 'text_feature': text_context, |
|
'layout_attention': attn_l,'textual_attention': attn_t} |
|
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
def __init__(self, dim, fn, eps = 1e-12): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim, eps = eps) |
|
self.fn = fn |
|
|
|
def forward(self, x, **kwargs): |
|
return self.fn(self.norm(x), **kwargs) |
|
|
|
class PreNormAttn(nn.Module): |
|
def __init__(self, dim, fn, eps = 1e-12): |
|
super().__init__() |
|
|
|
self.norm_t = nn.LayerNorm(dim, eps = eps) |
|
self.norm_l = nn.LayerNorm(dim, eps = eps) |
|
self.fn = fn |
|
|
|
def forward(self, text_feat, layout_feat, **kwargs): |
|
return self.fn(self.norm_t(text_feat), |
|
self.norm_l(layout_feat),**kwargs) |
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim : int = 768, hidden_dim : int = 4 * 768, dropout=0.): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
class LiLTEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.layers = nn.ModuleList([]) |
|
for _ in range(config['num_hidden_layers']): |
|
encoder_block = nn.ModuleList([ |
|
PreNormAttn(config['hidden_size'], |
|
MultiModalAttentionLayer(embed_dim = config['hidden_size'], |
|
n_heads = config['num_attention_heads'], |
|
dim_head = config['dim_head'], |
|
fine_tune = config['fine_tune'], |
|
dropout = config['hidden_dropout_prob'], |
|
), |
|
eps = config['eps'] |
|
), |
|
PreNorm(config['hidden_size'], |
|
FeedForward(config['hidden_size'], |
|
config['hidden_size'] * config['intermediate_ff_size_factor'], |
|
dropout=config['hidden_dropout_prob'], |
|
), |
|
eps = config['eps']), |
|
PreNorm(config['hidden_size'], |
|
FeedForward(config['hidden_size'], |
|
config['hidden_size'] * config['intermediate_ff_size_factor'], |
|
dropout=config['hidden_dropout_prob'] |
|
), |
|
eps = config['eps']) |
|
]) |
|
self.layers.append(encoder_block) |
|
|
|
def forward( |
|
self, |
|
text_feat, |
|
layout_feat, |
|
): |
|
|
|
text_attn = [] |
|
layout_attn = [] |
|
text_hidden_states = [] |
|
layout_hidden_states = [] |
|
|
|
for attn, ff_t, ff_l in self.layers: |
|
|
|
context_vec = attn(text_feat, layout_feat) |
|
text_feat = text_feat + context_vec['text_feature'] |
|
layout_feat = layout_feat + context_vec['layout_feature'] |
|
|
|
text_feat = ff_t(text_feat) + text_feat |
|
layout_feat = ff_l(layout_feat) + layout_feat |
|
|
|
text_attn.append(context_vec['textual_attention']) |
|
layout_attn.append(context_vec['layout_attention']) |
|
text_hidden_states.append(text_feat) |
|
layout_hidden_states.append(layout_feat) |
|
|
|
return {'text_hidden_states' : text_hidden_states, 'layout_hidden_states': layout_hidden_states, |
|
'text_attn' : text_attn, 'layout_attn' : layout_attn} |
|
|
|
|
|
|
|
|
|
class LiLT(nn.Module): |
|
|
|
def __init__(self, config): |
|
super(LiLT, self).__init__() |
|
self.lilt = LiLTEncoder(config) |
|
self.emb = Embedding(vocab_size = config['vocab_size'], |
|
hidden_dim_t = config['hidden_size_t'], |
|
hidden_dim_l = config['hidden_size_l'], |
|
max_x_coord = config['max_2d_position_embeddings'], |
|
max_y_coord = config['max_2d_position_embeddings'], |
|
max_seq_len_t = config['max_seq_len_t'], |
|
max_seq_len_l = config['max_seq_len_l']) |
|
|
|
|
|
def forward(self, tokenized_words, tokenized_bbox): |
|
hidden_enc = self.emb(tokenized_words, tokenized_bbox) |
|
encodings = self.lilt(hidden_enc['text_feature'], hidden_enc['layout_feature']) |
|
return encodings |
|
|