FakeVideoDetect / run.py
ybbwcwaps
AI Video
3cc4a06
raw
history blame
2.7 kB
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.io import read_video
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
import pickle
from tqdm import tqdm
from datetime import datetime
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random
from datasetss import create_test_dataloader
from utilss.logger import create_logger
import options
from networks.validator import Validator
def get_model():
val_opt = options.TestOptions().parse(print_options=False)
output_dir=os.path.join(val_opt.output, val_opt.name)
os.makedirs(output_dir, exist_ok=True)
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
print(f"working...")
model = Validator(val_opt)
model.load_state_dict(val_opt.ckpt)
print("ckpt loaded!")
return model
def detect_video(video_path, model):
frames, _, _ = read_video(str(video_path), pts_unit='sec')
frames = frames[:16]
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
with torch.no_grad():
model.set_input([torch.as_tensor(video_frames), torch.tensor([0])])
pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
return pred[0].item()
if __name__ == '__main__':
video_path = '../../dataset/MSRVTT/videos/all/video1.mp4'
# val_opt = options.TestOptions().parse()
# output_dir=os.path.join(val_opt.output, val_opt.name)
# os.makedirs(output_dir, exist_ok=True)
# # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
# print(f"working...")
# model = Validator(val_opt)
# model.load_state_dict(val_opt.ckpt)
# print("ckpt loaded!")
# # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess)
# frames, _, _ = read_video(str(video_path), pts_unit='sec')
# frames = frames[:16]
# frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
# video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
# with torch.no_grad():
# model.set_input([torch.as_tensor(video_frames), torch.tensor([0])])
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
model = get_model()
pred = detect_video(video_path, model)
if pred > 0.5:
print(f"Fake: {pred*100:.2f}%")
else:
print(f"Real: {(1-pred)*100:.2f}%")