Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
class MLP(nn.Module): | |
def __init__(self, input_dim, hidden_dims, output_dim, dropout): | |
super(MLP, self).__init__() | |
layers = list() | |
curr_dim = input_dim | |
for hidden_dim in hidden_dims: | |
layers.append(nn.Linear(curr_dim, hidden_dim)) | |
layers.append(nn.BatchNorm1d(hidden_dim)) | |
layers.append(nn.ReLU()) | |
layers.append(nn.Dropout(p=dropout)) | |
curr_dim = hidden_dim | |
layers.append(nn.Linear(curr_dim, output_dim)) | |
self.mlp = nn.Sequential(*layers) | |
def forward(self, input): | |
return self.mlp(input) | |
class MaskAvg(nn.Module): | |
def __init__(self): | |
super(MaskAvg, self).__init__() | |
def forward(self, input, mask): | |
score = torch.ones((input.shape[0], input.shape[1]), device=input.device) | |
score = score.masked_fill(mask == 0, float('-inf')) | |
score = torch.softmax(score, dim=-1).unsqueeze(1) | |
output = torch.matmul(score, input).squeeze(1) | |
return output | |
class CVRL(nn.Module): | |
def __init__(self, d_w, d_f, obj_num, gru_dim): | |
super(CVRL, self).__init__() | |
self.gru = nn.GRU(d_w, gru_dim, batch_first=True, bidirectional=True) | |
self.linear_r = nn.Linear(d_f, 1) | |
self.linear_h = nn.Linear(2*gru_dim, obj_num) | |
def forward(self, caption_feature, visual_feature): | |
# IN: caption_feature: (bs, K, S, d_w), visual_feature: (bs, K, obj_num, d_f) | |
# OUT: frame_visual_rep: (bs, K, d_f) | |
encoded_caption, _ = self.gru(caption_feature.view(-1, caption_feature.shape[-2], caption_feature.shape[-1])) # (bs*K, S, 2*gru_dim) | |
encoded_caption = encoded_caption.view(-1, caption_feature.shape[-3], caption_feature.shape[-2], encoded_caption.shape[-1]) # (bs, K, S, 2*gru_dim) | |
frame_caption_rep = encoded_caption.max(dim=2).values # (bs, K, 2*gru_dim) | |
alpha = self.linear_r(visual_feature).squeeze() + self.linear_h(frame_caption_rep) # (bs, K, obj_num) | |
alpha = torch.softmax(torch.tanh(alpha), dim=-1).unsqueeze(dim=-2) # (bs, K, 1, obj_num) | |
frame_visual_rep = alpha.matmul(visual_feature) # (bs, K, 1, d_f) | |
frame_visual_rep = frame_visual_rep.squeeze() # (bs, K, d_f) | |
return frame_visual_rep | |
class ASRL(nn.Module): | |
def __init__(self, d_w, gru_dim): | |
super(ASRL, self).__init__() | |
self.gru = nn.GRU(d_w, gru_dim, batch_first=True, bidirectional=True) | |
def forward(self, asr_feature): | |
# IN: asr_feature: (bs, N, d_w) | |
# OUT: text_audio_rep: (bs, N, 2*gru_dim) | |
text_audio_rep, _ = self.gru(asr_feature) | |
return text_audio_rep | |
class VCIF(nn.Module): | |
def __init__(self, d_f, d_w, d_H, gru_f_dim, gru_w_dim, dropout): | |
super(VCIF, self).__init__() | |
self.param_D = nn.Parameter(torch.empty((d_f, d_w))) | |
self.param_Df = nn.Parameter(torch.empty((d_f, d_H))) | |
self.param_Dw = nn.Parameter(torch.empty((d_w, d_H))) | |
self.param_df = nn.Parameter(torch.empty(d_H)) | |
self.param_dw = nn.Parameter(torch.empty(d_H)) | |
self.gru_f = nn.GRU(d_f, gru_f_dim, batch_first=True) | |
self.gru_w = nn.GRU(d_w, gru_w_dim, batch_first=True) | |
self.mask_avg = MaskAvg() | |
self.dropout = nn.Dropout(p=dropout) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.xavier_uniform_(self.param_D) | |
nn.init.xavier_uniform_(self.param_Df) | |
nn.init.xavier_uniform_(self.param_Dw) | |
nn.init.uniform_(self.param_df) | |
nn.init.uniform_(self.param_dw) | |
def forward(self, frame_visual_rep, text_audio_rep, mask_K, mask_N): | |
# IN: frame_visual_rep: (bs, K, d_f), text_audio_rep: (bs, N, d_w) | |
# OUT: video_rep: (bs, gru_f_dim + gru_w_dim) | |
affinity_matrix = torch.tanh(frame_visual_rep.matmul(self.param_D).matmul(text_audio_rep.transpose(-1, -2))) | |
affinity_matrix = self.dropout(affinity_matrix) | |
frame_co_att_map = torch.tanh(frame_visual_rep.matmul(self.param_Df) + affinity_matrix.matmul(text_audio_rep).matmul(self.param_Dw)) | |
word_co_att_map = torch.tanh(text_audio_rep.matmul(self.param_Dw) + affinity_matrix.transpose(-1, -2).matmul(frame_visual_rep).matmul(self.param_Df)) | |
frame_co_att_map = self.dropout(frame_co_att_map) | |
word_co_att_map = self.dropout(word_co_att_map) | |
frame_att_weight = torch.softmax(frame_co_att_map.matmul(self.param_df), dim=-1) | |
word_att_weight = torch.softmax(word_co_att_map.matmul(self.param_dw), dim=-1) | |
frame_visual_weighted_rep = frame_att_weight.unsqueeze(dim=-1) * frame_visual_rep | |
text_audio_weighted_rep = word_att_weight.unsqueeze(dim=-1) * text_audio_rep | |
encoded_visual_rep, _ = self.gru_f(frame_visual_weighted_rep) | |
encoded_speech_rep, _ = self.gru_w(text_audio_weighted_rep) | |
visual_rep = self.mask_avg(encoded_visual_rep, mask_K) # (bs, gru_f_dim) | |
speech_rep = self.mask_avg(encoded_speech_rep, mask_N) # (bs, gru_w_dim) | |
video_rep = torch.cat([visual_rep, speech_rep], dim=-1) | |
return video_rep | |
class TikTecModel(nn.Module): | |
def __init__(self, word_dim=300, mfcc_dim=650, visual_dim=1000, obj_num=45, CVRL_gru_dim=200, ASRL_gru_dim=500, VCIF_d_H=200, VCIF_gru_f_dim=200, VCIF_gru_w_dim=100, VCIF_dropout=0.2, MLP_hidden_dims=[512], MLP_dropout=0.2): | |
super(TikTecModel, self).__init__() | |
self.CVRL = CVRL(d_w=word_dim, d_f=visual_dim, obj_num=obj_num, gru_dim=CVRL_gru_dim) | |
self.ASRL = ASRL(d_w=(word_dim + mfcc_dim), gru_dim=ASRL_gru_dim) | |
self.VCIF = VCIF(d_f=visual_dim, d_w=2*ASRL_gru_dim, d_H=VCIF_d_H, gru_f_dim=VCIF_gru_f_dim, gru_w_dim=VCIF_gru_w_dim, dropout=VCIF_dropout) | |
self.MLP = MLP(VCIF_gru_f_dim + VCIF_gru_w_dim, MLP_hidden_dims, 2, MLP_dropout) | |
def forward(self, **kwargs): | |
# IN: | |
# caption_feature: (bs, K, S, word_dim) = (bs, 200, 100, 300) | |
# visual_feature: (bs, K, obj_num, visual_dim) = (bs, 200, 45, 1000) | |
# asr_feature: (bs, N, word_dim + mfcc_dim) = (bs, 500, 300 + 650) | |
# mask_K: (bs, K) = (bs, 200) | |
# mask_N: (bs, N) = (bs, 500) | |
# OUT: (bs, 2) | |
caption_feature = kwargs['caption_feature'] | |
visual_feature = kwargs['visual_feature'] | |
asr_feature = kwargs['asr_feature'] | |
mask_K = kwargs['mask_K'] | |
mask_N = kwargs['mask_N'] | |
frame_visual_rep = self.CVRL(caption_feature, visual_feature) | |
text_audio_rep = self.ASRL(asr_feature) | |
video_rep = self.VCIF(frame_visual_rep, text_audio_rep, mask_K, mask_N) | |
output = self.MLP(video_rep) | |
return output | |