iakarshu's picture
Upload modeling.py
15dee1b
raw
history blame
26.3 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange
from torch import Tensor
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.max_len = max_len
self.d_model = d_model
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self) -> Tensor:
x = self.pe[0, : self.max_len]
return self.dropout(x).unsqueeze(0)
class ResNetFeatureExtractor(nn.Module):
def __init__(self, hidden_dim = 512):
super().__init__()
# Making the resnet 50 model, which was used in the docformer for the purpose of visual feature extraction
resnet50 = models.resnet50(pretrained=False)
modules = list(resnet50.children())[:-2]
self.resnet50 = nn.Sequential(*modules)
# Applying convolution and linear layer
self.conv1 = nn.Conv2d(2048, 768, 1)
self.relu1 = F.relu
self.linear1 = nn.Linear(192, hidden_dim)
def forward(self, x):
x = self.resnet50(x)
x = self.conv1(x)
x = self.relu1(x)
x = rearrange(x, "b e w h -> b e (w h)") # b -> batch, e -> embedding dim, w -> width, h -> height
x = self.linear1(x)
x = rearrange(x, "b e s -> b s e") # b -> batch, e -> embedding dim, s -> sequence length
return x
class DocFormerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super(DocFormerEmbeddings, self).__init__()
self.config = config
self.position_embeddings_v = PositionalEncoding(
d_model=config["hidden_size"],
dropout=0.1,
max_len=config["max_position_embeddings"],
)
self.x_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.x_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.w_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
self.x_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_topleft_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.y_bottomright_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.h_position_embeddings_v = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
self.y_topleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_bottomleft_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_topright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_bottomright_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_centroid_distance_to_prev_embeddings_v = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.position_embeddings_t = PositionalEncoding(
d_model=config["hidden_size"],
dropout=0.1,
max_len=config["max_position_embeddings"],
)
self.x_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.x_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.w_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
self.x_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
self.x_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"]+1, config["shape_size"])
self.x_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.x_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_topleft_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.y_bottomright_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["coordinate_size"])
self.h_position_embeddings_t = nn.Embedding(config["max_2d_position_embeddings"], config["shape_size"])
self.y_topleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_bottomleft_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_topright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_bottomright_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.y_centroid_distance_to_prev_embeddings_t = nn.Embedding(2*config["max_2d_position_embeddings"] + 1, config["shape_size"])
self.LayerNorm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
self.dropout = nn.Dropout(config["hidden_dropout_prob"])
def forward(self, x_feature, y_feature):
"""
Arguments:
x_features of shape, (batch size, seq_len, 8)
y_features of shape, (batch size, seq_len, 8)
Outputs:
(V-bar-s, T-bar-s) of shape (batch size, 512,768),(batch size, 512,768)
What are the features:
0 -> top left x/y
1 -> bottom right x/y
2 -> width/height
3 -> diff top left x/y
4 -> diff bottom left x/y
5 -> diff top right x/y
6 -> diff bottom right x/y
7 -> centroids diff x/y
"""
batch, seq_len = x_feature.shape[:-1]
hidden_size = self.config["hidden_size"]
num_feat = x_feature.shape[-1]
sub_dim = hidden_size // num_feat
# Clamping and adding a bias for handling negative values
x_feature[:,:,3:] = torch.clamp(x_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
x_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
y_feature[:,:,3:] = torch.clamp(y_feature[:,:,3:],-self.config["max_2d_position_embeddings"],self.config["max_2d_position_embeddings"])
y_feature[:,:,3:]+= self.config["max_2d_position_embeddings"]
x_topleft_position_embeddings_v = self.x_topleft_position_embeddings_v(x_feature[:,:,0])
x_bottomright_position_embeddings_v = self.x_bottomright_position_embeddings_v(x_feature[:,:,1])
w_position_embeddings_v = self.w_position_embeddings_v(x_feature[:,:,2])
x_topleft_distance_to_prev_embeddings_v = self.x_topleft_distance_to_prev_embeddings_v(x_feature[:,:,3])
x_bottomleft_distance_to_prev_embeddings_v = self.x_bottomleft_distance_to_prev_embeddings_v(x_feature[:,:,4])
x_topright_distance_to_prev_embeddings_v = self.x_topright_distance_to_prev_embeddings_v(x_feature[:,:,5])
x_bottomright_distance_to_prev_embeddings_v = self.x_bottomright_distance_to_prev_embeddings_v(x_feature[:,:,6])
x_centroid_distance_to_prev_embeddings_v = self.x_centroid_distance_to_prev_embeddings_v(x_feature[:,:,7])
x_calculated_embedding_v = torch.cat(
[
x_topleft_position_embeddings_v,
x_bottomright_position_embeddings_v,
w_position_embeddings_v,
x_topleft_distance_to_prev_embeddings_v,
x_bottomleft_distance_to_prev_embeddings_v,
x_topright_distance_to_prev_embeddings_v,
x_bottomright_distance_to_prev_embeddings_v ,
x_centroid_distance_to_prev_embeddings_v
],
dim = -1
)
y_topleft_position_embeddings_v = self.y_topleft_position_embeddings_v(y_feature[:,:,0])
y_bottomright_position_embeddings_v = self.y_bottomright_position_embeddings_v(y_feature[:,:,1])
h_position_embeddings_v = self.h_position_embeddings_v(y_feature[:,:,2])
y_topleft_distance_to_prev_embeddings_v = self.y_topleft_distance_to_prev_embeddings_v(y_feature[:,:,3])
y_bottomleft_distance_to_prev_embeddings_v = self.y_bottomleft_distance_to_prev_embeddings_v(y_feature[:,:,4])
y_topright_distance_to_prev_embeddings_v = self.y_topright_distance_to_prev_embeddings_v(y_feature[:,:,5])
y_bottomright_distance_to_prev_embeddings_v = self.y_bottomright_distance_to_prev_embeddings_v(y_feature[:,:,6])
y_centroid_distance_to_prev_embeddings_v = self.y_centroid_distance_to_prev_embeddings_v(y_feature[:,:,7])
x_calculated_embedding_v = torch.cat(
[
x_topleft_position_embeddings_v,
x_bottomright_position_embeddings_v,
w_position_embeddings_v,
x_topleft_distance_to_prev_embeddings_v,
x_bottomleft_distance_to_prev_embeddings_v,
x_topright_distance_to_prev_embeddings_v,
x_bottomright_distance_to_prev_embeddings_v ,
x_centroid_distance_to_prev_embeddings_v
],
dim = -1
)
y_calculated_embedding_v = torch.cat(
[
y_topleft_position_embeddings_v,
y_bottomright_position_embeddings_v,
h_position_embeddings_v,
y_topleft_distance_to_prev_embeddings_v,
y_bottomleft_distance_to_prev_embeddings_v,
y_topright_distance_to_prev_embeddings_v,
y_bottomright_distance_to_prev_embeddings_v ,
y_centroid_distance_to_prev_embeddings_v
],
dim = -1
)
v_bar_s = x_calculated_embedding_v + y_calculated_embedding_v + self.position_embeddings_v()
x_topleft_position_embeddings_t = self.x_topleft_position_embeddings_t(x_feature[:,:,0])
x_bottomright_position_embeddings_t = self.x_bottomright_position_embeddings_t(x_feature[:,:,1])
w_position_embeddings_t = self.w_position_embeddings_t(x_feature[:,:,2])
x_topleft_distance_to_prev_embeddings_t = self.x_topleft_distance_to_prev_embeddings_t(x_feature[:,:,3])
x_bottomleft_distance_to_prev_embeddings_t = self.x_bottomleft_distance_to_prev_embeddings_t(x_feature[:,:,4])
x_topright_distance_to_prev_embeddings_t = self.x_topright_distance_to_prev_embeddings_t(x_feature[:,:,5])
x_bottomright_distance_to_prev_embeddings_t = self.x_bottomright_distance_to_prev_embeddings_t(x_feature[:,:,6])
x_centroid_distance_to_prev_embeddings_t = self.x_centroid_distance_to_prev_embeddings_t(x_feature[:,:,7])
x_calculated_embedding_t = torch.cat(
[
x_topleft_position_embeddings_t,
x_bottomright_position_embeddings_t,
w_position_embeddings_t,
x_topleft_distance_to_prev_embeddings_t,
x_bottomleft_distance_to_prev_embeddings_t,
x_topright_distance_to_prev_embeddings_t,
x_bottomright_distance_to_prev_embeddings_t ,
x_centroid_distance_to_prev_embeddings_t
],
dim = -1
)
y_topleft_position_embeddings_t = self.y_topleft_position_embeddings_t(y_feature[:,:,0])
y_bottomright_position_embeddings_t = self.y_bottomright_position_embeddings_t(y_feature[:,:,1])
h_position_embeddings_t = self.h_position_embeddings_t(y_feature[:,:,2])
y_topleft_distance_to_prev_embeddings_t = self.y_topleft_distance_to_prev_embeddings_t(y_feature[:,:,3])
y_bottomleft_distance_to_prev_embeddings_t = self.y_bottomleft_distance_to_prev_embeddings_t(y_feature[:,:,4])
y_topright_distance_to_prev_embeddings_t = self.y_topright_distance_to_prev_embeddings_t(y_feature[:,:,5])
y_bottomright_distance_to_prev_embeddings_t = self.y_bottomright_distance_to_prev_embeddings_t(y_feature[:,:,6])
y_centroid_distance_to_prev_embeddings_t = self.y_centroid_distance_to_prev_embeddings_t(y_feature[:,:,7])
x_calculated_embedding_t = torch.cat(
[
x_topleft_position_embeddings_t,
x_bottomright_position_embeddings_t,
w_position_embeddings_t,
x_topleft_distance_to_prev_embeddings_t,
x_bottomleft_distance_to_prev_embeddings_t,
x_topright_distance_to_prev_embeddings_t,
x_bottomright_distance_to_prev_embeddings_t ,
x_centroid_distance_to_prev_embeddings_t
],
dim = -1
)
y_calculated_embedding_t = torch.cat(
[
y_topleft_position_embeddings_t,
y_bottomright_position_embeddings_t,
h_position_embeddings_t,
y_topleft_distance_to_prev_embeddings_t,
y_bottomleft_distance_to_prev_embeddings_t,
y_topright_distance_to_prev_embeddings_t,
y_bottomright_distance_to_prev_embeddings_t ,
y_centroid_distance_to_prev_embeddings_t
],
dim = -1
)
t_bar_s = x_calculated_embedding_t + y_calculated_embedding_t + self.position_embeddings_t()
return v_bar_s, t_bar_s
# fmt: off
class PreNorm(nn.Module):
def __init__(self, dim, fn):
# Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class PreNormAttn(nn.Module):
def __init__(self, dim, fn):
# Fig 1: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
super().__init__()
self.norm_t_bar = nn.LayerNorm(dim)
self.norm_v_bar = nn.LayerNorm(dim)
self.norm_t_bar_s = nn.LayerNorm(dim)
self.norm_v_bar_s = nn.LayerNorm(dim)
self.fn = fn
def forward(self, t_bar, v_bar, t_bar_s, v_bar_s, **kwargs):
return self.fn(self.norm_t_bar(t_bar),
self.norm_v_bar(v_bar),
self.norm_t_bar_s(t_bar_s),
self.norm_v_bar_s(v_bar_s), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position, max_seq_length):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
self.max_length = max_seq_length
range_vec_q = torch.arange(max_seq_length)
range_vec_k = torch.arange(max_seq_length)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
self.final_mat = torch.LongTensor(final_mat)
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
embeddings = self.embeddings_table[self.final_mat[:length_q, :length_k]]
return embeddings
class MultiModalAttentionLayer(nn.Module):
def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dropout):
super().__init__()
assert embed_dim % n_heads == 0
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
self.relative_positions_text = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
self.relative_positions_img = RelativePosition(self.head_dim, max_relative_position, max_seq_length)
# text qkv embeddings
self.fc_k_text = nn.Linear(embed_dim, embed_dim)
self.fc_q_text = nn.Linear(embed_dim, embed_dim)
self.fc_v_text = nn.Linear(embed_dim, embed_dim)
# image qkv embeddings
self.fc_k_img = nn.Linear(embed_dim, embed_dim)
self.fc_q_img = nn.Linear(embed_dim, embed_dim)
self.fc_v_img = nn.Linear(embed_dim, embed_dim)
# spatial qk embeddings (shared for visual and text)
self.fc_k_spatial = nn.Linear(embed_dim, embed_dim)
self.fc_q_spatial = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.Dropout(dropout)
)
self.scale = embed_dim**0.5
def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
text_feat = text_feat
img_feat = img_feat
text_spatial_feat = text_spatial_feat
img_spatial_feat = img_spatial_feat
seq_length = text_feat.shape[1]
# self attention of text
# b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads)
value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh)
dots_text = dots_text/ self.scale
# 1D relative positions (query, key)
rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length)
rel_pos_key_text = torch.einsum('bhrd,lrd->bhlr', key_text_nh, rel_pos_embed_text)
rel_pos_query_text = torch.einsum('bhld,lrd->bhlr', query_text_nh, rel_pos_embed_text)
# shared spatial <-> text hidden features
key_spatial_text = self.fc_k_spatial(text_spatial_feat)
query_spatial_text = self.fc_q_spatial(text_spatial_feat)
key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads)
query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads)
dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh)
dots_text_spatial = dots_text_spatial/ self.scale
# Line 38 of pseudo-code
text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial
# self-attention of image
key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads)
value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh)
dots_img = dots_img/ self.scale
# 1D relative positions (query, key)
rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length)
rel_pos_key_img = torch.einsum('bhrd,lrd->bhlr', key_img_nh, rel_pos_embed_text)
rel_pos_query_img = torch.einsum('bhld,lrd->bhlr', query_img_nh, rel_pos_embed_text)
# shared spatial <-> image features
key_spatial_img = self.fc_k_spatial(img_spatial_feat)
query_spatial_img = self.fc_q_spatial(img_spatial_feat)
key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads)
query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads)
dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh)
dots_img_spatial = dots_img_spatial/ self.scale
# Line 59 of pseudo-code
img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial
text_attn_probs = self.dropout(torch.softmax(text_attn_scores, dim=-1))
img_attn_probs = self.dropout(torch.softmax(img_attn_scores, dim=-1))
text_context = torch.einsum('hblt,hbtv->hblv', text_attn_probs, value_text_nh)
img_context = torch.einsum('hblt,hbtv->hblv', img_attn_probs, value_img_nh)
context = text_context + img_context
embeddings = rearrange(context, 'head b t d -> b t (head d)')
return self.to_out(embeddings)
class DocFormerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([])
for _ in range(config['num_hidden_layers']):
encoder_block = nn.ModuleList([
PreNormAttn(config['hidden_size'],
MultiModalAttentionLayer(config['hidden_size'],
config['num_attention_heads'],
config['max_relative_positions'],
config['max_position_embeddings'],
config['hidden_dropout_prob'],
)
),
PreNorm(config['hidden_size'],
FeedForward(config['hidden_size'],
config['hidden_size'] * config['intermediate_ff_size_factor'],
dropout=config['hidden_dropout_prob']))
])
self.layers.append(encoder_block)
def forward(
self,
text_feat, # text feat or output from last encoder block
img_feat,
text_spatial_feat,
img_spatial_feat,
):
# Fig 1 encoder part (skip conn for both attn & FF): https://arxiv.org/abs/1706.03762
# TODO: ensure 1st skip conn (var "skip") in such a multimodal setting makes sense (most likely does)
for attn, ff in self.layers:
skip = text_feat + img_feat + text_spatial_feat + img_spatial_feat
x = attn(text_feat, img_feat, text_spatial_feat, img_spatial_feat) + skip
x = ff(x) + x
text_feat = x
return x
class LanguageFeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
from transformers import LayoutLMForTokenClassification
layoutlm_dummy = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=1)
self.embedding_vector = nn.Embedding.from_pretrained(layoutlm_dummy.layoutlm.embeddings.word_embeddings.weight)
def forward(self, x):
return self.embedding_vector(x)
class ExtractFeatures(nn.Module):
'''
Inputs: dictionary
Output: v_bar, t_bar, v_bar_s, t_bar_s
'''
def __init__(self, config):
super().__init__()
self.visual_feature = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
self.language_feature = LanguageFeatureExtractor()
self.spatial_feature = DocFormerEmbeddings(config)
def forward(self, encoding):
image = encoding['resized_scaled_img']
language = encoding['input_ids']
x_feature = encoding['x_features']
y_feature = encoding['y_features']
v_bar = self.visual_feature(image)
t_bar = self.language_feature(language)
v_bar_s, t_bar_s = self.spatial_feature(x_feature, y_feature)
return v_bar, t_bar, v_bar_s, t_bar_s
class DocFormer(nn.Module):
'''
Easy boiler plate, because this model will just take as an input, the dictionary which is obtained from create_features function
'''
def __init__(self, config):
super().__init__()
self.config = config
self.extract_feature = ExtractFeatures(config)
self.encoder = DocFormerEncoder(config)
self.dropout = nn.Dropout(config['hidden_dropout_prob'])
def forward(self, x ,use_tdi=False):
v_bar, t_bar, v_bar_s, t_bar_s = self.extract_feature(x,use_tdi)
features = {'v_bar': v_bar, 't_bar': t_bar, 'v_bar_s': v_bar_s, 't_bar_s': t_bar_s}
output = self.encoder(features['t_bar'], features['v_bar'], features['t_bar_s'], features['v_bar_s'])
output = self.dropout(output)
return output