Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as TF | |
from torchvision.io import read_video | |
import torch.utils.data | |
import numpy as np | |
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score | |
import pickle | |
from tqdm import tqdm | |
from datetime import datetime | |
from copy import deepcopy | |
from dataset_paths import DATASET_PATHS | |
import random | |
from datasetss import create_test_dataloader | |
from utilss.logger import create_logger | |
import options | |
from networks.validator import Validator | |
def get_model(): | |
val_opt = options.TestOptions().parse(print_options=False) | |
output_dir=os.path.join(val_opt.output, val_opt.name) | |
os.makedirs(output_dir, exist_ok=True) | |
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") | |
print(f"working...") | |
model = Validator(val_opt) | |
model.load_state_dict(val_opt.ckpt) | |
print("ckpt loaded!") | |
return model | |
def detect_video(video_path, model): | |
frames, _, _ = read_video(str(video_path), pts_unit='sec') | |
frames = frames[:16] | |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) | |
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) | |
with torch.no_grad(): | |
model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) | |
pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() | |
return pred[0].item() | |
if __name__ == '__main__': | |
video_path = '../../dataset/MSRVTT/videos/all/video1.mp4' | |
# val_opt = options.TestOptions().parse() | |
# output_dir=os.path.join(val_opt.output, val_opt.name) | |
# os.makedirs(output_dir, exist_ok=True) | |
# # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector") | |
# print(f"working...") | |
# model = Validator(val_opt) | |
# model.load_state_dict(val_opt.ckpt) | |
# print("ckpt loaded!") | |
# # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) | |
# frames, _, _ = read_video(str(video_path), pts_unit='sec') | |
# frames = frames[:16] | |
# frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W) | |
# video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) | |
# with torch.no_grad(): | |
# model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) | |
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() | |
model = get_model() | |
pred = detect_video(video_path, model) | |
if pred > 0.5: | |
print(f"Fake: {pred*100:.2f}%") | |
else: | |
print(f"Real: {(1-pred)*100:.2f}%") | |