Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from einops import rearrange | |
from torch import Tensor | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): | |
super().__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
self.max_len = max_len | |
self.d_model = d_model | |
position = torch.arange(max_len).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) | |
pe = torch.zeros(1, max_len, d_model) | |
pe[0, :, 0::2] = torch.sin(position * div_term) | |
pe[0, :, 1::2] = torch.cos(position * div_term) | |
self.register_buffer("pe", pe) | |
def forward(self) -> Tensor: | |
x = self.pe[0, : self.max_len] | |
return self.dropout(x).unsqueeze(0) | |
class ResNetFeatureExtractor(nn.Module): | |
def __init__(self, hidden_dim = 512): | |
super().__init__() | |
# Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction | |
resnet50 = models.resnet50(pretrained=False) | |
modules = list(resnet50.children())[:-2] | |
self.resnet50 = nn.Sequential(*modules) | |
# Applying convolution and linear layer | |
self.conv1 = nn.Conv2d(2048, 768, 1) | |
self.relu1 = F.relu | |
self.linear1 = nn.Linear(192, hidden_dim) | |
def forward(self, x): | |
x = self.resnet50(x) | |
x = self.conv1(x) | |
x = self.relu1(x) | |
x = rearrange(x, "b e w h -> b e (w h)") # b -> batch, e -> embedding dim, w -> width, h -> height | |
x = self.linear1(x) | |
x = rearrange(x, "b e s -> b s e") # b -> batch, e -> embedding dim, s -> sequence length | |
return x | |
class DocFormerEmbeddings(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings.""" | |
def __init__(self, config): | |
super(DocFormerEmbeddings, self).__init__() | |
self.config = config | |
self.position_embeddings_v = PositionalEncoding( | |
d_model=config["hidden_size"], | |
dropout=0.1, | |
max_len=config["max_position_embeddings"], | |
) | |
self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) | |
self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) | |
self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.position_embeddings_t = PositionalEncoding( | |
d_model=config["hidden_size"], | |
dropout=0.1, | |
max_len=config["max_position_embeddings"], | |
) | |
self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) | |
self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"]) | |
self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"]) | |
self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"]) | |
self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"]) | |
self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"]) | |
self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"]) | |
self.dropout = nn.Dropout(config["hidden_dropout_prob"]) | |
def forward(self, x_feature, y_feature): | |
""" | |
Arguments: | |
x_features of shape, (batch size, seq_len, 8) | |
y_features of shape, (batch size, seq_len, 8) | |
Outputs: | |
(V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768) | |
What are the features: | |
0 -> top left x/y | |
1 -> bottom right x/y | |
2 -> width/height | |
3 -> diff top left x/y | |
4 -> diff bottom left x/y | |
5 -> diff top right x/y | |
6 -> diff bottom right x/y | |
7 -> centroids diff x/y | |
""" | |
batch, seq_len = x_feature.shape[:-1] | |
hidden_size = self.config["hidden_size"] | |
num_feat = x_feature.shape[-1] | |
sub_dim = hidden_size // num_feat | |
# Clamping and adding a bias for handling negative values | |
x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"]) | |
x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"] | |
y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"]) | |
y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"] | |
x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0]) | |
x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1]) | |
w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2]) | |
x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3]) | |
x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4]) | |
x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5]) | |
x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6]) | |
x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7]) | |
x_calculated_embedding_v = torch.cat( | |
[ | |
x_topleft_position_embeddings_v, | |
x_bottomright_position_embeddings_v, | |
w_position_embeddings_v, | |
x_topleft_distance_to_prev_embeddings_v, | |
x_bottomleft_distance_to_prev_embeddings_v, | |
x_topright_distance_to_prev_embeddings_v, | |
x_bottomright_distance_to_prev_embeddings_v , | |
x_centroid_distance_to_prev_embeddings_v | |
], | |
dim = -1 | |
) | |
y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0]) | |
y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1]) | |
h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2]) | |
y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3]) | |
y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4]) | |
y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5]) | |
y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6]) | |
y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7]) | |
x_calculated_embedding_v = torch.cat( | |
[ | |
x_topleft_position_embeddings_v, | |
x_bottomright_position_embeddings_v, | |
w_position_embeddings_v, | |
x_topleft_distance_to_prev_embeddings_v, | |
x_bottomleft_distance_to_prev_embeddings_v, | |
x_topright_distance_to_prev_embeddings_v, | |
x_bottomright_distance_to_prev_embeddings_v , | |
x_centroid_distance_to_prev_embeddings_v | |
], | |
dim = -1 | |
) | |
y_calculated_embedding_v = torch.cat( | |
[ | |
y_topleft_position_embeddings_v, | |
y_bottomright_position_embeddings_v, | |
h_position_embeddings_v, | |
y_topleft_distance_to_prev_embeddings_v, | |
y_bottomleft_distance_to_prev_embeddings_v, | |
y_topright_distance_to_prev_embeddings_v, | |
y_bottomright_distance_to_prev_embeddings_v , | |
y_centroid_distance_to_prev_embeddings_v | |
], | |
dim = -1 | |
) | |
v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v() | |
x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0]) | |
x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1]) | |
w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2]) | |
x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3]) | |
x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4]) | |
x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5]) | |
x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6]) | |
x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7]) | |
x_calculated_embedding_t = torch.cat( | |
[ | |
x_topleft_position_embeddings_t, | |
x_bottomright_position_embeddings_t, | |
w_position_embeddings_t, | |
x_topleft_distance_to_prev_embeddings_t, | |
x_bottomleft_distance_to_prev_embeddings_t, | |
x_topright_distance_to_prev_embeddings_t, | |
x_bottomright_distance_to_prev_embeddings_t , | |
x_centroid_distance_to_prev_embeddings_t | |
], | |
dim = -1 | |
) | |
y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0]) | |
y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1]) | |
h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2]) | |
y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3]) | |
y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4]) | |
y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5]) | |
y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6]) | |
y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7]) | |
x_calculated_embedding_t = torch.cat( | |
[ | |
x_topleft_position_embeddings_t, | |
x_bottomright_position_embeddings_t, | |
w_position_embeddings_t, | |
x_topleft_distance_to_prev_embeddings_t, | |
x_bottomleft_distance_to_prev_embeddings_t, | |
x_topright_distance_to_prev_embeddings_t, | |
x_bottomright_distance_to_prev_embeddings_t , | |
x_centroid_distance_to_prev_embeddings_t | |
], | |
dim = -1 | |
) | |
y_calculated_embedding_t = torch.cat( | |
[ | |
y_topleft_position_embeddings_t, | |
y_bottomright_position_embeddings_t, | |
h_position_embeddings_t, | |
y_topleft_distance_to_prev_embeddings_t, | |
y_bottomleft_distance_to_prev_embeddings_t, | |
y_topright_distance_to_prev_embeddings_t, | |
y_bottomright_distance_to_prev_embeddings_t , | |
y_centroid_distance_to_prev_embeddings_t | |
], | |
dim = -1 | |
) | |
t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t() | |
return v_bar_s, t_bar_s | |
# fmt: off | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
# Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(self.norm(x), **kwargs) | |
class PreNormAttn(nn.Module): | |
def __init__(self, dim, fn): | |
# Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf | |
super().__init__() | |
self.norm_t_bar = nn.LayerNorm(dim) | |
self.norm_v_bar = nn.LayerNorm(dim) | |
self.norm_t_bar_s = nn.LayerNorm(dim) | |
self.norm_v_bar_s = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs): | |
return self.fn(self.norm_t_bar(t_bar), | |
self.norm_v_bar(v_bar), | |
self.norm_t_bar_s(t_bar_s), | |
self.norm_v_bar_s(v_bar_s), **kwargs) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout=0.): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, hidden_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class RelativePosition(nn.Module): | |
def __init__(self, num_units, max_relative_position, max_seq_length): | |
super().__init__() | |
self.num_units = num_units | |
self.max_relative_position = max_relative_position | |
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units)) | |
self.max_length = max_seq_length | |
range_vec_q = torch.arange(max_seq_length) | |
range_vec_k = torch.arange(max_seq_length) | |
distance_mat = range_vec_k[None, :] - range_vec_q[:, None] | |
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) | |
final_mat = distance_mat_clipped + self.max_relative_position | |
self.final_mat = torch.LongTensor(final_mat) | |
nn.init.xavier_uniform_(self.embeddings_table) | |
def forward(self, length_q, length_k): | |
embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]] | |
return embeddings | |
class MultiModalAttentionLayer(nn.Module): | |
def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout): | |
super().__init__() | |
assert embed_dim % n_heads == 0 | |
self.embed_dim = embed_dim | |
self.n_heads = n_heads | |
self.head_dim = embed_dim // n_heads | |
self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length) | |
self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length) | |
# text qkv embeddings | |
self.fc_k_text = nn.Linear(embed_dim, embed_dim) | |
self.fc_q_text = nn.Linear(embed_dim, embed_dim) | |
self.fc_v_text = nn.Linear(embed_dim, embed_dim) | |
# image qkv embeddings | |
self.fc_k_img = nn.Linear(embed_dim, embed_dim) | |
self.fc_q_img = nn.Linear(embed_dim, embed_dim) | |
self.fc_v_img = nn.Linear(embed_dim, embed_dim) | |
# spatial qk embeddings (shared for visual and text) | |
self.fc_k_spatial = nn.Linear(embed_dim, embed_dim) | |
self.fc_q_spatial = nn.Linear(embed_dim, embed_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.to_out = nn.Sequential( | |
nn.Linear(embed_dim, embed_dim), | |
nn.Dropout(dropout) | |
) | |
self.scale = embed_dim**0.5 | |
def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat): | |
text_feat = text_feat | |
img_feat = img_feat | |
text_spatial_feat = text_spatial_feat | |
img_spatial_feat = img_spatial_feat | |
seq_length = text_feat.shape[1] | |
# self attention of text | |
# b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim. | |
key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads) | |
query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads) | |
value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads) | |
dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh) | |
dots_text = dots_text/ self.scale | |
# 1D relative positions (query, key) | |
rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length) | |
rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text) | |
rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text) | |
# shared spatial <-> text hidden features | |
key_spatial_text = self.fc_k_spatial(text_spatial_feat) | |
query_spatial_text = self.fc_q_spatial(text_spatial_feat) | |
key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads) | |
query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads) | |
dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh) | |
dots_text_spatial = dots_text_spatial/ self.scale | |
# Line 38 of pseudo-code | |
text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial | |
# self-attention of image | |
key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads) | |
query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads) | |
value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads) | |
dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh) | |
dots_img = dots_img/ self.scale | |
# 1D relative positions (query, key) | |
rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length) | |
rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text) | |
rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text) | |
# shared spatial <-> image features | |
key_spatial_img = self.fc_k_spatial(img_spatial_feat) | |
query_spatial_img = self.fc_q_spatial(img_spatial_feat) | |
key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads) | |
query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads) | |
dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh) | |
dots_img_spatial = dots_img_spatial/ self.scale | |
# Line 59 of pseudo-code | |
img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial | |
text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1)) | |
img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1)) | |
text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh) | |
img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh) | |
context = text_context + img_context | |
embeddings = rearrange(context, 'head b t d -> b t (head d)') | |
return self.to_out(embeddings) | |
class DocFormerEncoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.layers = nn.ModuleList([]) | |
for _ in range(config['num_hidden_layers']): | |
encoder_block = nn.ModuleList([ | |
PreNormAttn(config['hidden_size'], | |
MultiModalAttentionLayer(config['hidden_size'], | |
config['num_attention_heads'], | |
config['max_relative_positions'], | |
config['max_position_embeddings'], | |
config['hidden_dropout_prob'], | |
) | |
), | |
PreNorm(config['hidden_size'], | |
FeedForward(config['hidden_size'], | |
config['hidden_size'] * config['intermediate_ff_size_factor'], | |
dropout=config['hidden_dropout_prob'])) | |
]) | |
self.layers.append(encoder_block) | |
def forward( | |
self, | |
text_feat, # text feat or output from last encoder block | |
img_feat, | |
text_spatial_feat, | |
img_spatial_feat, | |
): | |
# Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762 | |
# TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does) | |
for attn, ff in self.layers: | |
skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat | |
x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip | |
x = ff(x) + x | |
text_feat = x | |
return x | |
class LanguageFeatureExtractor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
from transformers import LayoutLMForTokenClassification | |
layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1) | |
self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight) | |
def forward(self, x): | |
return self.embedding_vector(x) | |
class ExtractFeatures(nn.Module): | |
''' | |
Inputs: dictionary | |
Output: v_bar, t_bar, v_bar_s, t_bar_s | |
''' | |
def __init__(self, config): | |
super().__init__() | |
self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings']) | |
self.language_feature = LanguageFeatureExtractor() | |
self.spatial_feature = DocFormerEmbeddings(config) | |
def forward(self, encoding): | |
image = encoding['resized_scaled_img'] | |
language = encoding['input_ids'] | |
x_feature = encoding['x_features'] | |
y_feature = encoding['y_features'] | |
v_bar = self.visual_feature(image) | |
t_bar = self.language_feature(language) | |
v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature) | |
return v_bar, t_bar, v_bar_s, t_bar_s | |
class DocFormer(nn.Module): | |
''' | |
Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function | |
''' | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.extract_feature = ExtractFeatures(config) | |
self.encoder = DocFormerEncoder(config) | |
self.dropout = nn.Dropout(config['hidden_dropout_prob']) | |
def forward(self, x ,use_tdi=False): | |
v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi) | |
features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s} | |
output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s']) | |
output = self.dropout(output) | |
return output | |