ybbwcwaps
some FakeVD
711b041
raw
history blame
No virus
5.87 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BertModel
from .layers import Attention
class bBbox(torch.nn.Module):
def __init__(self,fea_dim):
super(bBbox, self).__init__()
self.img_dim = 4096
self.attention1 = Attention(dim=128,heads=4)
self.attention2 = Attention(dim=128,heads=4)
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
frames=kwargs['bbox_vgg']
fea_img = self.linear_img(frames)
fea_img = torch.reshape(fea_img, (-1, 45, 128))
fea_img = self.attention1(fea_img)
fea_img = torch.mean(fea_img, -2)
fea_img = torch.reshape(fea_img, (-1, 83, 128))
fea_img = self.attention2(fea_img)
fea_img = torch.mean(fea_img, -2)
output = self.classifier(fea_img)
return output, fea_img
class bC3D(torch.nn.Module):
def __init__(self,fea_dim):
super(bC3D, self).__init__()
# self.video_dim = 4096
self.video_dim = 2048
self.attention = Attention(dim=128,heads=4)
self.linear_video = nn.Sequential(torch.nn.Linear(self.video_dim, fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
c3d = kwargs['c3d']
fea_video = self.linear_video(c3d)
fea_video = self.attention(fea_video)
fea_video = torch.mean(fea_video, -2)
output = self.classifier(fea_video)
return output
class bVGG(torch.nn.Module):
def __init__(self,fea_dim):
super(bVGG, self).__init__()
# self.img_dim = 4096
self.img_dim = 2048
self.attention = Attention(dim=128,heads=4)
self.linear_img = nn.Sequential(torch.nn.Linear(self.img_dim, fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
frames=kwargs['frames']
fea_img = self.linear_img(frames)
fea_img = self.attention(fea_img)
fea_img = torch.mean(fea_img, -2)
output = self.classifier(fea_img)
return output
class bVggish(torch.nn.Module):
def __init__(self,fea_dim):
super(bVggish, self).__init__()
# self.audio_dim = 128
self.attention = Attention(dim=128,heads=4)
self.vggish_layer = torch.hub.load('./torchvggish/', 'vggish', source = 'local')
net_structure = list(self.vggish_layer.children())
self.vggish_modified = nn.Sequential(*net_structure[-2:-1])
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
audioframes=kwargs['audioframes']
fea_audio = self.vggish_modified(audioframes)
fea_audio = self.attention(fea_audio)
fea_audio = torch.mean(fea_audio, -2)
print (fea_audio.shape)
output = self.classifier(fea_audio)
return output, fea_audio
class bBert(torch.nn.Module):
def __init__(self,bert_model,fea_dim, dropout):
super(bBert, self).__init__()
self.text_dim = 768
self.bert = BertModel.from_pretrained(bert_model).requires_grad_(False)
self.linear_text = nn.Sequential(torch.nn.Linear(self.text_dim, fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
title_inputid = kwargs['title_inputid']
title_mask=kwargs['title_mask']
fea_text=self.bert(title_inputid,attention_mask=title_mask)[1]
fea_text=self.linear_text(fea_text)
output = self.classifier(fea_text)
return output,fea_text
class bTextCNN(nn.Module):
def __init__(self, fea_dim, vocab_size):
super(bTextCNN, self).__init__()
self.vocab_size = vocab_size
self.fea_dim=fea_dim
self.channel_in = 1
self.filter_num = 14
self.window_size = [3,4,5]
self.textcnn =nn.ModuleList([nn.Conv2d(self.channel_in, self.filter_num, (K,self.vocab_size)) for K in self.window_size])
self.linear = nn.Sequential(torch.nn.Linear(len(self.window_size) * self.filter_num, self.fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(self.fea_dim,2)
def forward(self, **kwargs):
title_w2v = kwargs['title_w2v']
text = title_w2v.unsqueeze(1)
text = [F.relu(conv(text)).squeeze(3) for conv in self.textcnn]
text = [F.max_pool1d(i.squeeze(2), i.shape[-1]).squeeze(2) for i in text]
fea_text = torch.cat(text, 1)
fea_text = self.linear(fea_text)
output = self.classifier(fea_text)
return output
class bComments(torch.nn.Module):
def __init__(self,bert_model,fea_dim):
super(bComments, self).__init__()
self.comment_dim = 768
self.bert = BertModel.from_pretrained(bert_model).requires_grad_(False)
self.attention = Attention(dim=128,heads=4)
self.linear_comment = nn.Sequential(torch.nn.Linear(self.comment_dim, fea_dim),torch.nn.ReLU())
self.classifier = nn.Linear(fea_dim,2)
def forward(self, **kwargs):
comments_inputid = kwargs['comments_inputid']
comments_mask=kwargs['comments_mask']
comments_feature=[]
for i in range(comments_inputid.shape[0]):
bert_fea=self.bert(comments_inputid[i], attention_mask=comments_mask[i])[1]
comments_feature.append(bert_fea)
comments_feature=torch.stack(comments_feature)
fea_comments=self.linear_comment(comments_feature)
print (fea_comments.shape)
fea_comments = self.attention(fea_comments)
fea_comments = torch.mean(fea_comments, -2)
output = self.classifier(fea_comments)
return output