import torch from torch import nn class MLP(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim, dropout): super(MLP, self).__init__() layers = list() curr_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(curr_dim, hidden_dim)) layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(p=dropout)) curr_dim = hidden_dim layers.append(nn.Linear(curr_dim, output_dim)) self.mlp = nn.Sequential(*layers) def forward(self, input): return self.mlp(input) class MaskAvg(nn.Module): def __init__(self): super(MaskAvg, self).__init__() def forward(self, input, mask): score = torch.ones((input.shape[0], input.shape[1]), device=input.device) score = score.masked_fill(mask == 0, float('-inf')) score = torch.softmax(score, dim=-1).unsqueeze(1) output = torch.matmul(score, input).squeeze(1) return output class CVRL(nn.Module): def __init__(self, d_w, d_f, obj_num, gru_dim): super(CVRL, self).__init__() self.gru = nn.GRU(d_w, gru_dim, batch_first=True, bidirectional=True) self.linear_r = nn.Linear(d_f, 1) self.linear_h = nn.Linear(2*gru_dim, obj_num) def forward(self, caption_feature, visual_feature): # IN: caption_feature: (bs, K, S, d_w), visual_feature: (bs, K, obj_num, d_f) # OUT: frame_visual_rep: (bs, K, d_f) encoded_caption, _ = self.gru(caption_feature.view(-1, caption_feature.shape[-2], caption_feature.shape[-1])) # (bs*K, S, 2*gru_dim) encoded_caption = encoded_caption.view(-1, caption_feature.shape[-3], caption_feature.shape[-2], encoded_caption.shape[-1]) # (bs, K, S, 2*gru_dim) frame_caption_rep = encoded_caption.max(dim=2).values # (bs, K, 2*gru_dim) alpha = self.linear_r(visual_feature).squeeze() + self.linear_h(frame_caption_rep) # (bs, K, obj_num) alpha = torch.softmax(torch.tanh(alpha), dim=-1).unsqueeze(dim=-2) # (bs, K, 1, obj_num) frame_visual_rep = alpha.matmul(visual_feature) # (bs, K, 1, d_f) frame_visual_rep = frame_visual_rep.squeeze() # (bs, K, d_f) return frame_visual_rep class ASRL(nn.Module): def __init__(self, d_w, gru_dim): super(ASRL, self).__init__() self.gru = nn.GRU(d_w, gru_dim, batch_first=True, bidirectional=True) def forward(self, asr_feature): # IN: asr_feature: (bs, N, d_w) # OUT: text_audio_rep: (bs, N, 2*gru_dim) text_audio_rep, _ = self.gru(asr_feature) return text_audio_rep class VCIF(nn.Module): def __init__(self, d_f, d_w, d_H, gru_f_dim, gru_w_dim, dropout): super(VCIF, self).__init__() self.param_D = nn.Parameter(torch.empty((d_f, d_w))) self.param_Df = nn.Parameter(torch.empty((d_f, d_H))) self.param_Dw = nn.Parameter(torch.empty((d_w, d_H))) self.param_df = nn.Parameter(torch.empty(d_H)) self.param_dw = nn.Parameter(torch.empty(d_H)) self.gru_f = nn.GRU(d_f, gru_f_dim, batch_first=True) self.gru_w = nn.GRU(d_w, gru_w_dim, batch_first=True) self.mask_avg = MaskAvg() self.dropout = nn.Dropout(p=dropout) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.param_D) nn.init.xavier_uniform_(self.param_Df) nn.init.xavier_uniform_(self.param_Dw) nn.init.uniform_(self.param_df) nn.init.uniform_(self.param_dw) def forward(self, frame_visual_rep, text_audio_rep, mask_K, mask_N): # IN: frame_visual_rep: (bs, K, d_f), text_audio_rep: (bs, N, d_w) # OUT: video_rep: (bs, gru_f_dim + gru_w_dim) affinity_matrix = torch.tanh(frame_visual_rep.matmul(self.param_D).matmul(text_audio_rep.transpose(-1, -2))) affinity_matrix = self.dropout(affinity_matrix) frame_co_att_map = torch.tanh(frame_visual_rep.matmul(self.param_Df) + affinity_matrix.matmul(text_audio_rep).matmul(self.param_Dw)) word_co_att_map = torch.tanh(text_audio_rep.matmul(self.param_Dw) + affinity_matrix.transpose(-1, -2).matmul(frame_visual_rep).matmul(self.param_Df)) frame_co_att_map = self.dropout(frame_co_att_map) word_co_att_map = self.dropout(word_co_att_map) frame_att_weight = torch.softmax(frame_co_att_map.matmul(self.param_df), dim=-1) word_att_weight = torch.softmax(word_co_att_map.matmul(self.param_dw), dim=-1) frame_visual_weighted_rep = frame_att_weight.unsqueeze(dim=-1) * frame_visual_rep text_audio_weighted_rep = word_att_weight.unsqueeze(dim=-1) * text_audio_rep encoded_visual_rep, _ = self.gru_f(frame_visual_weighted_rep) encoded_speech_rep, _ = self.gru_w(text_audio_weighted_rep) visual_rep = self.mask_avg(encoded_visual_rep, mask_K) # (bs, gru_f_dim) speech_rep = self.mask_avg(encoded_speech_rep, mask_N) # (bs, gru_w_dim) video_rep = torch.cat([visual_rep, speech_rep], dim=-1) return video_rep class TikTecModel(nn.Module): def __init__(self, word_dim=300, mfcc_dim=650, visual_dim=1000, obj_num=45, CVRL_gru_dim=200, ASRL_gru_dim=500, VCIF_d_H=200, VCIF_gru_f_dim=200, VCIF_gru_w_dim=100, VCIF_dropout=0.2, MLP_hidden_dims=[512], MLP_dropout=0.2): super(TikTecModel, self).__init__() self.CVRL = CVRL(d_w=word_dim, d_f=visual_dim, obj_num=obj_num, gru_dim=CVRL_gru_dim) self.ASRL = ASRL(d_w=(word_dim + mfcc_dim), gru_dim=ASRL_gru_dim) self.VCIF = VCIF(d_f=visual_dim, d_w=2*ASRL_gru_dim, d_H=VCIF_d_H, gru_f_dim=VCIF_gru_f_dim, gru_w_dim=VCIF_gru_w_dim, dropout=VCIF_dropout) self.MLP = MLP(VCIF_gru_f_dim + VCIF_gru_w_dim, MLP_hidden_dims, 2, MLP_dropout) def forward(self, **kwargs): # IN: # caption_feature: (bs, K, S, word_dim) = (bs, 200, 100, 300) # visual_feature: (bs, K, obj_num, visual_dim) = (bs, 200, 45, 1000) # asr_feature: (bs, N, word_dim + mfcc_dim) = (bs, 500, 300 + 650) # mask_K: (bs, K) = (bs, 200) # mask_N: (bs, N) = (bs, 500) # OUT: (bs, 2) caption_feature = kwargs['caption_feature'] visual_feature = kwargs['visual_feature'] asr_feature = kwargs['asr_feature'] mask_K = kwargs['mask_K'] mask_N = kwargs['mask_N'] frame_visual_rep = self.CVRL(caption_feature, visual_feature) text_audio_rep = self.ASRL(asr_feature) video_rep = self.VCIF(frame_visual_rep, text_audio_rep, mask_K, mask_N) output = self.MLP(video_rep) return output