File size: 2,702 Bytes
3cc4a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.io import read_video
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
import pickle
from tqdm import tqdm
from datetime import datetime
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random

from datasetss import create_test_dataloader
from utilss.logger import create_logger
import options
from networks.validator import Validator


def get_model():
    val_opt = options.TestOptions().parse(print_options=False)
    output_dir=os.path.join(val_opt.output, val_opt.name)
    os.makedirs(output_dir, exist_ok=True)

    # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
    print(f"working...")

    model = Validator(val_opt)
    model.load_state_dict(val_opt.ckpt)
    print("ckpt loaded!")
    return model


def detect_video(video_path, model):
    frames, _, _ = read_video(str(video_path), pts_unit='sec')
    frames = frames[:16]
    frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)

    video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])

    with torch.no_grad():
        model.set_input([torch.as_tensor(video_frames), torch.tensor([0])])

        pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()

    return pred[0].item()


if __name__ == '__main__':
    video_path = '../../dataset/MSRVTT/videos/all/video1.mp4'
    # val_opt = options.TestOptions().parse()

    # output_dir=os.path.join(val_opt.output, val_opt.name)
    # os.makedirs(output_dir, exist_ok=True)
    # # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
    # print(f"working...")

    # model = Validator(val_opt)
    # model.load_state_dict(val_opt.ckpt)
    # print("ckpt loaded!")

    # # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess)
    # frames, _, _ = read_video(str(video_path), pts_unit='sec')
    # frames = frames[:16]
    # frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)


    # video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])

    # with torch.no_grad():
    #     model.set_input([torch.as_tensor(video_frames), torch.tensor([0])])

    #     pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()

    model = get_model()

    pred = detect_video(video_path, model)
    if pred > 0.5:
        print(f"Fake: {pred*100:.2f}%")
    else:
        print(f"Real: {(1-pred)*100:.2f}%")