ybbwcwaps
some FakeVD
711b041
raw
history blame
No virus
5.93 kB
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from FakeVD.code_test.utils.metrics import *
from FakeVD.code_test.models.SVFEND import SVFENDModel
from FakeVD.code_test.utils.dataloader import SVFENDDataset
from FakeVD.code_test.run import _init_fn, SVFEND_collate_fn
# from VGGish_Feature_Extractor.my_vggish_folder_fun import vggish_audio
from FakeVD.code_test.VGGish_Feature_Extractor.my_vggish_fun import vggish_audio, load_model_vggish
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import process_video as vgg19_frame
from FakeVD.code_test.VGG19_Feature_Extractor.vgg19_feature import load_model_vgg19
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import feature_extractor as c3d_video
from FakeVD.code_test.C3D_Feature_Extractor.feature_extractor_vid import load_model_c3d
from FakeVD.code_test.Text_Feature_Extractor.main import video_work as asr_text
from FakeVD.code_test.Text_Feature_Extractor.wav2text import wav2text
def load_model(checkpoint_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SVFENDModel(bert_model='bert-base-chinese', fea_dim=128,dropout=0.1)
# model.load_state_dict(torch.load(checkpoint_path))
model.load_state_dict(torch.load(checkpoint_path, map_location=device), False)
model.eval()
return model
def get_model(checkpoint_path='./FakeVD/code_test/checkpoints/SVFEND/SVFEND/_test_epoch4_0.7943'):
# 加载检测模型 模型存放路径 checkpoint_path
model_main = load_model(checkpoint_path)
model_vggish = load_model_vggish()
model_vgg19 = load_model_vgg19()
model_c3d = load_model_c3d()
model_text = wav2text()
models = {
'model_main': model_main,
'model_vggish': model_vggish,
'model_vgg19': model_vgg19,
'model_c3d' : model_c3d,
'model_text' : model_text
}
return models
# label = 0 if item['annotation']=='真' else 1
def test(model, dataloader):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# model.cuda()
model.eval()
pred = []
label = []
prob = []
for batch in tqdm(dataloader):
with torch.no_grad():
batch_data = batch
for k, v in batch_data.items():
batch_data[k] = v.to(device)
batch_label = batch_data['label']
batch_outputs, fea = model(**batch_data)
_, batch_preds = torch.max(batch_outputs, 1)
softmax_probs = F.softmax(batch_outputs, dim=1) # 计算softmax概率
label.extend(batch_label.detach().cpu().numpy().tolist())
pred.extend(batch_preds.detach().cpu().numpy().tolist())
prob.extend(softmax_probs.detach().cpu().numpy().tolist()) # 收集softmax概率
return (label, pred, prob)
def main(models,
video_file_path,
preprocessed_flag=False,
feature_path='./FakeVD/code_test/preprocessed_feature'):
# 视频是否已经过预处理 preprocessed_flag
# 特征存放目录 feature_path
# 获取模型
model_main = models['model_main']
model_vggish = models['model_vggish']
model_vgg19 = models['model_vgg19']
model_c3d = models['model_c3d']
model_text = models['model_text']
# 获取视频文件夹路径
video_folder_path = os.path.dirname(video_file_path)
# 获取视频文件名(包含扩展名)
video_file_name = os.path.basename(video_file_path)
# 提取视频文件名(不包括扩展名)作为视频ID
vids = []
vid = os.path.splitext(video_file_name)[0]
vids.append(vid)
# video_file_name = os.path.basename(video_file_path)
# vids.append(os.path.splitext(video_file_name)[0])
# # vids.append(video_file_name.split('_')[1].split('.')[0]
# VGGish_audio特征目录
VGGish_audio_feature_path = os.path.join(feature_path, vid+'.pkl')
# C3D_video特征目录
C3D_video_feature_path = os.path.join(feature_path, 'C3D/')
# VGG19_frame特征目录
VGG19_frame_feature_path = os.path.join(feature_path, 'VGG19/')
# ASR_text特征目录
asr_text_feature_path = os.path.join(feature_path, 'ASR/'+vid+'.json')
# 特征提取
if not preprocessed_flag:
vggish_audio(model_vggish, video_file_path, VGGish_audio_feature_path)
vgg19_frame(model_vgg19, video_file_name, video_folder_path, VGG19_frame_feature_path)
c3d_video(model_c3d, C3D_video_feature_path, video_folder_path, video_file_name)
asr_text(model_text, model_vggish, video_file_path, asr_text_feature_path)
# 数据路径
data = vids
data_paths = {
'VGGish_audio' : VGGish_audio_feature_path,
'C3D_video' : C3D_video_feature_path,
'VGG19_frame' : VGG19_frame_feature_path,
'ASR_text' : asr_text_feature_path
}
# 创建Dataset和DataLoader
dataset = SVFENDDataset(data, data_paths)
dataloader=DataLoader(dataset, batch_size=1,
num_workers=0,
pin_memory=True,
shuffle=False,
worker_init_fn=_init_fn,
collate_fn=SVFEND_collate_fn)
# 进行预测
predictions = test(model_main, dataloader)
annotation = '真' if predictions[1][0]==0 else '假'
prob_softmax = predictions[2]
# annotation_prob = max(prob_softmax[0])
annotation_prob = prob_softmax[0][0]#真的概率
annotation_prob1 = prob_softmax[0][1]#假的概率
# 打印预测结果
print(annotation, annotation_prob, annotation_prob1)
return annotation_prob1
if __name__ == "__main__":
# 视频是否已经过预处理
preprocessed_flag = False
video_file_path = "./FakeVD/dataset/videos_1/douyin_6700861687563570439.mp4"
models = get_model()
main(models, video_file_path, preprocessed_flag)