File size: 4,811 Bytes
e8e478e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import torch
import torchvision.transforms as transforms
import torch.utils.data
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score
import pickle
from tqdm import tqdm
from datetime import datetime
from copy import deepcopy
from dataset_paths import DATASET_PATHS
import random
from datasets import create_test_dataloader
from utils.logger import create_logger
import options
from networks.validator import Validator
SEED = 0
def set_seed():
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
MEAN = {
"imagenet":[0.485, 0.456, 0.406],
"clip":[0.48145466, 0.4578275, 0.40821073]
}
STD = {
"imagenet":[0.229, 0.224, 0.225],
"clip":[0.26862954, 0.26130258, 0.27577711]
}
def find_best_threshold(y_true, y_pred):
"We assume first half is real 0, and the second half is fake 1"
N = y_true.shape[0]
if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case
return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2
best_acc = 0
best_thres = 0
for thres in y_pred:
temp = deepcopy(y_pred)
temp[temp>=thres] = 1
temp[temp<thres] = 0
acc = (temp == y_true).sum() / N
if acc >= best_acc:
best_thres = thres
best_acc = acc
return best_thres
def calculate_acc(y_true, y_pred, thres):
r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres)
f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres)
acc = accuracy_score(y_true, y_pred > thres)
return r_acc, f_acc, acc
def validate(model, loader, logger, find_thres=False):
with torch.no_grad():
y_true, y_pred = [], []
logger.info ("Length of dataset: %d" %(len(loader)))
pbar = tqdm(loader)
for data in pbar:
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
model.set_input(data)
y_pred.extend(model.model(model.input).view(-1).unsqueeze(1).sigmoid().flatten().tolist())
y_true.extend(data[1].flatten().tolist())
y_true, y_pred = np.array(y_true), np.array(y_pred)
# ================== save this if you want to plot the curves =========== #
# torch.save( torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' )
# exit()
# =================================================================== #
# print(y_pred, '\n', y_true)
# Get AP
ap = average_precision_score(y_true, y_pred)
# Acc based on 0.5
r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5)
if not find_thres:
return ap, r_acc0, f_acc0, acc0
# Acc based on the best thres
best_thres = find_best_threshold(y_true, y_pred)
r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres)
return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres
# = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = #
def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]):
out = []
for r, d, f in os.walk(rootdir):
for file in f:
if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)):
out.append(os.path.join(r, file))
return out
def get_list(path, must_contain=''):
if ".pickle" in path:
with open(path, 'rb') as f:
image_list = pickle.load(f)
image_list = [ item for item in image_list if must_contain in item ]
else:
image_list = recursively_read(path, must_contain)
return image_list
if __name__ == '__main__':
val_opt = options.TestOptions().parse()
output_dir=os.path.join(val_opt.output, val_opt.name)
os.makedirs(output_dir, exist_ok=True)
logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
logger.info(f"working dir: {output_dir}")
model = Validator(val_opt)
model.load_state_dict(val_opt.ckpt)
logger.info("ckpt loaded!")
val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess)
ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, val_loader, logger, find_thres=True, )
print(f"ap: {ap}, r_acc0: {r_acc0}, f_acc0: {f_acc0}, acc0:{acc0}, r_acc1: {r_acc1}, f_acc1: {f_acc1}, acc1: {acc1}, best_thres: {best_thres} ")
with open( os.path.join(val_opt.name,'ap.txt'), 'a') as f:
f.write(str(round(ap*100, 2))+'\n' )
with open( os.path.join(val_opt.name,'acc0.txt'), 'a') as f:
f.write(str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' )
|