russel0719 commited on
Commit
55da56b
·
1 Parent(s): 2638f8b

Upload 38 files

Browse files
Files changed (39) hide show
  1. .gitattributes +3 -0
  2. app.py +52 -0
  3. preprocessing/__init__.py +1 -0
  4. preprocessing/compress_videos.py +45 -0
  5. preprocessing/detect_original_faces.py +51 -0
  6. preprocessing/extract_crops.py +86 -0
  7. preprocessing/extract_images.py +42 -0
  8. preprocessing/face_detector.py +72 -0
  9. preprocessing/face_encodings.py +55 -0
  10. preprocessing/generate_diffs.py +73 -0
  11. preprocessing/generate_folds.py +114 -0
  12. preprocessing/generate_landmarks.py +75 -0
  13. preprocessing/utils.py +51 -0
  14. sample/sample1.mp4 +3 -0
  15. sample/sample2.mp4 +3 -0
  16. training/__init__.py +0 -0
  17. training/__pycache__/__init__.cpython-37.pyc +0 -0
  18. training/__pycache__/__init__.cpython-39.pyc +0 -0
  19. training/datasets/__init__.py +0 -0
  20. training/datasets/classifier_dataset.py +378 -0
  21. training/datasets/validation_set.py +60 -0
  22. training/losses.py +28 -0
  23. training/pipelines/__init__.py +0 -0
  24. training/pipelines/train_classifier.py +361 -0
  25. training/tools/__init__.py +0 -0
  26. training/tools/config.py +43 -0
  27. training/tools/schedulers.py +46 -0
  28. training/tools/utils.py +121 -0
  29. training/transforms/__init__.py +0 -0
  30. training/transforms/albu.py +99 -0
  31. training/zoo/__init__.py +0 -0
  32. training/zoo/__pycache__/__init__.cpython-37.pyc +0 -0
  33. training/zoo/__pycache__/__init__.cpython-39.pyc +0 -0
  34. training/zoo/__pycache__/classifiers.cpython-37.pyc +0 -0
  35. training/zoo/__pycache__/classifiers.cpython-39.pyc +0 -0
  36. training/zoo/classifiers.py +172 -0
  37. training/zoo/unet.py +151 -0
  38. utils.py +354 -0
  39. weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 +3 -0
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ sample/sample1.mp4 filter=lfs diff=lfs merge=lfs -text
36
+ sample/sample2.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import re
4
+ import torch
5
+ from utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
6
+ from training.zoo.classifiers import DeepFakeClassifier
7
+
8
+ def detect(video):
9
+ # Load model
10
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7")
11
+ path = os.path.join('weights', 'final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36')
12
+ checkpoint = torch.load(path, map_location="cpu")
13
+ state_dict = checkpoint.get("state_dict", checkpoint)
14
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
15
+ model.eval()
16
+ del checkpoint
17
+ models = [model.float()]
18
+
19
+ # Setting Video
20
+ frames_per_video = 32
21
+ video_reader = VideoReader()
22
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
23
+ face_extractor = FaceExtractor(video_read_fn)
24
+ input_size = 380
25
+ strategy = confident_strategy
26
+
27
+ # Predict
28
+ pred = predict_on_video(
29
+ face_extractor=face_extractor,
30
+ video=video,
31
+ batch_size=frames_per_video,
32
+ input_size=input_size,
33
+ models=models,
34
+ strategy=strategy
35
+ )
36
+ prob = {'Fake': float(pred), 'Real': float(1 - pred)}
37
+ return prob
38
+
39
+ gr_inputs = gr.Video(format='mp4', source='upload')
40
+ gr_outputs = gr.Label()
41
+ gr_ex = [
42
+ [os.path.join(os.path.dirname(__file__),"sample/sample1.mp4")],
43
+ [os.path.join(os.path.dirname(__file__),"sample/sample2.mp4")],
44
+ ]
45
+ iface = gr.Interface(
46
+ fn=detect,
47
+ inputs=gr_inputs,
48
+ outputs=gr_outputs,
49
+ examples=gr_ex,
50
+ )
51
+
52
+ iface.launch()
preprocessing/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .face_detector import *
preprocessing/compress_videos.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import subprocess
5
+
6
+ os.environ["MKL_NUM_THREADS"] = "1"
7
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
8
+ os.environ["OMP_NUM_THREADS"] = "1"
9
+ from functools import partial
10
+ from glob import glob
11
+ from multiprocessing.pool import Pool
12
+ from os import cpu_count
13
+
14
+ import cv2
15
+
16
+ cv2.ocl.setUseOpenCL(False)
17
+ cv2.setNumThreads(0)
18
+ from tqdm import tqdm
19
+
20
+
21
+ def compress_video(video, root_dir):
22
+ parent_dir = video.split("/")[-2]
23
+ out_dir = os.path.join(root_dir, "compressed", parent_dir)
24
+ os.makedirs(out_dir, exist_ok=True)
25
+ video_name = video.split("/")[-1]
26
+ out_path = os.path.join(out_dir, video_name)
27
+ lvl = random.choice([23, 28, 32])
28
+ command = "ffmpeg -i {} -c:v libx264 -crf {} -threads 1 {}".format(video, lvl, out_path)
29
+ try:
30
+ subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
31
+ except Exception as e:
32
+ print("Could not process vide", str(e))
33
+
34
+
35
+ if __name__ == '__main__':
36
+ parser = argparse.ArgumentParser(
37
+ description="Extracts jpegs from video")
38
+ parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
39
+
40
+ args = parser.parse_args()
41
+ videos = [video_path for video_path in glob(os.path.join(args.root_dir, "*/*.mp4"))]
42
+ with Pool(processes=cpu_count() - 2) as p:
43
+ with tqdm(total=len(videos)) as pbar:
44
+ for v in p.imap_unordered(partial(compress_video, root_dir=args.root_dir), videos):
45
+ pbar.update()
preprocessing/detect_original_faces.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from os import cpu_count
5
+ from typing import Type
6
+
7
+ from torch.utils.data.dataloader import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from preprocessing import face_detector, VideoDataset
11
+ from preprocessing.face_detector import VideoFaceDetector
12
+ from preprocessing.utils import get_original_video_paths
13
+
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(
17
+ description="Process a original videos with face detector")
18
+ parser.add_argument("--root-dir", help="root directory")
19
+ parser.add_argument("--detector-type", help="type of the detector", default="FacenetDetector",
20
+ choices=["FacenetDetector"])
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ def process_videos(videos, root_dir, detector_cls: Type[VideoFaceDetector]):
26
+ detector = face_detector.__dict__[detector_cls](device="cuda:0")
27
+ dataset = VideoDataset(videos)
28
+ loader = DataLoader(dataset, shuffle=False, num_workers=cpu_count() - 2, batch_size=1, collate_fn=lambda x: x)
29
+ for item in tqdm(loader):
30
+ result = {}
31
+ video, indices, frames = item[0]
32
+ batches = [frames[i:i + detector._batch_size] for i in range(0, len(frames), detector._batch_size)]
33
+ for j, frames in enumerate(batches):
34
+ result.update({int(j * detector._batch_size) + i : b for i, b in zip(indices, detector._detect_faces(frames))})
35
+ id = os.path.splitext(os.path.basename(video))[0]
36
+ out_dir = os.path.join(root_dir, "boxes")
37
+ os.makedirs(out_dir, exist_ok=True)
38
+ with open(os.path.join(out_dir, "{}.json".format(id)), "w") as f:
39
+ json.dump(result, f)
40
+
41
+
42
+
43
+
44
+ def main():
45
+ args = parse_args()
46
+ originals = get_original_video_paths(args.root_dir)
47
+ process_videos(originals, args.root_dir, args.detector_type)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
preprocessing/extract_crops.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from os import cpu_count
5
+ from pathlib import Path
6
+
7
+ os.environ["MKL_NUM_THREADS"] = "1"
8
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
9
+ os.environ["OMP_NUM_THREADS"] = "1"
10
+ from functools import partial
11
+ from glob import glob
12
+ from multiprocessing.pool import Pool
13
+
14
+ import cv2
15
+
16
+ cv2.ocl.setUseOpenCL(False)
17
+ cv2.setNumThreads(0)
18
+ from tqdm import tqdm
19
+
20
+
21
+ def extract_video(param, root_dir, crops_dir):
22
+ video, bboxes_path = param
23
+ with open(bboxes_path, "r") as bbox_f:
24
+ bboxes_dict = json.load(bbox_f)
25
+
26
+ capture = cv2.VideoCapture(video)
27
+ frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
28
+
29
+ for i in range(frames_num):
30
+ capture.grab()
31
+ if i % 10 != 0:
32
+ continue
33
+ success, frame = capture.retrieve()
34
+ if not success or str(i) not in bboxes_dict:
35
+ continue
36
+ id = os.path.splitext(os.path.basename(video))[0]
37
+ crops = []
38
+ bboxes = bboxes_dict[str(i)]
39
+ if bboxes is None:
40
+ continue
41
+ for bbox in bboxes:
42
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
43
+ w = xmax - xmin
44
+ h = ymax - ymin
45
+ p_h = h // 3
46
+ p_w = w // 3
47
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
48
+ h, w = crop.shape[:2]
49
+ crops.append(crop)
50
+ img_dir = os.path.join(root_dir, crops_dir, id)
51
+ os.makedirs(img_dir, exist_ok=True)
52
+ for j, crop in enumerate(crops):
53
+ cv2.imwrite(os.path.join(img_dir, "{}_{}.png".format(i, j)), crop)
54
+
55
+
56
+ def get_video_paths(root_dir):
57
+ paths = []
58
+ for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
59
+ dir = Path(json_path).parent
60
+ with open(json_path, "r") as f:
61
+ metadata = json.load(f)
62
+ for k, v in metadata.items():
63
+ original = v.get("original", None)
64
+ if not original:
65
+ original = k
66
+ bboxes_path = os.path.join(root_dir, "boxes", original[:-4] + ".json")
67
+ if not os.path.exists(bboxes_path):
68
+ continue
69
+ paths.append((os.path.join(dir, k), bboxes_path))
70
+
71
+ return paths
72
+
73
+
74
+ if __name__ == '__main__':
75
+ parser = argparse.ArgumentParser(
76
+ description="Extracts crops from video")
77
+ parser.add_argument("--root-dir", help="root directory")
78
+ parser.add_argument("--crops-dir", help="crops directory")
79
+
80
+ args = parser.parse_args()
81
+ os.makedirs(os.path.join(args.root_dir, args.crops_dir), exist_ok=True)
82
+ params = get_video_paths(args.root_dir)
83
+ with Pool(processes=cpu_count()) as p:
84
+ with tqdm(total=len(params)) as pbar:
85
+ for v in p.imap_unordered(partial(extract_video, root_dir=args.root_dir, crops_dir=args.crops_dir), params):
86
+ pbar.update()
preprocessing/extract_images.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ os.environ["MKL_NUM_THREADS"] = "1"
4
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
5
+ os.environ["OMP_NUM_THREADS"] = "1"
6
+ from functools import partial
7
+ from glob import glob
8
+ from multiprocessing.pool import Pool
9
+ from os import cpu_count
10
+
11
+ import cv2
12
+ cv2.ocl.setUseOpenCL(False)
13
+ cv2.setNumThreads(0)
14
+ from tqdm import tqdm
15
+
16
+
17
+ def extract_video(video, root_dir):
18
+ capture = cv2.VideoCapture(video)
19
+ frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
20
+
21
+ for i in range(frames_num):
22
+ capture.grab()
23
+ success, frame = capture.retrieve()
24
+ if not success:
25
+ continue
26
+ id = os.path.splitext(os.path.basename(video))[0]
27
+ cv2.imwrite(os.path.join(root_dir, "jpegs", "{}_{}.jpg".format(id, i)), frame, [cv2.IMWRITE_JPEG_QUALITY, 100])
28
+
29
+
30
+
31
+ if __name__ == '__main__':
32
+ parser = argparse.ArgumentParser(
33
+ description="Extracts jpegs from video")
34
+ parser.add_argument("--root-dir", help="root directory")
35
+
36
+ args = parser.parse_args()
37
+ os.makedirs(os.path.join(args.root_dir, "jpegs"), exist_ok=True)
38
+ videos = [video_path for video_path in glob(os.path.join(args.root_dir, "*/*.mp4"))]
39
+ with Pool(processes=cpu_count() - 2) as p:
40
+ with tqdm(total=len(videos)) as pbar:
41
+ for v in p.imap_unordered(partial(extract_video, root_dir=args.root_dir), videos):
42
+ pbar.update()
preprocessing/face_detector.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["MKL_NUM_THREADS"] = "1"
3
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
4
+ os.environ["OMP_NUM_THREADS"] = "1"
5
+
6
+ from abc import ABC, abstractmethod
7
+ from collections import OrderedDict
8
+ from typing import List
9
+
10
+ import cv2
11
+ cv2.ocl.setUseOpenCL(False)
12
+ cv2.setNumThreads(0)
13
+
14
+ from PIL import Image
15
+ from facenet_pytorch.models.mtcnn import MTCNN
16
+ from torch.utils.data import Dataset
17
+
18
+
19
+ class VideoFaceDetector(ABC):
20
+
21
+ def __init__(self, **kwargs) -> None:
22
+ super().__init__()
23
+
24
+ @property
25
+ @abstractmethod
26
+ def _batch_size(self) -> int:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def _detect_faces(self, frames) -> List:
31
+ pass
32
+
33
+
34
+ class FacenetDetector(VideoFaceDetector):
35
+
36
+ def __init__(self, device="cuda:0") -> None:
37
+ super().__init__()
38
+ self.detector = MTCNN(margin=0,thresholds=[0.85, 0.95, 0.95], device=device)
39
+
40
+ def _detect_faces(self, frames) -> List:
41
+ batch_boxes, *_ = self.detector.detect(frames, landmarks=False)
42
+ return [b.tolist() if b is not None else None for b in batch_boxes]
43
+
44
+ @property
45
+ def _batch_size(self):
46
+ return 32
47
+
48
+
49
+ class VideoDataset(Dataset):
50
+
51
+ def __init__(self, videos) -> None:
52
+ super().__init__()
53
+ self.videos = videos
54
+
55
+ def __getitem__(self, index: int):
56
+ video = self.videos[index]
57
+ capture = cv2.VideoCapture(video)
58
+ frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
59
+ frames = OrderedDict()
60
+ for i in range(frames_num):
61
+ capture.grab()
62
+ success, frame = capture.retrieve()
63
+ if not success:
64
+ continue
65
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
66
+ frame = Image.fromarray(frame)
67
+ frame = frame.resize(size=[s // 2 for s in frame.size])
68
+ frames[i] = frame
69
+ return video, list(frames.keys()), list(frames.values())
70
+
71
+ def __len__(self) -> int:
72
+ return len(self.videos)
preprocessing/face_encodings.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from functools import partial
4
+ from multiprocessing.pool import Pool
5
+
6
+ from tqdm import tqdm
7
+
8
+ from preprocessing.utils import get_original_video_paths
9
+
10
+ os.environ["MKL_NUM_THREADS"] = "1"
11
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
12
+ os.environ["OMP_NUM_THREADS"] = "1"
13
+
14
+ import random
15
+
16
+ import face_recognition
17
+ import numpy as np
18
+
19
+
20
+ def write_face_encodings(video, root_dir):
21
+ video_id, *_ = os.path.splitext(video)
22
+ crops_dir = os.path.join(root_dir, "crops", video_id)
23
+ if not os.path.exists(crops_dir):
24
+ return
25
+ crop_files = [f for f in os.listdir(crops_dir) if f.endswith("jpg")]
26
+ if crop_files:
27
+ crop_files = random.sample(crop_files, min(10, len(crop_files)))
28
+ encodings = []
29
+ for crop_file in crop_files:
30
+ img = face_recognition.load_image_file(os.path.join(crops_dir, crop_file))
31
+ encoding = face_recognition.face_encodings(img, num_jitters=10)
32
+ if encoding:
33
+ encodings.append(encoding[0])
34
+ np.save(os.path.join(crops_dir, "encodings"), encodings)
35
+
36
+
37
+ def parse_args():
38
+ parser = argparse.ArgumentParser(
39
+ description="Extract 10 crops encodings for each video")
40
+ parser.add_argument("--root-dir", help="root directory", default="/home/selim/datasets/deepfake")
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def main():
46
+ args = parse_args()
47
+ originals = get_original_video_paths(args.root_dir, basename=True)
48
+ with Pool(processes=os.cpu_count() - 4) as p:
49
+ with tqdm(total=len(originals)) as pbar:
50
+ for v in p.imap_unordered(partial(write_face_encodings, root_dir=args.root_dir), originals):
51
+ pbar.update()
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
preprocessing/generate_diffs.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ os.environ["MKL_NUM_THREADS"] = "1"
5
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
6
+ os.environ["OMP_NUM_THREADS"] = "1"
7
+ from skimage.measure import compare_ssim
8
+
9
+ from functools import partial
10
+ from multiprocessing.pool import Pool
11
+
12
+ from tqdm import tqdm
13
+
14
+ from preprocessing.utils import get_original_with_fakes
15
+
16
+ import cv2
17
+
18
+ cv2.ocl.setUseOpenCL(False)
19
+ cv2.setNumThreads(0)
20
+
21
+ import numpy as np
22
+
23
+ cache = {}
24
+
25
+
26
+ def save_diffs(pair, root_dir):
27
+ ori_id, fake_id = pair
28
+ ori_dir = os.path.join(root_dir, "crops", ori_id)
29
+ fake_dir = os.path.join(root_dir, "crops", fake_id)
30
+ diff_dir = os.path.join(root_dir, "diffs", fake_id)
31
+ os.makedirs(diff_dir, exist_ok=True)
32
+ for frame in range(320):
33
+ if frame % 10 != 0:
34
+ continue
35
+ for actor in range(2):
36
+ image_id = "{}_{}.png".format(frame, actor)
37
+ diff_image_id = "{}_{}_diff.png".format(frame, actor)
38
+ ori_path = os.path.join(ori_dir, image_id)
39
+ fake_path = os.path.join(fake_dir, image_id)
40
+ diff_path = os.path.join(diff_dir, diff_image_id)
41
+ if os.path.exists(ori_path) and os.path.exists(fake_path):
42
+ img1 = cv2.imread(ori_path, cv2.IMREAD_COLOR)
43
+ img2 = cv2.imread(fake_path, cv2.IMREAD_COLOR)
44
+ try:
45
+ d, a = compare_ssim(img1, img2, multichannel=True, full=True)
46
+ a = 1 - a
47
+ diff = (a * 255).astype(np.uint8)
48
+ diff = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
49
+ cv2.imwrite(diff_path, diff)
50
+ except:
51
+ pass
52
+
53
+ def parse_args():
54
+ parser = argparse.ArgumentParser(
55
+ description="Extract image diffs")
56
+ parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
57
+ args = parser.parse_args()
58
+ return args
59
+
60
+
61
+ def main():
62
+ args = parse_args()
63
+ pairs = get_original_with_fakes(args.root_dir)
64
+ os.makedirs(os.path.join(args.root_dir, "diffs"), exist_ok=True)
65
+ with Pool(processes=os.cpu_count() - 2) as p:
66
+ with tqdm(total=len(pairs)) as pbar:
67
+ func = partial(save_diffs, root_dir=args.root_dir)
68
+ for v in p.imap_unordered(func, pairs):
69
+ pbar.update()
70
+
71
+
72
+ if __name__ == '__main__':
73
+ main()
preprocessing/generate_folds.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ from functools import partial
6
+ from multiprocessing.pool import Pool
7
+ from pathlib import Path
8
+
9
+ os.environ["MKL_NUM_THREADS"] = "1"
10
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
11
+ os.environ["OMP_NUM_THREADS"] = "1"
12
+ import pandas as pd
13
+
14
+ from tqdm import tqdm
15
+
16
+ from preprocessing.utils import get_original_with_fakes
17
+
18
+ import cv2
19
+
20
+ cv2.ocl.setUseOpenCL(False)
21
+ cv2.setNumThreads(0)
22
+
23
+
24
+ def get_paths(vid, label, root_dir):
25
+ ori_vid, fake_vid = vid
26
+ ori_dir = os.path.join(root_dir, "crops", ori_vid)
27
+ fake_dir = os.path.join(root_dir, "crops", fake_vid)
28
+ data = []
29
+ for frame in range(320):
30
+ if frame % 10 != 0:
31
+ continue
32
+ for actor in range(2):
33
+ image_id = "{}_{}.png".format(frame, actor)
34
+ ori_img_path = os.path.join(ori_dir, image_id)
35
+ fake_img_path = os.path.join(fake_dir, image_id)
36
+ img_path = ori_img_path if label == 0 else fake_img_path
37
+ try:
38
+ # img = cv2.imread(img_path)[..., ::-1]
39
+ if os.path.exists(img_path):
40
+ data.append([img_path, label, ori_vid])
41
+ except:
42
+ pass
43
+ return data
44
+
45
+
46
+ def parse_args():
47
+ parser = argparse.ArgumentParser(
48
+ description="Generate Folds")
49
+ parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
50
+ parser.add_argument("--out", type=str, default="folds02.csv", help="CSV file to save")
51
+ parser.add_argument("--seed", type=int, default=777, help="Seed to split, default 777")
52
+ parser.add_argument("--n_splits", type=int, default=16, help="Num folds, default 10")
53
+ args = parser.parse_args()
54
+
55
+ return args
56
+
57
+
58
+ def main():
59
+ args = parse_args()
60
+ ori_fakes = get_original_with_fakes(args.root_dir)
61
+ sz = 50 // args.n_splits
62
+ folds = []
63
+ for fold in range(args.n_splits):
64
+ folds.append(list(range(sz * fold, sz * fold + sz if fold < args.n_splits - 1 else 50)))
65
+ print(folds)
66
+ video_fold = {}
67
+ for d in os.listdir(args.root_dir):
68
+ if "dfdc" in d:
69
+ part = int(d.split("_")[-1])
70
+ for f in os.listdir(os.path.join(args.root_dir, d)):
71
+ if "metadata.json" in f:
72
+ with open(os.path.join(args.root_dir, d, "metadata.json")) as metadata_json:
73
+ metadata = json.load(metadata_json)
74
+
75
+ for k, v in metadata.items():
76
+ fold = None
77
+ for i, fold_dirs in enumerate(folds):
78
+ if part in fold_dirs:
79
+ fold = i
80
+ break
81
+ assert fold is not None
82
+ video_id = k[:-4]
83
+ video_fold[video_id] = fold
84
+ for fold in range(len(folds)):
85
+ holdoutset = {k for k, v in video_fold.items() if v == fold}
86
+ trainset = {k for k, v in video_fold.items() if v != fold}
87
+ assert holdoutset.isdisjoint(trainset), "Folds have leaks"
88
+ data = []
89
+ ori_ori = set([(ori, ori) for ori, fake in ori_fakes])
90
+ with Pool(processes=os.cpu_count()) as p:
91
+ with tqdm(total=len(ori_ori)) as pbar:
92
+ func = partial(get_paths, label=0, root_dir=args.root_dir)
93
+ for v in p.imap_unordered(func, ori_ori):
94
+ pbar.update()
95
+ data.extend(v)
96
+ with tqdm(total=len(ori_fakes)) as pbar:
97
+ func = partial(get_paths, label=1, root_dir=args.root_dir)
98
+ for v in p.imap_unordered(func, ori_fakes):
99
+ pbar.update()
100
+ data.extend(v)
101
+ fold_data = []
102
+ for img_path, label, ori_vid in data:
103
+ path = Path(img_path)
104
+ video = path.parent.name
105
+ file = path.name
106
+ assert video_fold[video] == video_fold[ori_vid], "original video and fake have leak {} {}".format(ori_vid,
107
+ video)
108
+ fold_data.append([video, file, label, ori_vid, int(file.split("_")[0]), video_fold[video]])
109
+ random.shuffle(fold_data)
110
+ pd.DataFrame(fold_data, columns=["video", "file", "label", "original", "frame", "fold"]).to_csv(args.out, index=False)
111
+
112
+
113
+ if __name__ == '__main__':
114
+ main()
preprocessing/generate_landmarks.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from functools import partial
4
+ from multiprocessing.pool import Pool
5
+
6
+
7
+
8
+ os.environ["MKL_NUM_THREADS"] = "1"
9
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
10
+ os.environ["OMP_NUM_THREADS"] = "1"
11
+
12
+ from tqdm import tqdm
13
+
14
+
15
+ import cv2
16
+
17
+ cv2.ocl.setUseOpenCL(False)
18
+ cv2.setNumThreads(0)
19
+ from preprocessing.utils import get_original_video_paths
20
+
21
+ from PIL import Image
22
+ from facenet_pytorch.models.mtcnn import MTCNN
23
+ import numpy as np
24
+
25
+ detector = MTCNN(margin=0, thresholds=[0.65, 0.75, 0.75], device="cpu")
26
+
27
+
28
+ def save_landmarks(ori_id, root_dir):
29
+ ori_id = ori_id[:-4]
30
+ ori_dir = os.path.join(root_dir, "crops", ori_id)
31
+ landmark_dir = os.path.join(root_dir, "landmarks", ori_id)
32
+ os.makedirs(landmark_dir, exist_ok=True)
33
+ for frame in range(320):
34
+ if frame % 10 != 0:
35
+ continue
36
+ for actor in range(2):
37
+ image_id = "{}_{}.png".format(frame, actor)
38
+ landmarks_id = "{}_{}".format(frame, actor)
39
+ ori_path = os.path.join(ori_dir, image_id)
40
+ landmark_path = os.path.join(landmark_dir, landmarks_id)
41
+
42
+ if os.path.exists(ori_path):
43
+ try:
44
+ image_ori = cv2.imread(ori_path, cv2.IMREAD_COLOR)[...,::-1]
45
+ frame_img = Image.fromarray(image_ori)
46
+ batch_boxes, conf, landmarks = detector.detect(frame_img, landmarks=True)
47
+ if landmarks is not None:
48
+ landmarks = np.around(landmarks[0]).astype(np.int16)
49
+ np.save(landmark_path, landmarks)
50
+ except Exception as e:
51
+ print(e)
52
+ pass
53
+
54
+
55
+ def parse_args():
56
+ parser = argparse.ArgumentParser(
57
+ description="Extract image landmarks")
58
+ parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
59
+ args = parser.parse_args()
60
+ return args
61
+
62
+
63
+ def main():
64
+ args = parse_args()
65
+ ids = get_original_video_paths(args.root_dir, basename=True)
66
+ os.makedirs(os.path.join(args.root_dir, "landmarks"), exist_ok=True)
67
+ with Pool(processes=os.cpu_count()) as p:
68
+ with tqdm(total=len(ids)) as pbar:
69
+ func = partial(save_landmarks, root_dir=args.root_dir)
70
+ for v in p.imap_unordered(func, ids):
71
+ pbar.update()
72
+
73
+
74
+ if __name__ == '__main__':
75
+ main()
preprocessing/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+
6
+
7
+ def get_original_video_paths(root_dir, basename=False):
8
+ originals = set()
9
+ originals_v = set()
10
+ for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
11
+ dir = Path(json_path).parent
12
+ with open(json_path, "r") as f:
13
+ metadata = json.load(f)
14
+ for k, v in metadata.items():
15
+ original = v.get("original", None)
16
+ if v["label"] == "REAL":
17
+ original = k
18
+ originals_v.add(original)
19
+ originals.add(os.path.join(dir, original))
20
+ originals = list(originals)
21
+ originals_v = list(originals_v)
22
+ print(len(originals))
23
+ return originals_v if basename else originals
24
+
25
+
26
+ def get_original_with_fakes(root_dir):
27
+ pairs = []
28
+ for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
29
+ with open(json_path, "r") as f:
30
+ metadata = json.load(f)
31
+ for k, v in metadata.items():
32
+ original = v.get("original", None)
33
+ if v["label"] == "FAKE":
34
+ pairs.append((original[:-4], k[:-4] ))
35
+
36
+ return pairs
37
+
38
+
39
+ def get_originals_and_fakes(root_dir):
40
+ originals = []
41
+ fakes = []
42
+ for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
43
+ with open(json_path, "r") as f:
44
+ metadata = json.load(f)
45
+ for k, v in metadata.items():
46
+ if v["label"] == "FAKE":
47
+ fakes.append(k[:-4])
48
+ else:
49
+ originals.append(k[:-4])
50
+
51
+ return originals, fakes
sample/sample1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37f0ff1e337e6fe2211757b8255c8098f8a4f303b333e9c5e0da27c37e095876
3
+ size 5128324
sample/sample2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa3a01923fdfe38abc7b12b436c0e631575b6337fcde90cb9017e8dd5b733db0
3
+ size 16193231
training/__init__.py ADDED
File without changes
training/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (157 Bytes). View file
 
training/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (161 Bytes). View file
 
training/datasets/__init__.py ADDED
File without changes
training/datasets/classifier_dataset.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import sys
5
+ import traceback
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import pandas as pd
10
+ import skimage.draw
11
+ from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
12
+ from albumentations.augmentations.functional import image_compression, rot90
13
+ from albumentations.pytorch.functional import img_to_tensor
14
+ from scipy.ndimage import binary_erosion, binary_dilation
15
+ from skimage import measure
16
+ from torch.utils.data import Dataset
17
+ import dlib
18
+
19
+ from training.datasets.validation_set import PUBLIC_SET
20
+
21
+
22
+ def prepare_bit_masks(mask):
23
+ h, w = mask.shape
24
+ mid_w = w // 2
25
+ mid_h = w // 2
26
+ masks = []
27
+ ones = np.ones_like(mask)
28
+ ones[:mid_h] = 0
29
+ masks.append(ones)
30
+ ones = np.ones_like(mask)
31
+ ones[mid_h:] = 0
32
+ masks.append(ones)
33
+ ones = np.ones_like(mask)
34
+ ones[:, :mid_w] = 0
35
+ masks.append(ones)
36
+ ones = np.ones_like(mask)
37
+ ones[:, mid_w:] = 0
38
+ masks.append(ones)
39
+ ones = np.ones_like(mask)
40
+ ones[:mid_h, :mid_w] = 0
41
+ ones[mid_h:, mid_w:] = 0
42
+ masks.append(ones)
43
+ ones = np.ones_like(mask)
44
+ ones[:mid_h, mid_w:] = 0
45
+ ones[mid_h:, :mid_w] = 0
46
+ masks.append(ones)
47
+ return masks
48
+
49
+
50
+ detector = dlib.get_frontal_face_detector()
51
+ predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
52
+
53
+
54
+ def blackout_convex_hull(img):
55
+ try:
56
+ rect = detector(img)[0]
57
+ sp = predictor(img, rect)
58
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
59
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
60
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
61
+ cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
62
+ cropped_img[Y, X] = 1
63
+ # if random.random() > 0.5:
64
+ # img[cropped_img == 0] = 0
65
+ # #leave only face
66
+ # return img
67
+
68
+ y, x = measure.centroid(cropped_img)
69
+ y = int(y)
70
+ x = int(x)
71
+ first = random.random() > 0.5
72
+ if random.random() > 0.5:
73
+ if first:
74
+ cropped_img[:y, :] = 0
75
+ else:
76
+ cropped_img[y:, :] = 0
77
+ else:
78
+ if first:
79
+ cropped_img[:, :x] = 0
80
+ else:
81
+ cropped_img[:, x:] = 0
82
+
83
+ img[cropped_img > 0] = 0
84
+ except Exception as e:
85
+ pass
86
+
87
+
88
+ def dist(p1, p2):
89
+ return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
90
+
91
+
92
+ def remove_eyes(image, landmarks):
93
+ image = image.copy()
94
+ (x1, y1), (x2, y2) = landmarks[:2]
95
+ mask = np.zeros_like(image[..., 0])
96
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
97
+ w = dist((x1, y1), (x2, y2))
98
+ dilation = int(w // 4)
99
+ line = binary_dilation(line, iterations=dilation)
100
+ image[line, :] = 0
101
+ return image
102
+
103
+
104
+ def remove_nose(image, landmarks):
105
+ image = image.copy()
106
+ (x1, y1), (x2, y2) = landmarks[:2]
107
+ x3, y3 = landmarks[2]
108
+ mask = np.zeros_like(image[..., 0])
109
+ x4 = int((x1 + x2) / 2)
110
+ y4 = int((y1 + y2) / 2)
111
+ line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
112
+ w = dist((x1, y1), (x2, y2))
113
+ dilation = int(w // 4)
114
+ line = binary_dilation(line, iterations=dilation)
115
+ image[line, :] = 0
116
+ return image
117
+
118
+
119
+ def remove_mouth(image, landmarks):
120
+ image = image.copy()
121
+ (x1, y1), (x2, y2) = landmarks[-2:]
122
+ mask = np.zeros_like(image[..., 0])
123
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
124
+ w = dist((x1, y1), (x2, y2))
125
+ dilation = int(w // 3)
126
+ line = binary_dilation(line, iterations=dilation)
127
+ image[line, :] = 0
128
+ return image
129
+
130
+
131
+ def remove_landmark(image, landmarks):
132
+ if random.random() > 0.5:
133
+ image = remove_eyes(image, landmarks)
134
+ elif random.random() > 0.5:
135
+ image = remove_mouth(image, landmarks)
136
+ elif random.random() > 0.5:
137
+ image = remove_nose(image, landmarks)
138
+ return image
139
+
140
+
141
+ def change_padding(image, part=5):
142
+ h, w = image.shape[:2]
143
+ # original padding was done with 1/3 from each side, too much
144
+ pad_h = int(((3 / 5) * h) / part)
145
+ pad_w = int(((3 / 5) * w) / part)
146
+ image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
147
+ return image
148
+
149
+
150
+ def blackout_random(image, mask, label):
151
+ binary_mask = mask > 0.4 * 255
152
+ h, w = binary_mask.shape[:2]
153
+
154
+ tries = 50
155
+ current_try = 1
156
+ while current_try < tries:
157
+ first = random.random() < 0.5
158
+ if random.random() < 0.5:
159
+ pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
160
+ bitmap_msk = np.ones_like(binary_mask)
161
+ if first:
162
+ bitmap_msk[:pivot, :] = 0
163
+ else:
164
+ bitmap_msk[pivot:, :] = 0
165
+ else:
166
+ pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
167
+ bitmap_msk = np.ones_like(binary_mask)
168
+ if first:
169
+ bitmap_msk[:, :pivot] = 0
170
+ else:
171
+ bitmap_msk[:, pivot:] = 0
172
+
173
+ if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
174
+ or np.count_nonzero(binary_mask * bitmap_msk) > 40:
175
+ mask *= bitmap_msk
176
+ image *= np.expand_dims(bitmap_msk, axis=-1)
177
+ break
178
+ current_try += 1
179
+ return image
180
+
181
+
182
+ def blend_original(img):
183
+ img = img.copy()
184
+ h, w = img.shape[:2]
185
+ rect = detector(img)
186
+ if len(rect) == 0:
187
+ return img
188
+ else:
189
+ rect = rect[0]
190
+ sp = predictor(img, rect)
191
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
192
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
193
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
194
+ raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
195
+ raw_mask[Y, X] = 1
196
+ face = img * np.expand_dims(raw_mask, -1)
197
+
198
+ # add warping
199
+ h1 = random.randint(h - h // 2, h + h // 2)
200
+ w1 = random.randint(w - w // 2, w + w // 2)
201
+ while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
202
+ h1 = random.randint(h - h // 2, h + h // 2)
203
+ w1 = random.randint(w - w // 2, w + w // 2)
204
+ face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
205
+ face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
206
+
207
+ raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
208
+ img[raw_mask, :] = face[raw_mask, :]
209
+ if random.random() < 0.2:
210
+ img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
211
+ # image compression
212
+ if random.random() < 0.5:
213
+ img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
214
+ return img
215
+
216
+
217
+ class DeepFakeClassifierDataset(Dataset):
218
+
219
+ def __init__(self,
220
+ data_path="/mnt/sota/datasets/deepfake",
221
+ fold=0,
222
+ label_smoothing=0.01,
223
+ padding_part=3,
224
+ hardcore=True,
225
+ crops_dir="crops",
226
+ folds_csv="folds.csv",
227
+ normalize={"mean": [0.485, 0.456, 0.406],
228
+ "std": [0.229, 0.224, 0.225]},
229
+ rotation=False,
230
+ mode="train",
231
+ reduce_val=True,
232
+ oversample_real=True,
233
+ transforms=None
234
+ ):
235
+ super().__init__()
236
+ self.data_root = data_path
237
+ self.fold = fold
238
+ self.folds_csv = folds_csv
239
+ self.mode = mode
240
+ self.rotation = rotation
241
+ self.padding_part = padding_part
242
+ self.hardcore = hardcore
243
+ self.crops_dir = crops_dir
244
+ self.label_smoothing = label_smoothing
245
+ self.normalize = normalize
246
+ self.transforms = transforms
247
+ self.df = pd.read_csv(self.folds_csv)
248
+ self.oversample_real = oversample_real
249
+ self.reduce_val = reduce_val
250
+
251
+ def __getitem__(self, index: int):
252
+
253
+ while True:
254
+ video, img_file, label, ori_video, frame, fold = self.data[index]
255
+ try:
256
+ if self.mode == "train":
257
+ label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
258
+ img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
259
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
260
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
261
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
262
+ diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
263
+ try:
264
+ msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
265
+ if msk is not None:
266
+ mask = msk
267
+ except:
268
+ print("not found mask", diff_path)
269
+ pass
270
+ if self.mode == "train" and self.hardcore and not self.rotation:
271
+ landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
272
+ if os.path.exists(landmark_path) and random.random() < 0.7:
273
+ landmarks = np.load(landmark_path)
274
+ image = remove_landmark(image, landmarks)
275
+ elif random.random() < 0.2:
276
+ blackout_convex_hull(image)
277
+ elif random.random() < 0.1:
278
+ binary_mask = mask > 0.4 * 255
279
+ masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
280
+ tries = 6
281
+ current_try = 1
282
+ while current_try < tries:
283
+ bitmap_msk = random.choice(masks)
284
+ if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
285
+ mask *= bitmap_msk
286
+ image *= np.expand_dims(bitmap_msk, axis=-1)
287
+ break
288
+ current_try += 1
289
+ if self.mode == "train" and self.padding_part > 3:
290
+ image = change_padding(image, self.padding_part)
291
+ valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
292
+ valid_label = 1 if valid_label else 0
293
+ rotation = 0
294
+ if self.transforms:
295
+ data = self.transforms(image=image, mask=mask)
296
+ image = data["image"]
297
+ mask = data["mask"]
298
+ if self.mode == "train" and self.hardcore and self.rotation:
299
+ # landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
300
+ dropout = 0.8 if label > 0.5 else 0.6
301
+ if self.rotation:
302
+ dropout *= 0.7
303
+ elif random.random() < dropout:
304
+ blackout_random(image, mask, label)
305
+
306
+ #
307
+ # os.makedirs("../images", exist_ok=True)
308
+ # cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
309
+
310
+ if self.mode == "train" and self.rotation:
311
+ rotation = random.randint(0, 3)
312
+ image = rot90(image, rotation)
313
+
314
+ image = img_to_tensor(image, self.normalize)
315
+ return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
316
+ "valid": valid_label, "rotations": rotation}
317
+ except Exception as e:
318
+ traceback.print_exc(file=sys.stdout)
319
+ print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
320
+ index = random.randint(0, len(self.data) - 1)
321
+
322
+ def random_blackout_landmark(self, image, mask, landmarks):
323
+ x, y = random.choice(landmarks)
324
+ first = random.random() > 0.5
325
+ # crop half face either vertically or horizontally
326
+ if random.random() > 0.5:
327
+ # width
328
+ if first:
329
+ image[:, :x] = 0
330
+ mask[:, :x] = 0
331
+ else:
332
+ image[:, x:] = 0
333
+ mask[:, x:] = 0
334
+ else:
335
+ # height
336
+ if first:
337
+ image[:y, :] = 0
338
+ mask[:y, :] = 0
339
+ else:
340
+ image[y:, :] = 0
341
+ mask[y:, :] = 0
342
+
343
+ def reset(self, epoch, seed):
344
+ self.data = self._prepare_data(epoch, seed)
345
+
346
+ def __len__(self) -> int:
347
+ return len(self.data)
348
+
349
+ def _prepare_data(self, epoch, seed):
350
+ df = self.df
351
+ if self.mode == "train":
352
+ rows = df[df["fold"] != self.fold]
353
+ else:
354
+ rows = df[df["fold"] == self.fold]
355
+ seed = (epoch + 1) * seed
356
+ if self.oversample_real:
357
+ rows = self._oversample(rows, seed)
358
+ if self.mode == "val" and self.reduce_val:
359
+ # every 2nd frame, to speed up validation
360
+ rows = rows[rows["frame"] % 20 == 0]
361
+ # another option is to use public validation set
362
+ #rows = rows[rows["video"].isin(PUBLIC_SET)]
363
+
364
+ print(
365
+ "real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
366
+ data = rows.values
367
+
368
+ np.random.seed(seed)
369
+ np.random.shuffle(data)
370
+ return data
371
+
372
+ def _oversample(self, rows: pd.DataFrame, seed):
373
+ real = rows[rows["label"] == 0]
374
+ fakes = rows[rows["label"] == 1]
375
+ num_real = real["video"].count()
376
+ if self.mode == "train":
377
+ fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
378
+ return pd.concat([real, fakes])
training/datasets/validation_set.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
4
+ 'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
5
+ 'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
6
+ 'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
7
+ 'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
8
+ 'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
9
+ 'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
10
+ 'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
11
+ 'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
12
+ 'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
13
+ 'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
14
+ 'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
15
+ 'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
16
+ 'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
17
+ 'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
18
+ 'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
19
+ 'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
20
+ 'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
21
+ 'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
22
+ 'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
23
+ 'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
24
+ 'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
25
+ 'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
26
+ 'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
27
+ 'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
28
+ 'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
29
+ 'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
30
+ 'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
31
+ 'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
32
+ 'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
33
+ 'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
34
+ 'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
35
+ 'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
36
+ 'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
37
+ 'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
38
+ 'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
39
+ 'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
40
+ 'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
41
+ 'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
42
+ 'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
43
+ 'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
44
+ 'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
45
+ 'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
46
+ 'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
47
+ 'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
48
+ 'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
49
+ 'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
50
+ 'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
51
+ 'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
52
+ 'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
53
+ 'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
54
+ 'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
55
+ 'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
56
+ 'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
57
+ 'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
58
+ 'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
59
+ 'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
60
+ 'txmnoyiyte'}
training/losses.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_toolbelt.losses import BinaryFocalLoss
4
+ from torch import nn
5
+ from torch.nn.modules.loss import BCEWithLogitsLoss
6
+
7
+
8
+ class WeightedLosses(nn.Module):
9
+ def __init__(self, losses, weights):
10
+ super().__init__()
11
+ self.losses = losses
12
+ self.weights = weights
13
+
14
+ def forward(self, *input: Any, **kwargs: Any):
15
+ cum_loss = 0
16
+ for loss, w in zip(self.losses, self.weights):
17
+ cum_loss += w * loss.forward(*input, **kwargs)
18
+ return cum_loss
19
+
20
+
21
+ class BinaryCrossentropy(BCEWithLogitsLoss):
22
+ pass
23
+
24
+
25
+ class FocalLoss(BinaryFocalLoss):
26
+ def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
27
+ reduced_threshold=None):
28
+ super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
training/pipelines/__init__.py ADDED
File without changes
training/pipelines/train_classifier.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ from sklearn.metrics import log_loss
7
+ from torch import topk
8
+
9
+ from training import losses
10
+ from training.datasets.classifier_dataset import DeepFakeClassifierDataset
11
+ from training.losses import WeightedLosses
12
+ from training.tools.config import load_config
13
+ from training.tools.utils import create_optimizer, AverageMeter
14
+ from training.transforms.albu import IsotropicResize
15
+ from training.zoo import classifiers
16
+
17
+ os.environ["MKL_NUM_THREADS"] = "1"
18
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
19
+ os.environ["OMP_NUM_THREADS"] = "1"
20
+
21
+ import cv2
22
+
23
+ cv2.ocl.setUseOpenCL(False)
24
+ cv2.setNumThreads(0)
25
+ import numpy as np
26
+ from albumentations import Compose, RandomBrightnessContrast, \
27
+ HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
28
+ ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
29
+
30
+ from apex.parallel import DistributedDataParallel, convert_syncbn_model
31
+ from tensorboardX import SummaryWriter
32
+
33
+ from apex import amp
34
+
35
+ import torch
36
+ from torch.backends import cudnn
37
+ from torch.nn import DataParallel
38
+ from torch.utils.data import DataLoader
39
+ from tqdm import tqdm
40
+ import torch.distributed as dist
41
+
42
+ torch.backends.cudnn.benchmark = True
43
+
44
+
45
+ def create_train_transforms(size=300):
46
+ return Compose([
47
+ ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
48
+ GaussNoise(p=0.1),
49
+ GaussianBlur(blur_limit=3, p=0.05),
50
+ HorizontalFlip(),
51
+ OneOf([
52
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
53
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
54
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
55
+ ], p=1),
56
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
57
+ OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
58
+ ToGray(p=0.2),
59
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
60
+ ]
61
+ )
62
+
63
+
64
+ def create_val_transforms(size=300):
65
+ return Compose([
66
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
67
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
68
+ ])
69
+
70
+
71
+ def main():
72
+ parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
73
+ arg = parser.add_argument
74
+ arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
75
+ arg('--workers', type=int, default=6, help='number of cpu threads to use')
76
+ arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
77
+ arg('--output-dir', type=str, default='weights/')
78
+ arg('--resume', type=str, default='')
79
+ arg('--fold', type=int, default=0)
80
+ arg('--prefix', type=str, default='classifier_')
81
+ arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
82
+ arg('--folds-csv', type=str, default='folds.csv')
83
+ arg('--crops-dir', type=str, default='crops')
84
+ arg('--label-smoothing', type=float, default=0.01)
85
+ arg('--logdir', type=str, default='logs')
86
+ arg('--zero-score', action='store_true', default=False)
87
+ arg('--from-zero', action='store_true', default=False)
88
+ arg('--distributed', action='store_true', default=False)
89
+ arg('--freeze-epochs', type=int, default=0)
90
+ arg("--local_rank", default=0, type=int)
91
+ arg("--seed", default=777, type=int)
92
+ arg("--padding-part", default=3, type=int)
93
+ arg("--opt-level", default='O1', type=str)
94
+ arg("--test_every", type=int, default=1)
95
+ arg("--no-oversample", action="store_true")
96
+ arg("--no-hardcore", action="store_true")
97
+ arg("--only-changed-frames", action="store_true")
98
+
99
+ args = parser.parse_args()
100
+ os.makedirs(args.output_dir, exist_ok=True)
101
+ if args.distributed:
102
+ torch.cuda.set_device(args.local_rank)
103
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
104
+ else:
105
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
106
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
107
+
108
+ cudnn.benchmark = True
109
+
110
+ conf = load_config(args.config)
111
+ model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
112
+
113
+ model = model.cuda()
114
+ if args.distributed:
115
+ model = convert_syncbn_model(model)
116
+ ohem = conf.get("ohem_samples", None)
117
+ reduction = "mean"
118
+ if ohem:
119
+ reduction = "none"
120
+ loss_fn = []
121
+ weights = []
122
+ for loss_name, weight in conf["losses"].items():
123
+ loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
124
+ weights.append(weight)
125
+ loss = WeightedLosses(loss_fn, weights)
126
+ loss_functions = {"classifier_loss": loss}
127
+ optimizer, scheduler = create_optimizer(conf['optimizer'], model)
128
+ bce_best = 100
129
+ start_epoch = 0
130
+ batch_size = conf['optimizer']['batch_size']
131
+
132
+ data_train = DeepFakeClassifierDataset(mode="train",
133
+ oversample_real=not args.no_oversample,
134
+ fold=args.fold,
135
+ padding_part=args.padding_part,
136
+ hardcore=not args.no_hardcore,
137
+ crops_dir=args.crops_dir,
138
+ data_path=args.data_dir,
139
+ label_smoothing=args.label_smoothing,
140
+ folds_csv=args.folds_csv,
141
+ transforms=create_train_transforms(conf["size"]),
142
+ normalize=conf.get("normalize", None))
143
+ data_val = DeepFakeClassifierDataset(mode="val",
144
+ fold=args.fold,
145
+ padding_part=args.padding_part,
146
+ crops_dir=args.crops_dir,
147
+ data_path=args.data_dir,
148
+ folds_csv=args.folds_csv,
149
+ transforms=create_val_transforms(conf["size"]),
150
+ normalize=conf.get("normalize", None))
151
+ val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
152
+ pin_memory=False)
153
+ os.makedirs(args.logdir, exist_ok=True)
154
+ summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
155
+ if args.resume:
156
+ if os.path.isfile(args.resume):
157
+ print("=> loading checkpoint '{}'".format(args.resume))
158
+ checkpoint = torch.load(args.resume, map_location='cpu')
159
+ state_dict = checkpoint['state_dict']
160
+ state_dict = {k[7:]: w for k, w in state_dict.items()}
161
+ model.load_state_dict(state_dict, strict=False)
162
+ if not args.from_zero:
163
+ start_epoch = checkpoint['epoch']
164
+ if not args.zero_score:
165
+ bce_best = checkpoint.get('bce_best', 0)
166
+ print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
167
+ .format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
168
+ else:
169
+ print("=> no checkpoint found at '{}'".format(args.resume))
170
+ if args.from_zero:
171
+ start_epoch = 0
172
+ current_epoch = start_epoch
173
+
174
+ if conf['fp16']:
175
+ model, optimizer = amp.initialize(model, optimizer,
176
+ opt_level=args.opt_level,
177
+ loss_scale='dynamic')
178
+
179
+ snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
180
+
181
+ if args.distributed:
182
+ model = DistributedDataParallel(model, delay_allreduce=True)
183
+ else:
184
+ model = DataParallel(model).cuda()
185
+ data_val.reset(1, args.seed)
186
+ max_epochs = conf['optimizer']['schedule']['epochs']
187
+ for epoch in range(start_epoch, max_epochs):
188
+ data_train.reset(epoch, args.seed)
189
+ train_sampler = None
190
+ if args.distributed:
191
+ train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
192
+ train_sampler.set_epoch(epoch)
193
+ if epoch < args.freeze_epochs:
194
+ print("Freezing encoder!!!")
195
+ model.module.encoder.eval()
196
+ for p in model.module.encoder.parameters():
197
+ p.requires_grad = False
198
+ else:
199
+ model.module.encoder.train()
200
+ for p in model.module.encoder.parameters():
201
+ p.requires_grad = True
202
+
203
+ train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
204
+ shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
205
+ drop_last=True)
206
+
207
+ train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
208
+ args.local_rank, args.only_changed_frames)
209
+ model = model.eval()
210
+
211
+ if args.local_rank == 0:
212
+ torch.save({
213
+ 'epoch': current_epoch + 1,
214
+ 'state_dict': model.state_dict(),
215
+ 'bce_best': bce_best,
216
+ }, args.output_dir + '/' + snapshot_name + "_last")
217
+ torch.save({
218
+ 'epoch': current_epoch + 1,
219
+ 'state_dict': model.state_dict(),
220
+ 'bce_best': bce_best,
221
+ }, args.output_dir + snapshot_name + "_{}".format(current_epoch))
222
+ if (epoch + 1) % args.test_every == 0:
223
+ bce_best = evaluate_val(args, val_data_loader, bce_best, model,
224
+ snapshot_name=snapshot_name,
225
+ current_epoch=current_epoch,
226
+ summary_writer=summary_writer)
227
+ current_epoch += 1
228
+
229
+
230
+ def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
231
+ print("Test phase")
232
+ model = model.eval()
233
+
234
+ bce, probs, targets = validate(model, data_loader=data_val)
235
+ if args.local_rank == 0:
236
+ summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
237
+ if bce < bce_best:
238
+ print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
239
+ if args.output_dir is not None:
240
+ torch.save({
241
+ 'epoch': current_epoch + 1,
242
+ 'state_dict': model.state_dict(),
243
+ 'bce_best': bce,
244
+ }, args.output_dir + snapshot_name + "_best_dice")
245
+ bce_best = bce
246
+ with open("predictions_{}.json".format(args.fold), "w") as f:
247
+ json.dump({"probs": probs, "targets": targets}, f)
248
+ torch.save({
249
+ 'epoch': current_epoch + 1,
250
+ 'state_dict': model.state_dict(),
251
+ 'bce_best': bce_best,
252
+ }, args.output_dir + snapshot_name + "_last")
253
+ print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
254
+ return bce_best
255
+
256
+
257
+ def validate(net, data_loader, prefix=""):
258
+ probs = defaultdict(list)
259
+ targets = defaultdict(list)
260
+
261
+ with torch.no_grad():
262
+ for sample in tqdm(data_loader):
263
+ imgs = sample["image"].cuda()
264
+ img_names = sample["img_name"]
265
+ labels = sample["labels"].cuda().float()
266
+ out = net(imgs)
267
+ labels = labels.cpu().numpy()
268
+ preds = torch.sigmoid(out).cpu().numpy()
269
+ for i in range(out.shape[0]):
270
+ video, img_id = img_names[i].split("/")
271
+ probs[video].append(preds[i].tolist())
272
+ targets[video].append(labels[i].tolist())
273
+ data_x = []
274
+ data_y = []
275
+ for vid, score in probs.items():
276
+ score = np.array(score)
277
+ lbl = targets[vid]
278
+
279
+ score = np.mean(score)
280
+ lbl = np.mean(lbl)
281
+ data_x.append(score)
282
+ data_y.append(lbl)
283
+ y = np.array(data_y)
284
+ x = np.array(data_x)
285
+ fake_idx = y > 0.1
286
+ real_idx = y < 0.1
287
+ fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
288
+ real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
289
+ print("{}fake_loss".format(prefix), fake_loss)
290
+ print("{}real_loss".format(prefix), real_loss)
291
+
292
+ return (fake_loss + real_loss) / 2, probs, targets
293
+
294
+
295
+ def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
296
+ local_rank, only_valid):
297
+ losses = AverageMeter()
298
+ fake_losses = AverageMeter()
299
+ real_losses = AverageMeter()
300
+ max_iters = conf["batches_per_epoch"]
301
+ print("training epoch {}".format(current_epoch))
302
+ model.train()
303
+ pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
304
+ if conf["optimizer"]["schedule"]["mode"] == "epoch":
305
+ scheduler.step(current_epoch)
306
+ for i, sample in pbar:
307
+ imgs = sample["image"].cuda()
308
+ labels = sample["labels"].cuda().float()
309
+ out_labels = model(imgs)
310
+ if only_valid:
311
+ valid_idx = sample["valid"].cuda().float() > 0
312
+ out_labels = out_labels[valid_idx]
313
+ labels = labels[valid_idx]
314
+ if labels.size(0) == 0:
315
+ continue
316
+
317
+ fake_loss = 0
318
+ real_loss = 0
319
+ fake_idx = labels > 0.5
320
+ real_idx = labels <= 0.5
321
+
322
+ ohem = conf.get("ohem_samples", None)
323
+ if torch.sum(fake_idx * 1) > 0:
324
+ fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
325
+ if torch.sum(real_idx * 1) > 0:
326
+ real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
327
+ if ohem:
328
+ fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
329
+ real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
330
+
331
+ loss = (fake_loss + real_loss) / 2
332
+ losses.update(loss.item(), imgs.size(0))
333
+ fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
334
+ real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
335
+
336
+ optimizer.zero_grad()
337
+ pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
338
+ "fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
339
+
340
+ if conf['fp16']:
341
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
342
+ scaled_loss.backward()
343
+ else:
344
+ loss.backward()
345
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
346
+ optimizer.step()
347
+ torch.cuda.synchronize()
348
+ if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
349
+ scheduler.step(i + current_epoch * max_iters)
350
+ if i == max_iters - 1:
351
+ break
352
+ pbar.close()
353
+ if local_rank == 0:
354
+ for idx, param_group in enumerate(optimizer.param_groups):
355
+ lr = param_group['lr']
356
+ summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
357
+ summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
358
+
359
+
360
+ if __name__ == '__main__':
361
+ main()
training/tools/__init__.py ADDED
File without changes
training/tools/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ DEFAULTS = {
4
+ "network": "dpn",
5
+ "encoder": "dpn92",
6
+ "model_params": {},
7
+ "optimizer": {
8
+ "batch_size": 32,
9
+ "type": "SGD", # supported: SGD, Adam
10
+ "momentum": 0.9,
11
+ "weight_decay": 0,
12
+ "clip": 1.,
13
+ "learning_rate": 0.1,
14
+ "classifier_lr": -1,
15
+ "nesterov": True,
16
+ "schedule": {
17
+ "type": "constant", # supported: constant, step, multistep, exponential, linear, poly
18
+ "mode": "epoch", # supported: epoch, step
19
+ "epochs": 10,
20
+ "params": {}
21
+ }
22
+ },
23
+ "normalize": {
24
+ "mean": [0.485, 0.456, 0.406],
25
+ "std": [0.229, 0.224, 0.225]
26
+ }
27
+ }
28
+
29
+
30
+ def _merge(src, dst):
31
+ for k, v in src.items():
32
+ if k in dst:
33
+ if isinstance(v, dict):
34
+ _merge(src[k], dst[k])
35
+ else:
36
+ dst[k] = v
37
+
38
+
39
+ def load_config(config_file, defaults=DEFAULTS):
40
+ with open(config_file, "r") as fd:
41
+ config = json.load(fd)
42
+ _merge(defaults, config)
43
+ return config
training/tools/schedulers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_right
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class LRStepScheduler(_LRScheduler):
7
+ def __init__(self, optimizer, steps, last_epoch=-1):
8
+ self.lr_steps = steps
9
+ super().__init__(optimizer, last_epoch)
10
+
11
+ def get_lr(self):
12
+ pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
13
+ return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
14
+
15
+
16
+ class PolyLR(_LRScheduler):
17
+ """Sets the learning rate of each parameter group according to poly learning rate policy
18
+ """
19
+ def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
20
+ self.max_iter = max_iter
21
+ self.power = power
22
+ super(PolyLR, self).__init__(optimizer, last_epoch)
23
+
24
+ def get_lr(self):
25
+ self.last_epoch = (self.last_epoch + 1) % self.max_iter
26
+ return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
27
+
28
+ class ExponentialLRScheduler(_LRScheduler):
29
+ """Decays the learning rate of each parameter group by gamma every epoch.
30
+ When last_epoch=-1, sets initial lr as lr.
31
+
32
+ Args:
33
+ optimizer (Optimizer): Wrapped optimizer.
34
+ gamma (float): Multiplicative factor of learning rate decay.
35
+ last_epoch (int): The index of last epoch. Default: -1.
36
+ """
37
+
38
+ def __init__(self, optimizer, gamma, last_epoch=-1):
39
+ self.gamma = gamma
40
+ super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
41
+
42
+ def get_lr(self):
43
+ if self.last_epoch <= 0:
44
+ return self.base_lrs
45
+ return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
46
+
training/tools/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from apex.optimizers import FusedAdam, FusedSGD
3
+ from timm.optim import AdamW
4
+ from torch import optim
5
+ from torch.optim import lr_scheduler
6
+ from torch.optim.rmsprop import RMSprop
7
+ from torch.optim.adamw import AdamW
8
+ from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
9
+
10
+ from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
11
+
12
+ cv2.ocl.setUseOpenCL(False)
13
+ cv2.setNumThreads(0)
14
+
15
+
16
+ class AverageMeter(object):
17
+ """Computes and stores the average and current value"""
18
+
19
+ def __init__(self):
20
+ self.reset()
21
+
22
+ def reset(self):
23
+ self.val = 0
24
+ self.avg = 0
25
+ self.sum = 0
26
+ self.count = 0
27
+
28
+ def update(self, val, n=1):
29
+ self.val = val
30
+ self.sum += val * n
31
+ self.count += n
32
+ self.avg = self.sum / self.count
33
+
34
+ def create_optimizer(optimizer_config, model, master_params=None):
35
+ """Creates optimizer and schedule from configuration
36
+
37
+ Parameters
38
+ ----------
39
+ optimizer_config : dict
40
+ Dictionary containing the configuration options for the optimizer.
41
+ model : Model
42
+ The network model.
43
+
44
+ Returns
45
+ -------
46
+ optimizer : Optimizer
47
+ The optimizer.
48
+ scheduler : LRScheduler
49
+ The learning rate scheduler.
50
+ """
51
+ if optimizer_config.get("classifier_lr", -1) != -1:
52
+ # Separate classifier parameters from all others
53
+ net_params = []
54
+ classifier_params = []
55
+ for k, v in model.named_parameters():
56
+ if not v.requires_grad:
57
+ continue
58
+ if k.find("encoder") != -1:
59
+ net_params.append(v)
60
+ else:
61
+ classifier_params.append(v)
62
+ params = [
63
+ {"params": net_params},
64
+ {"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
65
+ ]
66
+ else:
67
+ if master_params:
68
+ params = master_params
69
+ else:
70
+ params = model.parameters()
71
+
72
+ if optimizer_config["type"] == "SGD":
73
+ optimizer = optim.SGD(params,
74
+ lr=optimizer_config["learning_rate"],
75
+ momentum=optimizer_config["momentum"],
76
+ weight_decay=optimizer_config["weight_decay"],
77
+ nesterov=optimizer_config["nesterov"])
78
+ elif optimizer_config["type"] == "FusedSGD":
79
+ optimizer = FusedSGD(params,
80
+ lr=optimizer_config["learning_rate"],
81
+ momentum=optimizer_config["momentum"],
82
+ weight_decay=optimizer_config["weight_decay"],
83
+ nesterov=optimizer_config["nesterov"])
84
+ elif optimizer_config["type"] == "Adam":
85
+ optimizer = optim.Adam(params,
86
+ lr=optimizer_config["learning_rate"],
87
+ weight_decay=optimizer_config["weight_decay"])
88
+ elif optimizer_config["type"] == "FusedAdam":
89
+ optimizer = FusedAdam(params,
90
+ lr=optimizer_config["learning_rate"],
91
+ weight_decay=optimizer_config["weight_decay"])
92
+ elif optimizer_config["type"] == "AdamW":
93
+ optimizer = AdamW(params,
94
+ lr=optimizer_config["learning_rate"],
95
+ weight_decay=optimizer_config["weight_decay"])
96
+ elif optimizer_config["type"] == "RmsProp":
97
+ optimizer = RMSprop(params,
98
+ lr=optimizer_config["learning_rate"],
99
+ weight_decay=optimizer_config["weight_decay"])
100
+ else:
101
+ raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
102
+
103
+ if optimizer_config["schedule"]["type"] == "step":
104
+ scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
105
+ elif optimizer_config["schedule"]["type"] == "clr":
106
+ scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
107
+ elif optimizer_config["schedule"]["type"] == "multistep":
108
+ scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
109
+ elif optimizer_config["schedule"]["type"] == "exponential":
110
+ scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
111
+ elif optimizer_config["schedule"]["type"] == "poly":
112
+ scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
113
+ elif optimizer_config["schedule"]["type"] == "constant":
114
+ scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
115
+ elif optimizer_config["schedule"]["type"] == "linear":
116
+ def linear_lr(it):
117
+ return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
118
+
119
+ scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
120
+
121
+ return optimizer, scheduler
training/transforms/__init__.py ADDED
File without changes
training/transforms/albu.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from albumentations import DualTransform, ImageOnlyTransform
6
+ from albumentations.augmentations.functional import crop
7
+
8
+
9
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
10
+ h, w = img.shape[:2]
11
+ if max(w, h) == size:
12
+ return img
13
+ if w > h:
14
+ scale = size / w
15
+ h = h * scale
16
+ w = size
17
+ else:
18
+ scale = size / h
19
+ w = w * scale
20
+ h = size
21
+ interpolation = interpolation_up if scale > 1 else interpolation_down
22
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
23
+ return resized
24
+
25
+
26
+ class IsotropicResize(DualTransform):
27
+ def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
28
+ always_apply=False, p=1):
29
+ super(IsotropicResize, self).__init__(always_apply, p)
30
+ self.max_side = max_side
31
+ self.interpolation_down = interpolation_down
32
+ self.interpolation_up = interpolation_up
33
+
34
+ def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
35
+ return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
36
+ interpolation_up=interpolation_up)
37
+
38
+ def apply_to_mask(self, img, **params):
39
+ return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
40
+
41
+ def get_transform_init_args_names(self):
42
+ return ("max_side", "interpolation_down", "interpolation_up")
43
+
44
+
45
+ class Resize4xAndBack(ImageOnlyTransform):
46
+ def __init__(self, always_apply=False, p=0.5):
47
+ super(Resize4xAndBack, self).__init__(always_apply, p)
48
+
49
+ def apply(self, img, **params):
50
+ h, w = img.shape[:2]
51
+ scale = random.choice([2, 4])
52
+ img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
53
+ img = cv2.resize(img, (w, h),
54
+ interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
55
+ return img
56
+
57
+
58
+ class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
59
+
60
+ def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
61
+ super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
62
+
63
+ self.min_max_height = min_max_height
64
+ self.w2h_ratio = w2h_ratio
65
+
66
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
67
+ cropped = crop(img, x_min, y_min, x_max, y_max)
68
+ return cropped
69
+
70
+ @property
71
+ def targets_as_params(self):
72
+ return ["mask"]
73
+
74
+ def get_params_dependent_on_targets(self, params):
75
+ mask = params["mask"]
76
+ mask_height, mask_width = mask.shape[:2]
77
+ crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
78
+ w2h_ratio = random.uniform(*self.w2h_ratio)
79
+ crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
80
+ if mask.sum() == 0:
81
+ x_min = random.randint(0, mask_width - crop_width + 1)
82
+ y_min = random.randint(0, mask_height - crop_height + 1)
83
+ else:
84
+ mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
85
+ non_zero_yx = np.argwhere(mask)
86
+ y, x = random.choice(non_zero_yx)
87
+ x_min = x - random.randint(0, crop_width - 1)
88
+ y_min = y - random.randint(0, crop_height - 1)
89
+ x_min = np.clip(x_min, 0, mask_width - crop_width)
90
+ y_min = np.clip(y_min, 0, mask_height - crop_height)
91
+
92
+ x_max = x_min + crop_height
93
+ y_max = y_min + crop_width
94
+ y_max = min(mask_height, y_max)
95
+ x_max = min(mask_width, x_max)
96
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
97
+
98
+ def get_transform_init_args_names(self):
99
+ return "min_max_height", "height", "width", "w2h_ratio"
training/zoo/__init__.py ADDED
File without changes
training/zoo/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (161 Bytes). View file
 
training/zoo/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (165 Bytes). View file
 
training/zoo/__pycache__/classifiers.cpython-37.pyc ADDED
Binary file (5.69 kB). View file
 
training/zoo/__pycache__/classifiers.cpython-39.pyc ADDED
Binary file (5.7 kB). View file
 
training/zoo/classifiers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
6
+ tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7
7
+ from torch import nn
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.linear import Linear
10
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
16
+ },
17
+ "tf_efficientnet_b2_ns": {
18
+ "features": 1408,
19
+ "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
20
+ },
21
+ "tf_efficientnet_b4_ns": {
22
+ "features": 1792,
23
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
24
+ },
25
+ "tf_efficientnet_b5_ns": {
26
+ "features": 2048,
27
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
28
+ },
29
+ "tf_efficientnet_b4_ns_03d": {
30
+ "features": 1792,
31
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
32
+ },
33
+ "tf_efficientnet_b5_ns_03d": {
34
+ "features": 2048,
35
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
36
+ },
37
+ "tf_efficientnet_b5_ns_04d": {
38
+ "features": 2048,
39
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
40
+ },
41
+ "tf_efficientnet_b6_ns": {
42
+ "features": 2304,
43
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
44
+ },
45
+ "tf_efficientnet_b7": {
46
+ "features": 2560,
47
+ "init_op": partial(tf_efficientnet_b7, pretrained=True, drop_path_rate=0.2)
48
+ },
49
+ "tf_efficientnet_b6_ns_04d": {
50
+ "features": 2304,
51
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
52
+ },
53
+ }
54
+
55
+
56
+ def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
57
+ """Creates the SRM kernels for noise analysis."""
58
+ # note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
59
+ srm_kernel = torch.from_numpy(np.array([
60
+ [ # srm 1/2 horiz
61
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
62
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
63
+ [0., 1., -2., 1., 0.], # noqa: E241,E201
64
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
65
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
66
+ ], [ # srm 1/4
67
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
68
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
69
+ [0., 2., -4., 2., 0.], # noqa: E241,E201
70
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
71
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
72
+ ], [ # srm 1/12
73
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
74
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
75
+ [-2., 8., -12., 8., -2.], # noqa: E241,E201
76
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
77
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
78
+ ]
79
+ ])).float()
80
+ srm_kernel[0] /= 2
81
+ srm_kernel[1] /= 4
82
+ srm_kernel[2] /= 12
83
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
84
+
85
+
86
+ def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
87
+ """Creates a SRM convolution layer for noise analysis."""
88
+ weights = setup_srm_weights(input_channels)
89
+ conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
90
+ with torch.no_grad():
91
+ conv.weight = torch.nn.Parameter(weights, requires_grad=False)
92
+ return conv
93
+
94
+
95
+ class DeepFakeClassifierSRM(nn.Module):
96
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
97
+ super().__init__()
98
+ self.encoder = encoder_params[encoder]["init_op"]()
99
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
100
+ self.srm_conv = setup_srm_layer(3)
101
+ self.dropout = Dropout(dropout_rate)
102
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
103
+
104
+ def forward(self, x):
105
+ noise = self.srm_conv(x)
106
+ x = self.encoder.forward_features(noise)
107
+ x = self.avg_pool(x).flatten(1)
108
+ x = self.dropout(x)
109
+ x = self.fc(x)
110
+ return x
111
+
112
+
113
+ class GlobalWeightedAvgPool2d(nn.Module):
114
+ """
115
+ Global Weighted Average Pooling from paper "Global Weighted Average
116
+ Pooling Bridges Pixel-level Localization and Image-level Classification"
117
+ """
118
+
119
+ def __init__(self, features: int, flatten=False):
120
+ super().__init__()
121
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
122
+ self.flatten = flatten
123
+
124
+ def fscore(self, x):
125
+ m = self.conv(x)
126
+ m = m.sigmoid().exp()
127
+ return m
128
+
129
+ def norm(self, x: torch.Tensor):
130
+ return x / x.sum(dim=[2, 3], keepdim=True)
131
+
132
+ def forward(self, x):
133
+ input_x = x
134
+ x = self.fscore(x)
135
+ x = self.norm(x)
136
+ x = x * input_x
137
+ x = x.sum(dim=[2, 3], keepdim=not self.flatten)
138
+ return x
139
+
140
+
141
+ class DeepFakeClassifier(nn.Module):
142
+ def __init__(self, encoder, dropout_rate=0.0) -> None:
143
+ super().__init__()
144
+ self.encoder = encoder_params[encoder]["init_op"]()
145
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
146
+ self.dropout = Dropout(dropout_rate)
147
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
148
+
149
+ def forward(self, x):
150
+ x = self.encoder.forward_features(x)
151
+ x = self.avg_pool(x).flatten(1)
152
+ x = self.dropout(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
161
+ super().__init__()
162
+ self.encoder = encoder_params[encoder]["init_op"]()
163
+ self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x).flatten(1)
170
+ x = self.dropout(x)
171
+ x = self.fc(x)
172
+ return x
training/zoo/unet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
5
+ from torch import nn
6
+ from torch.nn import Dropout2d, Conv2d
7
+ from torch.nn.modules.dropout import Dropout
8
+ from torch.nn.modules.linear import Linear
9
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
10
+ from torch.nn.modules.upsampling import UpsamplingBilinear2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "filters": [40, 32, 48, 136, 1536],
16
+ "decoder_filters": [64, 128, 256, 256],
17
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
18
+ },
19
+ "tf_efficientnet_b5_ns": {
20
+ "features": 2048,
21
+ "filters": [48, 40, 64, 176, 2048],
22
+ "decoder_filters": [64, 128, 256, 256],
23
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
24
+ },
25
+ }
26
+
27
+
28
+ class DecoderBlock(nn.Module):
29
+ def __init__(self, in_channels, out_channels):
30
+ super().__init__()
31
+ self.layer = nn.Sequential(
32
+ nn.Upsample(scale_factor=2),
33
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
34
+ nn.ReLU(inplace=True)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layer(x)
39
+
40
+
41
+ class ConcatBottleneck(nn.Module):
42
+ def __init__(self, in_channels, out_channels):
43
+ super().__init__()
44
+ self.seq = nn.Sequential(
45
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+
49
+ def forward(self, dec, enc):
50
+ x = torch.cat([dec, enc], dim=1)
51
+ return self.seq(x)
52
+
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, decoder_filters, filters, upsample_filters=None,
56
+ decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
57
+ super().__init__()
58
+ self.decoder_filters = decoder_filters
59
+ self.filters = filters
60
+ self.decoder_block = decoder_block
61
+ self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
62
+ self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
63
+ for i, f in enumerate(reversed(decoder_filters))])
64
+ self.dropout = Dropout2d(dropout) if dropout > 0 else None
65
+ self.last_block = None
66
+ if upsample_filters:
67
+ self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
68
+ else:
69
+ self.last_block = UpsamplingBilinear2d(scale_factor=2)
70
+
71
+ def forward(self, encoder_results: list):
72
+ x = encoder_results[0]
73
+ bottlenecks = self.bottlenecks
74
+ for idx, bottleneck in enumerate(bottlenecks):
75
+ rev_idx = - (idx + 1)
76
+ x = self.decoder_stages[rev_idx](x)
77
+ x = bottleneck(x, encoder_results[-rev_idx])
78
+ if self.last_block:
79
+ x = self.last_block(x)
80
+ if self.dropout:
81
+ x = self.dropout(x)
82
+ return x
83
+
84
+ def _get_decoder(self, layer):
85
+ idx = layer + 1
86
+ if idx == len(self.decoder_filters):
87
+ in_channels = self.filters[idx]
88
+ else:
89
+ in_channels = self.decoder_filters[idx]
90
+ return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
91
+
92
+
93
+ def _initialize_weights(module):
94
+ for m in module.modules():
95
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
96
+ m.weight.data = nn.init.kaiming_normal_(m.weight.data)
97
+ if m.bias is not None:
98
+ m.bias.data.zero_()
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class EfficientUnetClassifier(nn.Module):
105
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
106
+ super().__init__()
107
+ self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
108
+ filters=encoder_params[encoder]["filters"])
109
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
110
+ self.dropout = Dropout(dropout_rate)
111
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
112
+ self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
113
+ _initialize_weights(self)
114
+ self.encoder = encoder_params[encoder]["init_op"]()
115
+
116
+ def get_encoder_features(self, x):
117
+ encoder_results = []
118
+ x = self.encoder.conv_stem(x)
119
+ x = self.encoder.bn1(x)
120
+ x = self.encoder.act1(x)
121
+ encoder_results.append(x)
122
+ x = self.encoder.blocks[:2](x)
123
+ encoder_results.append(x)
124
+ x = self.encoder.blocks[2:3](x)
125
+ encoder_results.append(x)
126
+ x = self.encoder.blocks[3:5](x)
127
+ encoder_results.append(x)
128
+ x = self.encoder.blocks[5:](x)
129
+ x = self.encoder.conv_head(x)
130
+ x = self.encoder.bn2(x)
131
+ x = self.encoder.act2(x)
132
+ encoder_results.append(x)
133
+ encoder_results = list(reversed(encoder_results))
134
+ return encoder_results
135
+
136
+ def forward(self, x):
137
+ encoder_results = self.get_encoder_features(x)
138
+ seg = self.final(self.decoder(encoder_results))
139
+ x = encoder_results[0]
140
+ x = self.avg_pool(x).flatten(1)
141
+ x = self.dropout(x)
142
+ x = self.fc(x)
143
+ return x, seg
144
+
145
+
146
+ if __name__ == '__main__':
147
+ model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
148
+ model.eval()
149
+ with torch.no_grad():
150
+ input = torch.rand(4, 3, 224, 224)
151
+ print(model(input))
utils.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from albumentations.augmentations.functional import image_compression
8
+ from facenet_pytorch.models.mtcnn import MTCNN
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from torchvision.transforms import Normalize
12
+
13
+ mean = [0.485, 0.456, 0.406]
14
+ std = [0.229, 0.224, 0.225]
15
+ normalize_transform = Normalize(mean, std)
16
+
17
+
18
+ class VideoReader:
19
+ """Helper class for reading one or more frames from a video file."""
20
+
21
+ def __init__(self, verbose=True, insets=(0, 0)):
22
+ """Creates a new VideoReader.
23
+
24
+ Arguments:
25
+ verbose: whether to print warnings and error messages
26
+ insets: amount to inset the image by, as a percentage of
27
+ (width, height). This lets you "zoom in" to an image
28
+ to remove unimportant content around the borders.
29
+ Useful for face detection, which may not work if the
30
+ faces are too small.
31
+ """
32
+ self.verbose = verbose
33
+ self.insets = insets
34
+
35
+ def read_frames(self, path, num_frames, jitter=0, seed=None):
36
+ """Reads frames that are always evenly spaced throughout the video.
37
+
38
+ Arguments:
39
+ path: the video file
40
+ num_frames: how many frames to read, -1 means the entire video
41
+ (warning: this will take up a lot of memory!)
42
+ jitter: if not 0, adds small random offsets to the frame indices;
43
+ this is useful so we don't always land on even or odd frames
44
+ seed: random seed for jittering; if you set this to a fixed value,
45
+ you probably want to set it only on the first video
46
+ """
47
+ assert num_frames > 0
48
+
49
+ capture = cv2.VideoCapture(path)
50
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if frame_count <= 0: return None
52
+
53
+ frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int)
54
+ if jitter > 0:
55
+ np.random.seed(seed)
56
+ jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
57
+ frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
58
+
59
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
60
+ capture.release()
61
+ return result
62
+
63
+ def read_random_frames(self, path, num_frames, seed=None):
64
+ """Picks the frame indices at random.
65
+
66
+ Arguments:
67
+ path: the video file
68
+ num_frames: how many frames to read, -1 means the entire video
69
+ (warning: this will take up a lot of memory!)
70
+ """
71
+ assert num_frames > 0
72
+ np.random.seed(seed)
73
+
74
+ capture = cv2.VideoCapture(path)
75
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if frame_count <= 0: return None
77
+
78
+ frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
79
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
80
+
81
+ capture.release()
82
+ return result
83
+
84
+ def read_frames_at_indices(self, path, frame_idxs):
85
+ """Reads frames from a video and puts them into a NumPy array.
86
+
87
+ Arguments:
88
+ path: the video file
89
+ frame_idxs: a list of frame indices. Important: should be
90
+ sorted from low-to-high! If an index appears multiple
91
+ times, the frame is still read only once.
92
+
93
+ Returns:
94
+ - a NumPy array of shape (num_frames, height, width, 3)
95
+ - a list of the frame indices that were read
96
+
97
+ Reading stops if loading a frame fails, in which case the first
98
+ dimension returned may actually be less than num_frames.
99
+
100
+ Returns None if an exception is thrown for any reason, or if no
101
+ frames were read.
102
+ """
103
+ assert len(frame_idxs) > 0
104
+ capture = cv2.VideoCapture(path)
105
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
106
+ capture.release()
107
+ return result
108
+
109
+ def _read_frames_at_indices(self, path, capture, frame_idxs):
110
+ try:
111
+ frames = []
112
+ idxs_read = []
113
+ for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
114
+ # Get the next frame, but don't decode if we're not using it.
115
+ ret = capture.grab()
116
+ if not ret:
117
+ if self.verbose:
118
+ print("Error grabbing frame %d from movie %s" % (frame_idx, path))
119
+ break
120
+
121
+ # Need to look at this frame?
122
+ current = len(idxs_read)
123
+ if frame_idx == frame_idxs[current]:
124
+ ret, frame = capture.retrieve()
125
+ if not ret or frame is None:
126
+ if self.verbose:
127
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
128
+ break
129
+
130
+ frame = self._postprocess_frame(frame)
131
+ frames.append(frame)
132
+ idxs_read.append(frame_idx)
133
+
134
+ if len(frames) > 0:
135
+ return np.stack(frames), idxs_read
136
+ if self.verbose:
137
+ print("No frames read from movie %s" % path)
138
+ return None
139
+ except:
140
+ if self.verbose:
141
+ print("Exception while reading movie %s" % path)
142
+ return None
143
+
144
+ def read_middle_frame(self, path):
145
+ """Reads the frame from the middle of the video."""
146
+ capture = cv2.VideoCapture(path)
147
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ result = self._read_frame_at_index(path, capture, frame_count // 2)
149
+ capture.release()
150
+ return result
151
+
152
+ def read_frame_at_index(self, path, frame_idx):
153
+ """Reads a single frame from a video.
154
+
155
+ If you just want to read a single frame from the video, this is more
156
+ efficient than scanning through the video to find the frame. However,
157
+ for reading multiple frames it's not efficient.
158
+
159
+ My guess is that a "streaming" approach is more efficient than a
160
+ "random access" approach because, unless you happen to grab a keyframe,
161
+ the decoder still needs to read all the previous frames in order to
162
+ reconstruct the one you're asking for.
163
+
164
+ Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
165
+ or None if reading failed.
166
+ """
167
+ capture = cv2.VideoCapture(path)
168
+ result = self._read_frame_at_index(path, capture, frame_idx)
169
+ capture.release()
170
+ return result
171
+
172
+ def _read_frame_at_index(self, path, capture, frame_idx):
173
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
174
+ ret, frame = capture.read()
175
+ if not ret or frame is None:
176
+ if self.verbose:
177
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
178
+ return None
179
+ else:
180
+ frame = self._postprocess_frame(frame)
181
+ return np.expand_dims(frame, axis=0), [frame_idx]
182
+
183
+ def _postprocess_frame(self, frame):
184
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+
186
+ if self.insets[0] > 0:
187
+ W = frame.shape[1]
188
+ p = int(W * self.insets[0])
189
+ frame = frame[:, p:-p, :]
190
+
191
+ if self.insets[1] > 0:
192
+ H = frame.shape[1]
193
+ q = int(H * self.insets[1])
194
+ frame = frame[q:-q, :, :]
195
+
196
+ return frame
197
+
198
+
199
+ class FaceExtractor:
200
+ def __init__(self, video_read_fn):
201
+ self.video_read_fn = video_read_fn
202
+ self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cpu")
203
+
204
+ def process_videos(self, video, video_idxs):
205
+ videos_read = []
206
+ frames_read = []
207
+ frames = []
208
+ results = []
209
+ for video_idx in video_idxs:
210
+ # Read the full-size frames from this video.
211
+ result = self.video_read_fn(video)
212
+ # Error? Then skip this video.
213
+ if result is None: continue
214
+
215
+ videos_read.append(video_idx)
216
+
217
+ # Keep track of the original frames (need them later).
218
+ my_frames, my_idxs = result
219
+
220
+ frames.append(my_frames)
221
+ frames_read.append(my_idxs)
222
+ for i, frame in enumerate(my_frames):
223
+ h, w = frame.shape[:2]
224
+ img = Image.fromarray(frame.astype(np.uint8))
225
+ img = img.resize(size=[s // 2 for s in img.size])
226
+
227
+ batch_boxes, probs = self.detector.detect(img, landmarks=False)
228
+
229
+ faces = []
230
+ scores = []
231
+ if batch_boxes is None:
232
+ continue
233
+ for bbox, score in zip(batch_boxes, probs):
234
+ if bbox is not None:
235
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
236
+ w = xmax - xmin
237
+ h = ymax - ymin
238
+ p_h = h // 3
239
+ p_w = w // 3
240
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
241
+ faces.append(crop)
242
+ scores.append(score)
243
+
244
+ frame_dict = {"video_idx": video_idx,
245
+ "frame_idx": my_idxs[i],
246
+ "frame_w": w,
247
+ "frame_h": h,
248
+ "faces": faces,
249
+ "scores": scores}
250
+ results.append(frame_dict)
251
+
252
+ return results
253
+
254
+ def process_video(self, video):
255
+ """Convenience method for doing face extraction on a single video."""
256
+ return self.process_videos(video, [0])
257
+
258
+
259
+
260
+ def confident_strategy(pred, t=0.8):
261
+ pred = np.array(pred)
262
+ sz = len(pred)
263
+ fakes = np.count_nonzero(pred > t)
264
+ # 11 frames are detected as fakes with high probability
265
+ if fakes > sz // 2.5 and fakes > 11:
266
+ return np.mean(pred[pred > t])
267
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
268
+ return np.mean(pred[pred < 0.2])
269
+ else:
270
+ return np.mean(pred)
271
+
272
+ strategy = confident_strategy
273
+
274
+
275
+ def put_to_center(img, input_size):
276
+ img = img[:input_size, :input_size]
277
+ image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
278
+ start_w = (input_size - img.shape[1]) // 2
279
+ start_h = (input_size - img.shape[0]) // 2
280
+ image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
281
+ return image
282
+
283
+
284
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
285
+ h, w = img.shape[:2]
286
+ if max(w, h) == size:
287
+ return img
288
+ if w > h:
289
+ scale = size / w
290
+ h = h * scale
291
+ w = size
292
+ else:
293
+ scale = size / h
294
+ w = w * scale
295
+ h = size
296
+ interpolation = interpolation_up if scale > 1 else interpolation_down
297
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
298
+ return resized
299
+
300
+
301
+ def predict_on_video(face_extractor, video, batch_size, input_size, models, strategy=np.mean,
302
+ apply_compression=False):
303
+ batch_size *= 4
304
+ try:
305
+ faces = face_extractor.process_video(video)
306
+ if len(faces) > 0:
307
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
308
+ n = 0
309
+ for frame_data in faces:
310
+ for face in frame_data["faces"]:
311
+ resized_face = isotropically_resize_image(face, input_size)
312
+ resized_face = put_to_center(resized_face, input_size)
313
+ if apply_compression:
314
+ resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
315
+ if n + 1 < batch_size:
316
+ x[n] = resized_face
317
+ n += 1
318
+ else:
319
+ pass
320
+ if n > 0:
321
+ x = torch.tensor(x, device="cpu").float()
322
+ # Preprocess the images.
323
+ x = x.permute((0, 3, 1, 2))
324
+ for i in range(len(x)):
325
+ x[i] = normalize_transform(x[i] / 255.)
326
+ # Make a prediction, then take the average.
327
+ with torch.no_grad():
328
+ preds = []
329
+ for model in models:
330
+ y_pred = model(x[:n].float())
331
+ y_pred = torch.sigmoid(y_pred.squeeze())
332
+ bpred = y_pred[:n].cpu().numpy()
333
+ preds.append(strategy(bpred))
334
+ return np.mean(preds)
335
+ except Exception as e:
336
+ print("Prediction error on video: %s" % str(e))
337
+
338
+ return 0.5
339
+
340
+
341
+ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
342
+ strategy=np.mean,
343
+ apply_compression=False):
344
+ def process_file(i):
345
+ filename = videos[i]
346
+ y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
347
+ input_size=input_size,
348
+ batch_size=frames_per_video,
349
+ models=models, strategy=strategy, apply_compression=apply_compression)
350
+ return y_pred
351
+
352
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
353
+ predictions = ex.map(process_file, range(len(videos)))
354
+ return list(predictions)
weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
3
+ size 266910617