video2music / model /video_regression.py
kjysmu's picture
add files
4e46a55
raw
history blame contribute delete
No virus
2.64 kB
import torch
import torch.nn as nn
from torch.nn.modules.normalization import LayerNorm
import random
import numpy as np
from utilities.constants import *
from utilities.device import get_device
from datetime import datetime
import torch.nn.functional as F
class VideoRegression(nn.Module):
def __init__(self, n_layers=2, d_model=64, dropout=0.1, max_sequence_video=300, total_vf_dim = 0, regModel="bilstm"):
super(VideoRegression, self).__init__()
self.nlayers = n_layers
self.d_model = d_model
self.dropout = dropout
self.max_seq_video = max_sequence_video
self.total_vf_dim = total_vf_dim
self.regModel = regModel
self.bilstm = nn.LSTM(self.total_vf_dim, self.d_model, self.nlayers, bidirectional=True)
self.bigru = nn.GRU(self.total_vf_dim, self.d_model, self.nlayers, bidirectional=True)
self.bifc = nn.Linear(self.d_model * 2, 2)
self.lstm = nn.LSTM(self.total_vf_dim, self.d_model, self.nlayers)
self.gru = nn.GRU(self.total_vf_dim, self.d_model, self.nlayers)
self.fc = nn.Linear(self.d_model, 2)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def forward(self, feature_semantic_list, feature_scene_offset, feature_motion, feature_emotion):
### Video (SemanticList + SceneOffset + Motion + Emotion) (ENCODER) ###
vf_concat = feature_semantic_list[0].float()
for i in range(1, len(feature_semantic_list)):
vf_concat = torch.cat( (vf_concat, feature_semantic_list[i].float()), dim=2)
vf_concat = torch.cat([vf_concat, feature_scene_offset.unsqueeze(-1).float()], dim=-1)
vf_concat = torch.cat([vf_concat, feature_motion.unsqueeze(-1).float()], dim=-1)
vf_concat = torch.cat([vf_concat, feature_emotion.float()], dim=-1)
vf_concat = vf_concat.permute(1,0,2)
vf_concat = F.dropout(vf_concat, p=self.dropout, training=self.training)
if self.regModel == "bilstm":
out, _ = self.bilstm(vf_concat)
out = out.permute(1,0,2)
out = self.bifc(out)
elif self.regModel == "bigru":
out, _ = self.bigru(vf_concat)
out = out.permute(1,0,2)
out = self.bifc(out)
elif self.regModel == "lstm":
out, _ = self.lstm(vf_concat)
out = out.permute(1,0,2)
out = self.fc(out)
elif self.regModel == "gru":
out, _ = self.gru(vf_concat)
out = out.permute(1,0,2)
out = self.fc(out)
return out