import os import torch import torch.nn.functional as F from torch.utils.data import DataLoader import numpy as np from tqdm import tqdm from FakeVD.code_test.utils.metrics import * from FakeVD.code_test.models.SVFEND import SVFENDModel from FakeVD.code_test.utils.dataloader import SVFENDDataset from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn # from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19 from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text def load_model(checkpoint_path): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1) # model.load_state_dict(torch.load(checkpoint_path)) model.load_state_dict(torch.load(checkpoint_path, map_location=device), False) model.eval() return model def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'): # 加载检测模型 模型存放路径 checkpoint_path model_main = load_model(checkpoint_path) model_vggish = load_model_vggish() model_vgg19 = load_model_vgg19() model_c3d = load_model_c3d() model_text = wav2text() models = { 'model_main': model_main, 'model_vggish': model_vggish, 'model_vgg19': model_vgg19, 'model_c3d' : model_c3d, 'model_text' : model_text } return models # label = 0 if item['annotation']=='真' else 1 def test(model, dataloader): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # model.cuda() model.eval() pred = [] label = [] prob = [] for batch in tqdm(dataloader): with torch.no_grad(): batch_data = batch for k, v in batch_data.items(): batch_data[k] = v.to(device) batch_label = batch_data['label'] batch_outputs, fea = model(**batch_data) _, batch_preds = torch.max(batch_outputs, 1) softmax_probs = F.softmax(batch_outputs, dim=1) # 计算softmax概率 label.extend(batch_label.detach().cpu().numpy().tolist()) pred.extend(batch_preds.detach().cpu().numpy().tolist()) prob.extend(softmax_probs.detach().cpu().numpy().tolist()) # 收集softmax概率 return (label, pred, prob) def main(models, video_file_path, preprocessed_flag=False, feature_path='./FakeVD/code_test/preprocessed_feature'): # 视频是否已经过预处理 preprocessed_flag # 特征存放目录 feature_path # 获取模型 model_main = models['model_main'] model_vggish = models['model_vggish'] model_vgg19 = models['model_vgg19'] model_c3d = models['model_c3d'] model_text = models['model_text'] # 获取视频文件夹路径 video_folder_path = os.path.dirname(video_file_path) # 获取视频文件名(包含扩展名) video_file_name = os.path.basename(video_file_path) # 提取视频文件名(不包括扩展名)作为视频ID vids = [] vid = os.path.splitext(video_file_name)[0] vids.append(vid) # video_file_name = os.path.basename(video_file_path) # vids.append(os.path.splitext(video_file_name)[0]) # # vids.append(video_file_name.split('_')[1].split('.')[0] # VGGish_audio特征目录 VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl') # C3D_video特征目录 C3D_video_feature_path = os.path.join(feature_path, 'C3D/') # VGG19_frame特征目录 VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/') # ASR_text特征目录 asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json') # 特征提取 if not preprocessed_flag: vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path) vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path) c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name) asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path) # 数据路径 data = vids data_paths = { 'VGGish_audio' : VGGish_audio_feature_path, 'C3D_video' : C3D_video_feature_path, 'VGG19_frame' : VGG19_frame_feature_path, 'ASR_text' : asr_text_feature_path } # 创建Dataset和DataLoader dataset = SVFENDDataset(data, data_paths) dataloader=DataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True, shuffle=False, worker_init_fn=_init_fn, collate_fn=SVFEND_collate_fn) # 进行预测 predictions = test(model_main, dataloader) annotation = '真' if predictions[1][0]==0 else '假' prob_softmax = predictions[2] # annotation_prob = max(prob_softmax[0]) annotation_prob = prob_softmax[0][0]#真的概率 annotation_prob1 = prob_softmax[0][1]#假的概率 # 打印预测结果 print(annotation, annotation_prob, annotation_prob1) return annotation_prob1 if __name__ == "__main__": # 视频是否已经过预处理 preprocessed_flag = False video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4" models = get_model() main(models, video_file_path, preprocessed_flag)