|
|
|
from .modeling_bert import BertModel |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class MultiModalBert(BertModel): |
|
def __init__(self, config, embed_dim, pwam_idx=[3,6,9,12], num_heads_fusion=[1,1,1,1], fusion_drop=0.0): |
|
super().__init__(config) |
|
self.pwam_idx = pwam_idx |
|
self.num_heads_fusion = num_heads_fusion |
|
self.fusion_drop = fusion_drop |
|
|
|
pwam_dims=[embed_dim * 2** i for i in range(len(pwam_idx))] |
|
|
|
self.pwams = nn.ModuleList() |
|
self.res_gates = nn.ModuleList() |
|
self.norms = nn.ModuleList() |
|
for i in range(0, len(pwam_idx)): |
|
dim = pwam_dims[i] |
|
fusion = PWAM(768, |
|
dim, |
|
768, |
|
768, |
|
768, |
|
num_heads=num_heads_fusion[i], |
|
dropout=fusion_drop) |
|
self.pwams.append(fusion) |
|
|
|
res_gate = nn.Sequential( |
|
nn.Linear(768, 768, bias=False), |
|
nn.ReLU(), |
|
nn.Linear(768, 768, bias=False), |
|
nn.Tanh() |
|
) |
|
nn.init.zeros_(res_gate[0].weight) |
|
nn.init.zeros_(res_gate[2].weight) |
|
self.res_gates.append(res_gate) |
|
|
|
self.norms.append(nn.LayerNorm(768)) |
|
|
|
def forward_stem(self, input_ids, attention_mask): |
|
input_shape = input_ids.size() |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, input_ids.device) |
|
|
|
embedding_output = self.embeddings( |
|
input_ids=input_ids, token_type_ids=token_type_ids |
|
) |
|
|
|
return embedding_output, extended_attention_mask |
|
|
|
def forward_stage1(self, hidden_states, attention_mask): |
|
for i in range(0, self.pwam_idx[0]): |
|
layer_module = self.encoder.layer[i] |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
return layer_outputs[0] |
|
|
|
def forward_stage2(self, hidden_states, attention_mask): |
|
for i in range(self.pwam_idx[0], self.pwam_idx[1]): |
|
layer_module = self.encoder.layer[i] |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
return layer_outputs[0] |
|
|
|
def forward_stage3(self, hidden_states, attention_mask): |
|
for i in range(self.pwam_idx[1], self.pwam_idx[2]): |
|
layer_module = self.encoder.layer[i] |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
return layer_outputs[0] |
|
|
|
def forward_stage4(self, hidden_states, attention_mask): |
|
for i in range(self.pwam_idx[2], self.pwam_idx[3]): |
|
layer_module = self.encoder.layer[i] |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
return layer_outputs[0] |
|
|
|
def forward_pwam1(self, x, l, l_mask): |
|
l_residual = self.pwams[0](x, l, l_mask) |
|
l = l + (self.res_gates[0](l_residual) * l_residual) |
|
return self.norms[0](l_residual), l |
|
|
|
def forward_pwam2(self, x, l, l_mask): |
|
l_residual = self.pwams[1](x, l, l_mask) |
|
l = l + (self.res_gates[1](l_residual) * l_residual) |
|
return self.norms[1](l_residual), l |
|
|
|
def forward_pwam3(self, x, l, l_mask): |
|
l_residual = self.pwams[2](x, l, l_mask) |
|
l = l + (self.res_gates[2](l_residual) * l_residual) |
|
return self.norms[2](l_residual), l |
|
|
|
def forward_pwam4(self, x, l, l_mask): |
|
l_residual = self.pwams[3](x, l, l_mask) |
|
l = l + (self.res_gates[3](l_residual) * l_residual) |
|
return self.norms[3](l_residual), l |
|
|
|
class PWAM(nn.Module): |
|
def __init__(self, dim, v_in_channels, l_in_channels, key_channels, value_channels, num_heads=0, dropout=0.0): |
|
super(PWAM, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vis_project = nn.Sequential(nn.Linear(dim, dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
self.image_lang_att = SpatialImageLanguageAttention(v_in_channels, |
|
l_in_channels, |
|
key_channels, |
|
value_channels, |
|
out_channels=value_channels, |
|
num_heads=num_heads) |
|
|
|
self.project_mm = nn.Sequential(nn.Conv1d(value_channels, value_channels, 1, 1), |
|
nn.GELU(), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
|
|
|
|
|
|
vis = self.vis_project(l) |
|
|
|
lang = self.image_lang_att(x, l, l_mask) |
|
|
|
lang = lang.permute(0, 2, 1) |
|
|
|
|
|
mm = torch.mul(vis.permute(0,2,1), lang) |
|
|
|
mm = self.project_mm(mm) |
|
|
|
mm = mm.permute(0, 2, 1) |
|
|
|
return mm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialImageLanguageAttention(nn.Module): |
|
def __init__(self, v_in_channels, l_in_channels, key_channels, value_channels, out_channels=None, num_heads=1): |
|
super(SpatialImageLanguageAttention, self).__init__() |
|
|
|
|
|
|
|
self.v_in_channels = v_in_channels |
|
self.l_in_channels = l_in_channels |
|
self.out_channels = out_channels |
|
self.key_channels = key_channels |
|
self.value_channels = value_channels |
|
self.num_heads = num_heads |
|
if out_channels is None: |
|
self.out_channels = self.value_channels |
|
|
|
|
|
|
|
self.f_query = nn.Sequential( |
|
nn.Conv1d(self.l_in_channels, self.key_channels, kernel_size=1, stride=1), |
|
) |
|
|
|
|
|
self.f_key = nn.Sequential( |
|
nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.key_channels), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.f_value = nn.Sequential( |
|
nn.Conv1d(self.v_in_channels, self.key_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.key_channels), |
|
) |
|
|
|
|
|
self.W = nn.Sequential( |
|
nn.Conv1d(self.value_channels, self.out_channels, kernel_size=1, stride=1), |
|
nn.InstanceNorm1d(self.out_channels), |
|
) |
|
|
|
def forward(self, x, l, l_mask): |
|
|
|
l_mask = l_mask.squeeze(1) |
|
|
|
|
|
|
|
B, HW = x.size(0), x.size(1) |
|
x = x.permute(0, 2, 1) |
|
l = l.permute(0,2,1) |
|
|
|
l_mask = l_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = self.f_query(l) |
|
query = query * l_mask |
|
query = query.permute(0, 2, 1) |
|
|
|
key = self.f_key(x) |
|
value = self.f_value(x) |
|
|
|
n_l = query.size(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key = key.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) |
|
value = value.reshape(B, self.num_heads, self.key_channels//self.num_heads, HW) |
|
|
|
|
|
query = query.reshape(B, n_l, self.num_heads, self.key_channels//self.num_heads).permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
|
l_mask = l_mask.unsqueeze(-1) |
|
|
|
|
|
sim_map = torch.matmul(query, key) |
|
sim_map = (self.key_channels ** -.5) * sim_map |
|
|
|
sim_map = sim_map + (1e4*l_mask - 1e4) |
|
sim_map = F.softmax(sim_map, dim=-1) |
|
out = torch.matmul(sim_map, value.permute(0, 1, 3, 2)) |
|
|
|
|
|
out = out.permute(0, 2, 1, 3).contiguous().reshape(B, n_l, self.value_channels) |
|
out = out.permute(0, 2, 1) |
|
out = self.W(out) |
|
out = out.permute(0, 2, 1) |
|
|
|
return out |
|
|