import collections import json import os import time import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # from gensim.models import KeyedVectors from FakeVD.code_test.models.Baselines import * from FakeVD.code_test.models.FANVM import FANVMModel from FakeVD.code_test.models.SVFEND import SVFENDModel from FakeVD.code_test.models.TikTec import TikTecModel from FakeVD.code_test.utils.dataloader import * from FakeVD.code_test.models.Trainer import Trainer from FakeVD.code_test.models.Trainer_3set import Trainer3 def pad_sequence(seq_len,lst, emb): result=[] for video in lst: if isinstance(video, list): video = torch.stack(video) ori_len=video.shape[0] if ori_len == 0: video = torch.zeros([seq_len,emb],dtype=torch.long) elif ori_len>=seq_len: if emb == 200: video=torch.FloatTensor(video[:seq_len]) else: video=torch.LongTensor(video[:seq_len]) else: video=torch.cat([video,torch.zeros([seq_len-ori_len,video.shape[1]],dtype=torch.long)],dim=0) if emb == 200: video=torch.FloatTensor(video) else: video=torch.LongTensor(video) result.append(video) return torch.stack(result) def pad_sequence_bbox(seq_len,lst): result=[] for video in lst: if isinstance(video, list): video = torch.stack(video) ori_len=video.shape[0] if ori_len == 0: video = torch.zeros([seq_len,45,4096],dtype=torch.float) elif ori_len>=seq_len: video=torch.FloatTensor(video[:seq_len]) else: video=torch.cat([video,torch.zeros([seq_len-ori_len,45,4096],dtype=torch.float)],dim=0) result.append(video) return torch.stack(result) def pad_frame_sequence(seq_len,lst): attention_masks = [] result=[] for video in lst: video=torch.FloatTensor(video) ori_len=video.shape[0] if ori_len>=seq_len: gap=ori_len//seq_len video=video[::gap][:seq_len] mask = np.ones((seq_len)) else: video=torch.cat((video,torch.zeros([seq_len-ori_len,video.shape[1]],dtype=torch.float)),dim=0) mask = np.append(np.ones(ori_len), np.zeros(seq_len-ori_len)) result.append(video) mask = torch.IntTensor(mask) attention_masks.append(mask) return torch.stack(result), torch.stack(attention_masks) def _init_fn(worker_id): np.random.seed(2022) def SVFEND_collate_fn(batch): num_frames = 83 num_audioframes = 50 title_inputid = [item['title_inputid'] for item in batch] title_mask = [item['title_mask'] for item in batch] frames = [item['frames'] for item in batch] frames, frames_masks = pad_frame_sequence(num_frames, frames) audioframes = [item['audioframes'] for item in batch] audioframes, audioframes_masks = pad_frame_sequence(num_audioframes, audioframes) c3d = [item['c3d'] for item in batch] c3d, c3d_masks = pad_frame_sequence(num_frames, c3d) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'title_inputid': torch.stack(title_inputid), 'title_mask': torch.stack(title_mask), 'audioframes': audioframes, 'audioframes_masks': audioframes_masks, 'frames':frames, 'frames_masks': frames_masks, 'c3d': c3d, 'c3d_masks': c3d_masks, } def FANVM_collate_fn(batch): num_comments = 23 num_frames = 83 title_inputid = [item['title_inputid'] for item in batch] title_mask = [item['title_mask'] for item in batch] comments_like = [item['comments_like'] for item in batch] comments_inputid = [item['comments_inputid'] for item in batch] comments_mask = [item['comments_mask'] for item in batch] comments_inputid_resorted = [] comments_mask_resorted = [] comments_like_resorted = [] for idx in range(len(comments_like)): comments_like_one = comments_like[idx] comments_inputid_one = comments_inputid[idx] comments_mask_one = comments_mask[idx] if comments_like_one.shape != torch.Size([0]): comments_inputid_one, comments_mask_one, comments_like_one = (list(t) for t in zip(*sorted(zip(comments_inputid_one, comments_mask_one, comments_like_one), key=lambda s: s[2], reverse=True))) comments_inputid_resorted.append(comments_inputid_one) comments_mask_resorted.append(comments_mask_one) comments_like_resorted.append(comments_like_one) comments_inputid = pad_sequence(num_comments,comments_inputid_resorted,250) comments_mask = pad_sequence(num_comments,comments_mask_resorted,250) comments_like=[] for idx in range(len(comments_like_resorted)): comments_like_resorted_one = comments_like_resorted[idx] if len(comments_like_resorted_one)>=num_comments: comments_like.append(torch.tensor(comments_like_resorted_one[:num_comments])) else: if isinstance(comments_like_resorted_one, list): comments_like.append(torch.tensor(comments_like_resorted_one+[0]*(num_comments-len(comments_like_resorted_one)))) else: comments_like.append(torch.tensor(comments_like_resorted_one.tolist()+[0]*(num_comments-len(comments_like_resorted_one)))) frames = [item['frames'] for item in batch] frames, frames_masks = pad_frame_sequence(num_frames, frames) frame_thmub = [item['frame_thmub'] for item in batch] label = [item['label'] for item in batch] label_event = [item['label_event'] for item in batch] s = [item['s'] for item in batch] return { 'label': torch.stack(label), 'title_inputid': torch.stack(title_inputid), 'title_mask': torch.stack(title_mask), 'comments_inputid': comments_inputid, 'comments_mask': comments_mask, 'comments_like': torch.stack(comments_like), 'frames':frames, 'frames_masks': frames_masks, 'frame_thmub': torch.stack(frame_thmub), 's': torch.stack(s), 'label_event':torch.stack(label_event), } def bbox_collate_fn(batch): num_frames = 83 bbox_vgg = [item['bbox_vgg'] for item in batch] bbox_vgg = pad_sequence_bbox(num_frames,bbox_vgg) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'bbox_vgg': bbox_vgg, } def c3d_collate_fn(batch): num_frames = 83 c3d = [item['c3d'] for item in batch] c3d, c3d_masks = pad_frame_sequence(num_frames, c3d) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'c3d': c3d, 'c3d_masks': c3d_masks, } def vgg_collate_fn(batch): num_frames = 83 frames = [item['frames'] for item in batch] frames, frames_masks = pad_frame_sequence(num_frames, frames) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'frames':frames, 'frames_masks': frames_masks, } def comments_collate_fn(batch): num_comments = 23 comments_like = [item['comments_like'] for item in batch] comments_inputid = [item['comments_inputid'] for item in batch] comments_mask = [item['comments_mask'] for item in batch] comments_inputid_resorted = [] comments_mask_resorted = [] comments_like_resorted = [] for idx in range(len(comments_like)): comments_like_one = comments_like[idx] comments_inputid_one = comments_inputid[idx] comments_mask_one = comments_mask[idx] if comments_like_one.shape != torch.Size([0]): comments_inputid_one, comments_mask_one, comments_like_one = (list(t) for t in zip(*sorted(zip(comments_inputid_one, comments_mask_one, comments_like_one), key=lambda s: s[2], reverse=True))) comments_inputid_resorted.append(comments_inputid_one) comments_mask_resorted.append(comments_mask_one) comments_like_resorted.append(comments_like_one) comments_inputid = pad_sequence(num_comments,comments_inputid_resorted,250) comments_mask = pad_sequence(num_comments,comments_mask_resorted,250) comments_like=[] for idx in range(len(comments_like_resorted)): comments_like_resorted_one = comments_like_resorted[idx] if len(comments_like_resorted_one)>=num_comments: comments_like.append(torch.tensor(comments_like_resorted_one[:num_comments])) else: if isinstance(comments_like_resorted_one, list): comments_like.append(torch.tensor(comments_like_resorted_one+[0]*(num_comments-len(comments_like_resorted_one)))) else: comments_like.append(torch.tensor(comments_like_resorted_one.tolist()+[0]*(num_comments-len(comments_like_resorted_one)))) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'comments_inputid': comments_inputid, 'comments_mask': comments_mask, 'comments_like': torch.stack(comments_like), } def title_w2v_collate_fn(batch): length_title = 128 title_w2v = [item['title_w2v'] for item in batch] title_w2v = pad_sequence(length_title, title_w2v, 100) label = [item['label'] for item in batch] return { 'label': torch.stack(label), 'title_w2v': title_w2v, } def tictec_collate_fn(batch): """ 将一批样本组合成一个批次。 Args: batch (list of dict): 包含单个样本的列表,每个样本是一个字典,包含 'label'、'caption_feature'、'visual_feature'、'asr_feature'、'mask_K' 和 'mask_N'。 Returns: dict: 包含批次数据的字典,'labels' 是一个张量,其他特征和掩码也是张量。 """ num_frames = 83 labels = torch.stack([item['label'] for item in batch]) caption_features = torch.stack([item['caption_feature'] for item in batch]) visual_features = torch.stack([item['visual_feature'] for item in batch]) asr_features = torch.stack([item['asr_feature'] for item in batch]) mask_Ks = torch.stack([item['mask_K'] for item in batch]) mask_Ns = torch.stack([item['mask_N'] for item in batch]) return { 'label': labels, 'caption_feature': caption_features, 'visual_feature': visual_features, 'asr_feature': asr_features, 'mask_K': mask_Ks, 'mask_N': mask_Ns, } class Run(): def __init__(self, config ): self.model_name = config['model_name'] self.mode_eval = config['mode_eval'] self.fold = config['fold'] self.data_type = 'SVFEND' self.epoches = config['epoches'] self.batch_size = config['batch_size'] self.num_workers = config['num_workers'] self.epoch_stop = config['epoch_stop'] self.seed = config['seed'] self.device = config['device'] self.lr = config['lr'] self.lambd=config['lambd'] self.save_param_dir = config['path_param'] self.path_tensorboard = config['path_tensorboard'] self.dropout = config['dropout'] self.weight_decay = config['weight_decay'] self.event_num = 616 self.mode ='normal' def get_dataloader(self,data_type,data_fold): collate_fn=None if data_type=='SVFEND': dataset_train = SVFENDDataset(f'vid_fold_{1}.txt') dataset_test = SVFENDDataset(f'vid_fold_{2}.txt') collate_fn=SVFEND_collate_fn elif data_type=='FANVM': dataset_train = FANVMDataset_train(f'vid_fold_no_{data_fold}.txt') dataset_test = FANVMDataset_test(path_vid_train=f'vid_fold_no_{data_fold}.txt', path_vid_test=f'vid_fold_{data_fold}.txt') collate_fn = FANVM_collate_fn elif data_type=='c3d': dataset_train = C3DDataset(f'vid_fold_no_{data_fold}.txt') dataset_test = C3DDataset(f'vid_fold_{data_fold}.txt') collate_fn = c3d_collate_fn elif data_type=='vgg': dataset_train = VGGDataset(f'vid_fold_no_{data_fold}.txt') dataset_test = VGGDataset(f'vid_fold_{data_fold}.txt') collate_fn = vgg_collate_fn elif data_type=='bbox': dataset_train = BboxDataset('vid_fold_no1.txt') dataset_test = BboxDataset('vid_fold_1.txt') collate_fn = bbox_collate_fn elif data_type=='comments': dataset_train = CommentsDataset(f'vid_fold_no_{data_fold}.txt') dataset_test = CommentsDataset(f'vid_fold_{data_fold}.txt') collate_fn = comments_collate_fn elif data_type=='TikTec': dataset_train = TikTecDataset(f'vid_fold_no_{data_fold}.txt') dataset_test = TikTecDataset(f'vid_fold_{data_fold}.txt') collate_fn = tictec_collate_fn # elif data_type=='w2v': # wv_from_text = KeyedVectors.load_word2vec_format("./stores/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt", binary=False) # dataset_train = Title_W2V_Dataset(f'vid_fold_no{data_fold}.txt', wv_from_text) # dataset_test = Title_W2V_Dataset(f'vid_fold_{data_fold}.txt', wv_from_text) # collate_fn = title_w2v_collate_fn train_dataloader = DataLoader(dataset_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=True, worker_init_fn=_init_fn, collate_fn=collate_fn) test_dataloader=DataLoader(dataset_test, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, worker_init_fn=_init_fn, collate_fn=collate_fn) dataloaders = dict(zip(['train', 'test'],[train_dataloader, test_dataloader])) return dataloaders def get_dataloader_temporal(self, data_type): collate_fn=None if data_type=='SVFEND': dataset_train = SVFENDDataset('vid_time3_train.txt') dataset_val = SVFENDDataset('vid_time3_val.txt') dataset_test = SVFENDDataset('vid_time3_test.txt') collate_fn=SVFEND_collate_fn elif data_type=='FANVM': dataset_train = FANVMDataset_train('vid_time3_train.txt') dataset_val = FANVMDataset_test(path_vid_train='vid_time3_train.txt', path_vid_test='vid_time3_valid.txt') dataset_test = FANVMDataset_test(path_vid_train='vid_time3_train.txt', path_vid_test='vid_time3_test.txt') collate_fn = FANVM_collate_fn else: # can be added print ("Not available") train_dataloader = DataLoader(dataset_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=True, worker_init_fn=_init_fn, collate_fn=collate_fn) val_dataloader = DataLoader(dataset_val, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, worker_init_fn=_init_fn, collate_fn=collate_fn) test_dataloader=DataLoader(dataset_test, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, worker_init_fn=_init_fn, collate_fn=collate_fn) dataloaders = dict(zip(['train', 'val', 'test'],[train_dataloader, val_dataloader, test_dataloader])) return dataloaders def get_model(self): if self.model_name == 'SVFEND': self.model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=self.dropout) elif self.model_name == 'FANVM': self.model = FANVMModel(bert_model='bert-base-chinese', fea_dim=128) self.data_type = "FANVM" self.mode = 'eann' elif self.model_name == 'C3D': self.model = bC3D(fea_dim=128) self.data_type = "c3d" elif self.model_name == 'VGG': self.model = bVGG(fea_dim=128) self.data_type = "vgg" elif self.model_name == 'Bbox': self.model = bBbox(fea_dim=128) self.data_type = "bbox" elif self.model_name == 'Vggish': self.model = bVggish(fea_dim=128) elif self.model_name == 'Bert': self.model = bBert(bert_model='bert-base-chinese', fea_dim=128,dropout=self.dropout) elif self.model_name == 'TextCNN': self.model = bTextCNN(fea_dim=128, vocab_size=100) self.data_type = "w2v" elif self.model_name == 'Comments': self.model = bComments(bert_model='bert-base-chinese', fea_dim=128) self.data_type = "comments" elif self.model_name == 'TikTec': self.model = TikTecModel(VCIF_dropout=self.dropout, MLP_dropout=self.dropout) self.data_type = 'TikTec' return self.model def main(self): if self.mode_eval == "nocv": self.model = self.get_model() dataloaders = self.get_dataloader(data_type=self.data_type, data_fold=self.fold) trainer = Trainer(model=self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num, epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard)) result=trainer.train() for metric in ['acc', 'f1', 'precision', 'recall', 'auc']: print ('%s : %.4f' % (metric, result[metric])) elif self.mode_eval == "temporal": self.model = self.get_model() dataloaders = self.get_dataloader_temporal(data_type=self.data_type) trainer = Trainer3(model=self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num, epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard)) result=trainer.train() for metric in ['acc', 'f1', 'precision', 'recall', 'auc']: print ('%s : %.4f' % (metric, result[metric])) return result elif self.mode_eval == "cv": collate_fn=None # if self.model_name == 'TextCNN': # wv_from_text = KeyedVectors.load_word2vec_format("./stores/tencent-ailab-embedding-zh-d100-v0.2.0-s/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt", binary=False) history = collections.defaultdict(list) for fold in range(1, 6): print('-' * 50) print ('fold %d:' % fold) print('-' * 50) self.model = self.get_model() dataloaders = self.get_dataloader(data_type=self.data_type, data_fold=fold) trainer = Trainer(model = self.model, device = self.device, lr = self.lr, dataloaders = dataloaders, epoches = self.epoches, dropout = self.dropout, weight_decay = self.weight_decay, mode = self.mode, model_name = self.model_name, event_num = self.event_num, epoch_stop = self.epoch_stop, save_param_path = self.save_param_dir+self.data_type+"/"+self.model_name+"/", writer = SummaryWriter(self.path_tensorboard+"fold_"+str(fold)+"/")) result = trainer.train() history['auc'].append(result['auc']) history['f1'].append(result['f1']) history['recall'].append(result['recall']) history['precision'].append(result['precision']) history['acc'].append(result['acc']) print ('results on 5-fold cross-validation: ') for metric in ['acc', 'f1', 'precision', 'recall', 'auc']: print ('%s : %.4f +/- %.4f' % (metric, np.mean(history[metric]), np.std(history[metric]))) else: print ("Not Available")