import os import torch import torchvision.transforms as transforms import import numpy as np from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score import pickle from tqdm import tqdm from datetime import datetime from copy import deepcopy from dataset_paths import DATASET_PATHS import random from datasetss import create_test_dataloader from utilss.logger import create_logger import options from networks.validator import Validator SEED = 0 def set_seed(): torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) MEAN = { "imagenet":[0.485, 0.456, 0.406], "clip":[0.48145466, 0.4578275, 0.40821073] } STD = { "imagenet":[0.229, 0.224, 0.225], "clip":[0.26862954, 0.26130258, 0.27577711] } def find_best_threshold(y_true, y_pred): "We assume first half is real 0, and the second half is fake 1" N = y_true.shape[0] if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 best_acc = 0 best_thres = 0 for thres in y_pred: temp = deepcopy(y_pred) temp[temp>=thres] = 1 temp[temp= best_acc: best_thres = thres best_acc = acc return best_thres def calculate_acc(y_true, y_pred, thres): r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > thres) f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > thres) acc = accuracy_score(y_true, y_pred > thres) return r_acc, f_acc, acc def validate(model, loader, logger, find_thres=False): with torch.no_grad(): y_true, y_pred = [], [] ("Length of dataset: %d" %(len(loader))) pbar = tqdm(loader) for data in pbar: pbar.set_description("%Y-%m-%d %H:%M:%S")) model.set_input(data) y_pred.extend(model.model(model.input).view(-1).unsqueeze(1).sigmoid().flatten().tolist()) y_true.extend(data[1].flatten().tolist()) y_true, y_pred = np.array(y_true), np.array(y_pred) # ================== save this if you want to plot the curves =========== # # torch.stack( [torch.tensor(y_true), torch.tensor(y_pred)] ), 'baseline_predication_for_pr_roc_curve.pth' ) # exit() # =================================================================== # # print(y_pred, '\n', y_true) # Get AP ap = average_precision_score(y_true, y_pred) # Acc based on 0.5 r_acc0, f_acc0, acc0 = calculate_acc(y_true, y_pred, 0.5) if not find_thres: return ap, r_acc0, f_acc0, acc0 # Acc based on the best thres best_thres = find_best_threshold(y_true, y_pred) r_acc1, f_acc1, acc1 = calculate_acc(y_true, y_pred, best_thres) return ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = # def recursively_read(rootdir, must_contain, exts=["png", "jpg", "JPEG", "jpeg", "bmp"]): out = [] for r, d, f in os.walk(rootdir): for file in f: if (file.split('.')[1] in exts) and (must_contain in os.path.join(r, file)): out.append(os.path.join(r, file)) return out def get_list(path, must_contain=''): if ".pickle" in path: with open(path, 'rb') as f: image_list = pickle.load(f) image_list = [ item for item in image_list if must_contain in item ] else: image_list = recursively_read(path, must_contain) return image_list if __name__ == '__main__': val_opt = options.TestOptions().parse() output_dir=os.path.join(val_opt.output, os.makedirs(output_dir, exist_ok=True) logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")"working dir: {output_dir}") model = Validator(val_opt) model.load_state_dict(val_opt.ckpt)"ckpt loaded!") val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess) ap, r_acc0, f_acc0, acc0, r_acc1, f_acc1, acc1, best_thres = validate(model, val_loader, logger, find_thres=True, ) print(f"ap: {ap}, r_acc0: {r_acc0}, f_acc0: {f_acc0}, acc0:{acc0}, r_acc1: {r_acc1}, f_acc1: {f_acc1}, acc1: {acc1}, best_thres: {best_thres} ") with open( os.path.join(,'ap.txt'), 'a') as f: f.write(str(round(ap*100, 2))+'\n' ) with open( os.path.join(,'acc0.txt'), 'a') as f: f.write(str(round(r_acc0*100, 2))+' '+str(round(f_acc0*100, 2))+' '+str(round(acc0*100, 2))+'\n' )