import os import torch import torchvision.transforms as transforms import torchvision.transforms.functional as TF from torchvision.io import read_video import torch.utils.data import numpy as np from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score import pickle from tqdm import tqdm from datetime import datetime from copy import deepcopy from dataset_paths import DATASET_PATHS import random from datasets import create_test_dataloader from utils.logger import create_logger import options from networks.validator import Validator def detect_video(video_path): val_opt = options.TestOptions().parse(print_options=False) output_dir=os.path.join(val_opt.output, val_opt.name) os.makedirs(output_dir, exist_ok=True) # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") print(f"working...") model = Validator(val_opt) model.load_state_dict(val_opt.ckpt) print("ckpt loaded!") # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) frames, _, _ = read_video(str(video_path), pts_unit='sec') frames = frames[:16] frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) with torch.no_grad(): model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() return pred[0].item() if __name__ == '__main__': video_path = '../../dataset/MSRVTT/videos/all/video1.mp4' # val_opt = options.TestOptions().parse() # output_dir=os.path.join(val_opt.output, val_opt.name) # os.makedirs(output_dir, exist_ok=True) # # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") # print(f"working...") # model = Validator(val_opt) # model.load_state_dict(val_opt.ckpt) # print("ckpt loaded!") # # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) # frames, _, _ = read_video(str(video_path), pts_unit='sec') # frames = frames[:16] # frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) # video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) # with torch.no_grad(): # model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) # pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() pred = detect_video(video_path) if pred > 0.5: print(f"Fake: {pred*100:.2f}%") else: print(f"Real: {(1-pred)*100:.2f}%")