Spaces:
Runtime error
Runtime error
russel0719
commited on
Commit
·
55da56b
1
Parent(s):
2638f8b
Upload 38 files
Browse files- .gitattributes +3 -0
- app.py +52 -0
- preprocessing/__init__.py +1 -0
- preprocessing/compress_videos.py +45 -0
- preprocessing/detect_original_faces.py +51 -0
- preprocessing/extract_crops.py +86 -0
- preprocessing/extract_images.py +42 -0
- preprocessing/face_detector.py +72 -0
- preprocessing/face_encodings.py +55 -0
- preprocessing/generate_diffs.py +73 -0
- preprocessing/generate_folds.py +114 -0
- preprocessing/generate_landmarks.py +75 -0
- preprocessing/utils.py +51 -0
- sample/sample1.mp4 +3 -0
- sample/sample2.mp4 +3 -0
- training/__init__.py +0 -0
- training/__pycache__/__init__.cpython-37.pyc +0 -0
- training/__pycache__/__init__.cpython-39.pyc +0 -0
- training/datasets/__init__.py +0 -0
- training/datasets/classifier_dataset.py +378 -0
- training/datasets/validation_set.py +60 -0
- training/losses.py +28 -0
- training/pipelines/__init__.py +0 -0
- training/pipelines/train_classifier.py +361 -0
- training/tools/__init__.py +0 -0
- training/tools/config.py +43 -0
- training/tools/schedulers.py +46 -0
- training/tools/utils.py +121 -0
- training/transforms/__init__.py +0 -0
- training/transforms/albu.py +99 -0
- training/zoo/__init__.py +0 -0
- training/zoo/__pycache__/__init__.cpython-37.pyc +0 -0
- training/zoo/__pycache__/__init__.cpython-39.pyc +0 -0
- training/zoo/__pycache__/classifiers.cpython-37.pyc +0 -0
- training/zoo/__pycache__/classifiers.cpython-39.pyc +0 -0
- training/zoo/classifiers.py +172 -0
- training/zoo/unet.py +151 -0
- utils.py +354 -0
- weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 +3 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
sample/sample1.mp4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sample/sample2.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import torch
|
5 |
+
from utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
|
6 |
+
from training.zoo.classifiers import DeepFakeClassifier
|
7 |
+
|
8 |
+
def detect(video):
|
9 |
+
# Load model
|
10 |
+
model = DeepFakeClassifier(encoder="tf_efficientnet_b7")
|
11 |
+
path = os.path.join('weights', 'final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36')
|
12 |
+
checkpoint = torch.load(path, map_location="cpu")
|
13 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
14 |
+
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
15 |
+
model.eval()
|
16 |
+
del checkpoint
|
17 |
+
models = [model.float()]
|
18 |
+
|
19 |
+
# Setting Video
|
20 |
+
frames_per_video = 32
|
21 |
+
video_reader = VideoReader()
|
22 |
+
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
23 |
+
face_extractor = FaceExtractor(video_read_fn)
|
24 |
+
input_size = 380
|
25 |
+
strategy = confident_strategy
|
26 |
+
|
27 |
+
# Predict
|
28 |
+
pred = predict_on_video(
|
29 |
+
face_extractor=face_extractor,
|
30 |
+
video=video,
|
31 |
+
batch_size=frames_per_video,
|
32 |
+
input_size=input_size,
|
33 |
+
models=models,
|
34 |
+
strategy=strategy
|
35 |
+
)
|
36 |
+
prob = {'Fake': float(pred), 'Real': float(1 - pred)}
|
37 |
+
return prob
|
38 |
+
|
39 |
+
gr_inputs = gr.Video(format='mp4', source='upload')
|
40 |
+
gr_outputs = gr.Label()
|
41 |
+
gr_ex = [
|
42 |
+
[os.path.join(os.path.dirname(__file__),"sample/sample1.mp4")],
|
43 |
+
[os.path.join(os.path.dirname(__file__),"sample/sample2.mp4")],
|
44 |
+
]
|
45 |
+
iface = gr.Interface(
|
46 |
+
fn=detect,
|
47 |
+
inputs=gr_inputs,
|
48 |
+
outputs=gr_outputs,
|
49 |
+
examples=gr_ex,
|
50 |
+
)
|
51 |
+
|
52 |
+
iface.launch()
|
preprocessing/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .face_detector import *
|
preprocessing/compress_videos.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import subprocess
|
5 |
+
|
6 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
7 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
8 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
9 |
+
from functools import partial
|
10 |
+
from glob import glob
|
11 |
+
from multiprocessing.pool import Pool
|
12 |
+
from os import cpu_count
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
cv2.ocl.setUseOpenCL(False)
|
17 |
+
cv2.setNumThreads(0)
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
|
21 |
+
def compress_video(video, root_dir):
|
22 |
+
parent_dir = video.split("/")[-2]
|
23 |
+
out_dir = os.path.join(root_dir, "compressed", parent_dir)
|
24 |
+
os.makedirs(out_dir, exist_ok=True)
|
25 |
+
video_name = video.split("/")[-1]
|
26 |
+
out_path = os.path.join(out_dir, video_name)
|
27 |
+
lvl = random.choice([23, 28, 32])
|
28 |
+
command = "ffmpeg -i {} -c:v libx264 -crf {} -threads 1 {}".format(video, lvl, out_path)
|
29 |
+
try:
|
30 |
+
subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
|
31 |
+
except Exception as e:
|
32 |
+
print("Could not process vide", str(e))
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
parser = argparse.ArgumentParser(
|
37 |
+
description="Extracts jpegs from video")
|
38 |
+
parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
videos = [video_path for video_path in glob(os.path.join(args.root_dir, "*/*.mp4"))]
|
42 |
+
with Pool(processes=cpu_count() - 2) as p:
|
43 |
+
with tqdm(total=len(videos)) as pbar:
|
44 |
+
for v in p.imap_unordered(partial(compress_video, root_dir=args.root_dir), videos):
|
45 |
+
pbar.update()
|
preprocessing/detect_original_faces.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from os import cpu_count
|
5 |
+
from typing import Type
|
6 |
+
|
7 |
+
from torch.utils.data.dataloader import DataLoader
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from preprocessing import face_detector, VideoDataset
|
11 |
+
from preprocessing.face_detector import VideoFaceDetector
|
12 |
+
from preprocessing.utils import get_original_video_paths
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description="Process a original videos with face detector")
|
18 |
+
parser.add_argument("--root-dir", help="root directory")
|
19 |
+
parser.add_argument("--detector-type", help="type of the detector", default="FacenetDetector",
|
20 |
+
choices=["FacenetDetector"])
|
21 |
+
args = parser.parse_args()
|
22 |
+
return args
|
23 |
+
|
24 |
+
|
25 |
+
def process_videos(videos, root_dir, detector_cls: Type[VideoFaceDetector]):
|
26 |
+
detector = face_detector.__dict__[detector_cls](device="cuda:0")
|
27 |
+
dataset = VideoDataset(videos)
|
28 |
+
loader = DataLoader(dataset, shuffle=False, num_workers=cpu_count() - 2, batch_size=1, collate_fn=lambda x: x)
|
29 |
+
for item in tqdm(loader):
|
30 |
+
result = {}
|
31 |
+
video, indices, frames = item[0]
|
32 |
+
batches = [frames[i:i + detector._batch_size] for i in range(0, len(frames), detector._batch_size)]
|
33 |
+
for j, frames in enumerate(batches):
|
34 |
+
result.update({int(j * detector._batch_size) + i : b for i, b in zip(indices, detector._detect_faces(frames))})
|
35 |
+
id = os.path.splitext(os.path.basename(video))[0]
|
36 |
+
out_dir = os.path.join(root_dir, "boxes")
|
37 |
+
os.makedirs(out_dir, exist_ok=True)
|
38 |
+
with open(os.path.join(out_dir, "{}.json".format(id)), "w") as f:
|
39 |
+
json.dump(result, f)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
args = parse_args()
|
46 |
+
originals = get_original_video_paths(args.root_dir)
|
47 |
+
process_videos(originals, args.root_dir, args.detector_type)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
preprocessing/extract_crops.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from os import cpu_count
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
8 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
9 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
10 |
+
from functools import partial
|
11 |
+
from glob import glob
|
12 |
+
from multiprocessing.pool import Pool
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
cv2.ocl.setUseOpenCL(False)
|
17 |
+
cv2.setNumThreads(0)
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
|
21 |
+
def extract_video(param, root_dir, crops_dir):
|
22 |
+
video, bboxes_path = param
|
23 |
+
with open(bboxes_path, "r") as bbox_f:
|
24 |
+
bboxes_dict = json.load(bbox_f)
|
25 |
+
|
26 |
+
capture = cv2.VideoCapture(video)
|
27 |
+
frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
28 |
+
|
29 |
+
for i in range(frames_num):
|
30 |
+
capture.grab()
|
31 |
+
if i % 10 != 0:
|
32 |
+
continue
|
33 |
+
success, frame = capture.retrieve()
|
34 |
+
if not success or str(i) not in bboxes_dict:
|
35 |
+
continue
|
36 |
+
id = os.path.splitext(os.path.basename(video))[0]
|
37 |
+
crops = []
|
38 |
+
bboxes = bboxes_dict[str(i)]
|
39 |
+
if bboxes is None:
|
40 |
+
continue
|
41 |
+
for bbox in bboxes:
|
42 |
+
xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
|
43 |
+
w = xmax - xmin
|
44 |
+
h = ymax - ymin
|
45 |
+
p_h = h // 3
|
46 |
+
p_w = w // 3
|
47 |
+
crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
|
48 |
+
h, w = crop.shape[:2]
|
49 |
+
crops.append(crop)
|
50 |
+
img_dir = os.path.join(root_dir, crops_dir, id)
|
51 |
+
os.makedirs(img_dir, exist_ok=True)
|
52 |
+
for j, crop in enumerate(crops):
|
53 |
+
cv2.imwrite(os.path.join(img_dir, "{}_{}.png".format(i, j)), crop)
|
54 |
+
|
55 |
+
|
56 |
+
def get_video_paths(root_dir):
|
57 |
+
paths = []
|
58 |
+
for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
|
59 |
+
dir = Path(json_path).parent
|
60 |
+
with open(json_path, "r") as f:
|
61 |
+
metadata = json.load(f)
|
62 |
+
for k, v in metadata.items():
|
63 |
+
original = v.get("original", None)
|
64 |
+
if not original:
|
65 |
+
original = k
|
66 |
+
bboxes_path = os.path.join(root_dir, "boxes", original[:-4] + ".json")
|
67 |
+
if not os.path.exists(bboxes_path):
|
68 |
+
continue
|
69 |
+
paths.append((os.path.join(dir, k), bboxes_path))
|
70 |
+
|
71 |
+
return paths
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
parser = argparse.ArgumentParser(
|
76 |
+
description="Extracts crops from video")
|
77 |
+
parser.add_argument("--root-dir", help="root directory")
|
78 |
+
parser.add_argument("--crops-dir", help="crops directory")
|
79 |
+
|
80 |
+
args = parser.parse_args()
|
81 |
+
os.makedirs(os.path.join(args.root_dir, args.crops_dir), exist_ok=True)
|
82 |
+
params = get_video_paths(args.root_dir)
|
83 |
+
with Pool(processes=cpu_count()) as p:
|
84 |
+
with tqdm(total=len(params)) as pbar:
|
85 |
+
for v in p.imap_unordered(partial(extract_video, root_dir=args.root_dir, crops_dir=args.crops_dir), params):
|
86 |
+
pbar.update()
|
preprocessing/extract_images.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
4 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
5 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
6 |
+
from functools import partial
|
7 |
+
from glob import glob
|
8 |
+
from multiprocessing.pool import Pool
|
9 |
+
from os import cpu_count
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
cv2.ocl.setUseOpenCL(False)
|
13 |
+
cv2.setNumThreads(0)
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
def extract_video(video, root_dir):
|
18 |
+
capture = cv2.VideoCapture(video)
|
19 |
+
frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
20 |
+
|
21 |
+
for i in range(frames_num):
|
22 |
+
capture.grab()
|
23 |
+
success, frame = capture.retrieve()
|
24 |
+
if not success:
|
25 |
+
continue
|
26 |
+
id = os.path.splitext(os.path.basename(video))[0]
|
27 |
+
cv2.imwrite(os.path.join(root_dir, "jpegs", "{}_{}.jpg".format(id, i)), frame, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
parser = argparse.ArgumentParser(
|
33 |
+
description="Extracts jpegs from video")
|
34 |
+
parser.add_argument("--root-dir", help="root directory")
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
os.makedirs(os.path.join(args.root_dir, "jpegs"), exist_ok=True)
|
38 |
+
videos = [video_path for video_path in glob(os.path.join(args.root_dir, "*/*.mp4"))]
|
39 |
+
with Pool(processes=cpu_count() - 2) as p:
|
40 |
+
with tqdm(total=len(videos)) as pbar:
|
41 |
+
for v in p.imap_unordered(partial(extract_video, root_dir=args.root_dir), videos):
|
42 |
+
pbar.update()
|
preprocessing/face_detector.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
3 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
4 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
cv2.ocl.setUseOpenCL(False)
|
12 |
+
cv2.setNumThreads(0)
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
from facenet_pytorch.models.mtcnn import MTCNN
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
|
18 |
+
|
19 |
+
class VideoFaceDetector(ABC):
|
20 |
+
|
21 |
+
def __init__(self, **kwargs) -> None:
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
@property
|
25 |
+
@abstractmethod
|
26 |
+
def _batch_size(self) -> int:
|
27 |
+
pass
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
def _detect_faces(self, frames) -> List:
|
31 |
+
pass
|
32 |
+
|
33 |
+
|
34 |
+
class FacenetDetector(VideoFaceDetector):
|
35 |
+
|
36 |
+
def __init__(self, device="cuda:0") -> None:
|
37 |
+
super().__init__()
|
38 |
+
self.detector = MTCNN(margin=0,thresholds=[0.85, 0.95, 0.95], device=device)
|
39 |
+
|
40 |
+
def _detect_faces(self, frames) -> List:
|
41 |
+
batch_boxes, *_ = self.detector.detect(frames, landmarks=False)
|
42 |
+
return [b.tolist() if b is not None else None for b in batch_boxes]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def _batch_size(self):
|
46 |
+
return 32
|
47 |
+
|
48 |
+
|
49 |
+
class VideoDataset(Dataset):
|
50 |
+
|
51 |
+
def __init__(self, videos) -> None:
|
52 |
+
super().__init__()
|
53 |
+
self.videos = videos
|
54 |
+
|
55 |
+
def __getitem__(self, index: int):
|
56 |
+
video = self.videos[index]
|
57 |
+
capture = cv2.VideoCapture(video)
|
58 |
+
frames_num = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
59 |
+
frames = OrderedDict()
|
60 |
+
for i in range(frames_num):
|
61 |
+
capture.grab()
|
62 |
+
success, frame = capture.retrieve()
|
63 |
+
if not success:
|
64 |
+
continue
|
65 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
66 |
+
frame = Image.fromarray(frame)
|
67 |
+
frame = frame.resize(size=[s // 2 for s in frame.size])
|
68 |
+
frames[i] = frame
|
69 |
+
return video, list(frames.keys()), list(frames.values())
|
70 |
+
|
71 |
+
def __len__(self) -> int:
|
72 |
+
return len(self.videos)
|
preprocessing/face_encodings.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
from multiprocessing.pool import Pool
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from preprocessing.utils import get_original_video_paths
|
9 |
+
|
10 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
11 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
12 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
import face_recognition
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
def write_face_encodings(video, root_dir):
|
21 |
+
video_id, *_ = os.path.splitext(video)
|
22 |
+
crops_dir = os.path.join(root_dir, "crops", video_id)
|
23 |
+
if not os.path.exists(crops_dir):
|
24 |
+
return
|
25 |
+
crop_files = [f for f in os.listdir(crops_dir) if f.endswith("jpg")]
|
26 |
+
if crop_files:
|
27 |
+
crop_files = random.sample(crop_files, min(10, len(crop_files)))
|
28 |
+
encodings = []
|
29 |
+
for crop_file in crop_files:
|
30 |
+
img = face_recognition.load_image_file(os.path.join(crops_dir, crop_file))
|
31 |
+
encoding = face_recognition.face_encodings(img, num_jitters=10)
|
32 |
+
if encoding:
|
33 |
+
encodings.append(encoding[0])
|
34 |
+
np.save(os.path.join(crops_dir, "encodings"), encodings)
|
35 |
+
|
36 |
+
|
37 |
+
def parse_args():
|
38 |
+
parser = argparse.ArgumentParser(
|
39 |
+
description="Extract 10 crops encodings for each video")
|
40 |
+
parser.add_argument("--root-dir", help="root directory", default="/home/selim/datasets/deepfake")
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
args = parse_args()
|
47 |
+
originals = get_original_video_paths(args.root_dir, basename=True)
|
48 |
+
with Pool(processes=os.cpu_count() - 4) as p:
|
49 |
+
with tqdm(total=len(originals)) as pbar:
|
50 |
+
for v in p.imap_unordered(partial(write_face_encodings, root_dir=args.root_dir), originals):
|
51 |
+
pbar.update()
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
main()
|
preprocessing/generate_diffs.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
5 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
6 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
7 |
+
from skimage.measure import compare_ssim
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
from multiprocessing.pool import Pool
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from preprocessing.utils import get_original_with_fakes
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
|
18 |
+
cv2.ocl.setUseOpenCL(False)
|
19 |
+
cv2.setNumThreads(0)
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
cache = {}
|
24 |
+
|
25 |
+
|
26 |
+
def save_diffs(pair, root_dir):
|
27 |
+
ori_id, fake_id = pair
|
28 |
+
ori_dir = os.path.join(root_dir, "crops", ori_id)
|
29 |
+
fake_dir = os.path.join(root_dir, "crops", fake_id)
|
30 |
+
diff_dir = os.path.join(root_dir, "diffs", fake_id)
|
31 |
+
os.makedirs(diff_dir, exist_ok=True)
|
32 |
+
for frame in range(320):
|
33 |
+
if frame % 10 != 0:
|
34 |
+
continue
|
35 |
+
for actor in range(2):
|
36 |
+
image_id = "{}_{}.png".format(frame, actor)
|
37 |
+
diff_image_id = "{}_{}_diff.png".format(frame, actor)
|
38 |
+
ori_path = os.path.join(ori_dir, image_id)
|
39 |
+
fake_path = os.path.join(fake_dir, image_id)
|
40 |
+
diff_path = os.path.join(diff_dir, diff_image_id)
|
41 |
+
if os.path.exists(ori_path) and os.path.exists(fake_path):
|
42 |
+
img1 = cv2.imread(ori_path, cv2.IMREAD_COLOR)
|
43 |
+
img2 = cv2.imread(fake_path, cv2.IMREAD_COLOR)
|
44 |
+
try:
|
45 |
+
d, a = compare_ssim(img1, img2, multichannel=True, full=True)
|
46 |
+
a = 1 - a
|
47 |
+
diff = (a * 255).astype(np.uint8)
|
48 |
+
diff = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY)
|
49 |
+
cv2.imwrite(diff_path, diff)
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
|
53 |
+
def parse_args():
|
54 |
+
parser = argparse.ArgumentParser(
|
55 |
+
description="Extract image diffs")
|
56 |
+
parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
|
57 |
+
args = parser.parse_args()
|
58 |
+
return args
|
59 |
+
|
60 |
+
|
61 |
+
def main():
|
62 |
+
args = parse_args()
|
63 |
+
pairs = get_original_with_fakes(args.root_dir)
|
64 |
+
os.makedirs(os.path.join(args.root_dir, "diffs"), exist_ok=True)
|
65 |
+
with Pool(processes=os.cpu_count() - 2) as p:
|
66 |
+
with tqdm(total=len(pairs)) as pbar:
|
67 |
+
func = partial(save_diffs, root_dir=args.root_dir)
|
68 |
+
for v in p.imap_unordered(func, pairs):
|
69 |
+
pbar.update()
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
main()
|
preprocessing/generate_folds.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from functools import partial
|
6 |
+
from multiprocessing.pool import Pool
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
10 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
11 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
12 |
+
import pandas as pd
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from preprocessing.utils import get_original_with_fakes
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
cv2.ocl.setUseOpenCL(False)
|
21 |
+
cv2.setNumThreads(0)
|
22 |
+
|
23 |
+
|
24 |
+
def get_paths(vid, label, root_dir):
|
25 |
+
ori_vid, fake_vid = vid
|
26 |
+
ori_dir = os.path.join(root_dir, "crops", ori_vid)
|
27 |
+
fake_dir = os.path.join(root_dir, "crops", fake_vid)
|
28 |
+
data = []
|
29 |
+
for frame in range(320):
|
30 |
+
if frame % 10 != 0:
|
31 |
+
continue
|
32 |
+
for actor in range(2):
|
33 |
+
image_id = "{}_{}.png".format(frame, actor)
|
34 |
+
ori_img_path = os.path.join(ori_dir, image_id)
|
35 |
+
fake_img_path = os.path.join(fake_dir, image_id)
|
36 |
+
img_path = ori_img_path if label == 0 else fake_img_path
|
37 |
+
try:
|
38 |
+
# img = cv2.imread(img_path)[..., ::-1]
|
39 |
+
if os.path.exists(img_path):
|
40 |
+
data.append([img_path, label, ori_vid])
|
41 |
+
except:
|
42 |
+
pass
|
43 |
+
return data
|
44 |
+
|
45 |
+
|
46 |
+
def parse_args():
|
47 |
+
parser = argparse.ArgumentParser(
|
48 |
+
description="Generate Folds")
|
49 |
+
parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
|
50 |
+
parser.add_argument("--out", type=str, default="folds02.csv", help="CSV file to save")
|
51 |
+
parser.add_argument("--seed", type=int, default=777, help="Seed to split, default 777")
|
52 |
+
parser.add_argument("--n_splits", type=int, default=16, help="Num folds, default 10")
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
return args
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
args = parse_args()
|
60 |
+
ori_fakes = get_original_with_fakes(args.root_dir)
|
61 |
+
sz = 50 // args.n_splits
|
62 |
+
folds = []
|
63 |
+
for fold in range(args.n_splits):
|
64 |
+
folds.append(list(range(sz * fold, sz * fold + sz if fold < args.n_splits - 1 else 50)))
|
65 |
+
print(folds)
|
66 |
+
video_fold = {}
|
67 |
+
for d in os.listdir(args.root_dir):
|
68 |
+
if "dfdc" in d:
|
69 |
+
part = int(d.split("_")[-1])
|
70 |
+
for f in os.listdir(os.path.join(args.root_dir, d)):
|
71 |
+
if "metadata.json" in f:
|
72 |
+
with open(os.path.join(args.root_dir, d, "metadata.json")) as metadata_json:
|
73 |
+
metadata = json.load(metadata_json)
|
74 |
+
|
75 |
+
for k, v in metadata.items():
|
76 |
+
fold = None
|
77 |
+
for i, fold_dirs in enumerate(folds):
|
78 |
+
if part in fold_dirs:
|
79 |
+
fold = i
|
80 |
+
break
|
81 |
+
assert fold is not None
|
82 |
+
video_id = k[:-4]
|
83 |
+
video_fold[video_id] = fold
|
84 |
+
for fold in range(len(folds)):
|
85 |
+
holdoutset = {k for k, v in video_fold.items() if v == fold}
|
86 |
+
trainset = {k for k, v in video_fold.items() if v != fold}
|
87 |
+
assert holdoutset.isdisjoint(trainset), "Folds have leaks"
|
88 |
+
data = []
|
89 |
+
ori_ori = set([(ori, ori) for ori, fake in ori_fakes])
|
90 |
+
with Pool(processes=os.cpu_count()) as p:
|
91 |
+
with tqdm(total=len(ori_ori)) as pbar:
|
92 |
+
func = partial(get_paths, label=0, root_dir=args.root_dir)
|
93 |
+
for v in p.imap_unordered(func, ori_ori):
|
94 |
+
pbar.update()
|
95 |
+
data.extend(v)
|
96 |
+
with tqdm(total=len(ori_fakes)) as pbar:
|
97 |
+
func = partial(get_paths, label=1, root_dir=args.root_dir)
|
98 |
+
for v in p.imap_unordered(func, ori_fakes):
|
99 |
+
pbar.update()
|
100 |
+
data.extend(v)
|
101 |
+
fold_data = []
|
102 |
+
for img_path, label, ori_vid in data:
|
103 |
+
path = Path(img_path)
|
104 |
+
video = path.parent.name
|
105 |
+
file = path.name
|
106 |
+
assert video_fold[video] == video_fold[ori_vid], "original video and fake have leak {} {}".format(ori_vid,
|
107 |
+
video)
|
108 |
+
fold_data.append([video, file, label, ori_vid, int(file.split("_")[0]), video_fold[video]])
|
109 |
+
random.shuffle(fold_data)
|
110 |
+
pd.DataFrame(fold_data, columns=["video", "file", "label", "original", "frame", "fold"]).to_csv(args.out, index=False)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
main()
|
preprocessing/generate_landmarks.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
from multiprocessing.pool import Pool
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
9 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
10 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
|
17 |
+
cv2.ocl.setUseOpenCL(False)
|
18 |
+
cv2.setNumThreads(0)
|
19 |
+
from preprocessing.utils import get_original_video_paths
|
20 |
+
|
21 |
+
from PIL import Image
|
22 |
+
from facenet_pytorch.models.mtcnn import MTCNN
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
detector = MTCNN(margin=0, thresholds=[0.65, 0.75, 0.75], device="cpu")
|
26 |
+
|
27 |
+
|
28 |
+
def save_landmarks(ori_id, root_dir):
|
29 |
+
ori_id = ori_id[:-4]
|
30 |
+
ori_dir = os.path.join(root_dir, "crops", ori_id)
|
31 |
+
landmark_dir = os.path.join(root_dir, "landmarks", ori_id)
|
32 |
+
os.makedirs(landmark_dir, exist_ok=True)
|
33 |
+
for frame in range(320):
|
34 |
+
if frame % 10 != 0:
|
35 |
+
continue
|
36 |
+
for actor in range(2):
|
37 |
+
image_id = "{}_{}.png".format(frame, actor)
|
38 |
+
landmarks_id = "{}_{}".format(frame, actor)
|
39 |
+
ori_path = os.path.join(ori_dir, image_id)
|
40 |
+
landmark_path = os.path.join(landmark_dir, landmarks_id)
|
41 |
+
|
42 |
+
if os.path.exists(ori_path):
|
43 |
+
try:
|
44 |
+
image_ori = cv2.imread(ori_path, cv2.IMREAD_COLOR)[...,::-1]
|
45 |
+
frame_img = Image.fromarray(image_ori)
|
46 |
+
batch_boxes, conf, landmarks = detector.detect(frame_img, landmarks=True)
|
47 |
+
if landmarks is not None:
|
48 |
+
landmarks = np.around(landmarks[0]).astype(np.int16)
|
49 |
+
np.save(landmark_path, landmarks)
|
50 |
+
except Exception as e:
|
51 |
+
print(e)
|
52 |
+
pass
|
53 |
+
|
54 |
+
|
55 |
+
def parse_args():
|
56 |
+
parser = argparse.ArgumentParser(
|
57 |
+
description="Extract image landmarks")
|
58 |
+
parser.add_argument("--root-dir", help="root directory", default="/mnt/sota/datasets/deepfake")
|
59 |
+
args = parser.parse_args()
|
60 |
+
return args
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
args = parse_args()
|
65 |
+
ids = get_original_video_paths(args.root_dir, basename=True)
|
66 |
+
os.makedirs(os.path.join(args.root_dir, "landmarks"), exist_ok=True)
|
67 |
+
with Pool(processes=os.cpu_count()) as p:
|
68 |
+
with tqdm(total=len(ids)) as pbar:
|
69 |
+
func = partial(save_landmarks, root_dir=args.root_dir)
|
70 |
+
for v in p.imap_unordered(func, ids):
|
71 |
+
pbar.update()
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == '__main__':
|
75 |
+
main()
|
preprocessing/utils.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
def get_original_video_paths(root_dir, basename=False):
|
8 |
+
originals = set()
|
9 |
+
originals_v = set()
|
10 |
+
for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
|
11 |
+
dir = Path(json_path).parent
|
12 |
+
with open(json_path, "r") as f:
|
13 |
+
metadata = json.load(f)
|
14 |
+
for k, v in metadata.items():
|
15 |
+
original = v.get("original", None)
|
16 |
+
if v["label"] == "REAL":
|
17 |
+
original = k
|
18 |
+
originals_v.add(original)
|
19 |
+
originals.add(os.path.join(dir, original))
|
20 |
+
originals = list(originals)
|
21 |
+
originals_v = list(originals_v)
|
22 |
+
print(len(originals))
|
23 |
+
return originals_v if basename else originals
|
24 |
+
|
25 |
+
|
26 |
+
def get_original_with_fakes(root_dir):
|
27 |
+
pairs = []
|
28 |
+
for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
|
29 |
+
with open(json_path, "r") as f:
|
30 |
+
metadata = json.load(f)
|
31 |
+
for k, v in metadata.items():
|
32 |
+
original = v.get("original", None)
|
33 |
+
if v["label"] == "FAKE":
|
34 |
+
pairs.append((original[:-4], k[:-4] ))
|
35 |
+
|
36 |
+
return pairs
|
37 |
+
|
38 |
+
|
39 |
+
def get_originals_and_fakes(root_dir):
|
40 |
+
originals = []
|
41 |
+
fakes = []
|
42 |
+
for json_path in glob(os.path.join(root_dir, "*/metadata.json")):
|
43 |
+
with open(json_path, "r") as f:
|
44 |
+
metadata = json.load(f)
|
45 |
+
for k, v in metadata.items():
|
46 |
+
if v["label"] == "FAKE":
|
47 |
+
fakes.append(k[:-4])
|
48 |
+
else:
|
49 |
+
originals.append(k[:-4])
|
50 |
+
|
51 |
+
return originals, fakes
|
sample/sample1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37f0ff1e337e6fe2211757b8255c8098f8a4f303b333e9c5e0da27c37e095876
|
3 |
+
size 5128324
|
sample/sample2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa3a01923fdfe38abc7b12b436c0e631575b6337fcde90cb9017e8dd5b733db0
|
3 |
+
size 16193231
|
training/__init__.py
ADDED
File without changes
|
training/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (157 Bytes). View file
|
|
training/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (161 Bytes). View file
|
|
training/datasets/__init__.py
ADDED
File without changes
|
training/datasets/classifier_dataset.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import skimage.draw
|
11 |
+
from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
|
12 |
+
from albumentations.augmentations.functional import image_compression, rot90
|
13 |
+
from albumentations.pytorch.functional import img_to_tensor
|
14 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
15 |
+
from skimage import measure
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
import dlib
|
18 |
+
|
19 |
+
from training.datasets.validation_set import PUBLIC_SET
|
20 |
+
|
21 |
+
|
22 |
+
def prepare_bit_masks(mask):
|
23 |
+
h, w = mask.shape
|
24 |
+
mid_w = w // 2
|
25 |
+
mid_h = w // 2
|
26 |
+
masks = []
|
27 |
+
ones = np.ones_like(mask)
|
28 |
+
ones[:mid_h] = 0
|
29 |
+
masks.append(ones)
|
30 |
+
ones = np.ones_like(mask)
|
31 |
+
ones[mid_h:] = 0
|
32 |
+
masks.append(ones)
|
33 |
+
ones = np.ones_like(mask)
|
34 |
+
ones[:, :mid_w] = 0
|
35 |
+
masks.append(ones)
|
36 |
+
ones = np.ones_like(mask)
|
37 |
+
ones[:, mid_w:] = 0
|
38 |
+
masks.append(ones)
|
39 |
+
ones = np.ones_like(mask)
|
40 |
+
ones[:mid_h, :mid_w] = 0
|
41 |
+
ones[mid_h:, mid_w:] = 0
|
42 |
+
masks.append(ones)
|
43 |
+
ones = np.ones_like(mask)
|
44 |
+
ones[:mid_h, mid_w:] = 0
|
45 |
+
ones[mid_h:, :mid_w] = 0
|
46 |
+
masks.append(ones)
|
47 |
+
return masks
|
48 |
+
|
49 |
+
|
50 |
+
detector = dlib.get_frontal_face_detector()
|
51 |
+
predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
|
52 |
+
|
53 |
+
|
54 |
+
def blackout_convex_hull(img):
|
55 |
+
try:
|
56 |
+
rect = detector(img)[0]
|
57 |
+
sp = predictor(img, rect)
|
58 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
59 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
60 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
61 |
+
cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
|
62 |
+
cropped_img[Y, X] = 1
|
63 |
+
# if random.random() > 0.5:
|
64 |
+
# img[cropped_img == 0] = 0
|
65 |
+
# #leave only face
|
66 |
+
# return img
|
67 |
+
|
68 |
+
y, x = measure.centroid(cropped_img)
|
69 |
+
y = int(y)
|
70 |
+
x = int(x)
|
71 |
+
first = random.random() > 0.5
|
72 |
+
if random.random() > 0.5:
|
73 |
+
if first:
|
74 |
+
cropped_img[:y, :] = 0
|
75 |
+
else:
|
76 |
+
cropped_img[y:, :] = 0
|
77 |
+
else:
|
78 |
+
if first:
|
79 |
+
cropped_img[:, :x] = 0
|
80 |
+
else:
|
81 |
+
cropped_img[:, x:] = 0
|
82 |
+
|
83 |
+
img[cropped_img > 0] = 0
|
84 |
+
except Exception as e:
|
85 |
+
pass
|
86 |
+
|
87 |
+
|
88 |
+
def dist(p1, p2):
|
89 |
+
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
90 |
+
|
91 |
+
|
92 |
+
def remove_eyes(image, landmarks):
|
93 |
+
image = image.copy()
|
94 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
95 |
+
mask = np.zeros_like(image[..., 0])
|
96 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
97 |
+
w = dist((x1, y1), (x2, y2))
|
98 |
+
dilation = int(w // 4)
|
99 |
+
line = binary_dilation(line, iterations=dilation)
|
100 |
+
image[line, :] = 0
|
101 |
+
return image
|
102 |
+
|
103 |
+
|
104 |
+
def remove_nose(image, landmarks):
|
105 |
+
image = image.copy()
|
106 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
107 |
+
x3, y3 = landmarks[2]
|
108 |
+
mask = np.zeros_like(image[..., 0])
|
109 |
+
x4 = int((x1 + x2) / 2)
|
110 |
+
y4 = int((y1 + y2) / 2)
|
111 |
+
line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
|
112 |
+
w = dist((x1, y1), (x2, y2))
|
113 |
+
dilation = int(w // 4)
|
114 |
+
line = binary_dilation(line, iterations=dilation)
|
115 |
+
image[line, :] = 0
|
116 |
+
return image
|
117 |
+
|
118 |
+
|
119 |
+
def remove_mouth(image, landmarks):
|
120 |
+
image = image.copy()
|
121 |
+
(x1, y1), (x2, y2) = landmarks[-2:]
|
122 |
+
mask = np.zeros_like(image[..., 0])
|
123 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
124 |
+
w = dist((x1, y1), (x2, y2))
|
125 |
+
dilation = int(w // 3)
|
126 |
+
line = binary_dilation(line, iterations=dilation)
|
127 |
+
image[line, :] = 0
|
128 |
+
return image
|
129 |
+
|
130 |
+
|
131 |
+
def remove_landmark(image, landmarks):
|
132 |
+
if random.random() > 0.5:
|
133 |
+
image = remove_eyes(image, landmarks)
|
134 |
+
elif random.random() > 0.5:
|
135 |
+
image = remove_mouth(image, landmarks)
|
136 |
+
elif random.random() > 0.5:
|
137 |
+
image = remove_nose(image, landmarks)
|
138 |
+
return image
|
139 |
+
|
140 |
+
|
141 |
+
def change_padding(image, part=5):
|
142 |
+
h, w = image.shape[:2]
|
143 |
+
# original padding was done with 1/3 from each side, too much
|
144 |
+
pad_h = int(((3 / 5) * h) / part)
|
145 |
+
pad_w = int(((3 / 5) * w) / part)
|
146 |
+
image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
|
147 |
+
return image
|
148 |
+
|
149 |
+
|
150 |
+
def blackout_random(image, mask, label):
|
151 |
+
binary_mask = mask > 0.4 * 255
|
152 |
+
h, w = binary_mask.shape[:2]
|
153 |
+
|
154 |
+
tries = 50
|
155 |
+
current_try = 1
|
156 |
+
while current_try < tries:
|
157 |
+
first = random.random() < 0.5
|
158 |
+
if random.random() < 0.5:
|
159 |
+
pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
|
160 |
+
bitmap_msk = np.ones_like(binary_mask)
|
161 |
+
if first:
|
162 |
+
bitmap_msk[:pivot, :] = 0
|
163 |
+
else:
|
164 |
+
bitmap_msk[pivot:, :] = 0
|
165 |
+
else:
|
166 |
+
pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
|
167 |
+
bitmap_msk = np.ones_like(binary_mask)
|
168 |
+
if first:
|
169 |
+
bitmap_msk[:, :pivot] = 0
|
170 |
+
else:
|
171 |
+
bitmap_msk[:, pivot:] = 0
|
172 |
+
|
173 |
+
if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
|
174 |
+
or np.count_nonzero(binary_mask * bitmap_msk) > 40:
|
175 |
+
mask *= bitmap_msk
|
176 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
177 |
+
break
|
178 |
+
current_try += 1
|
179 |
+
return image
|
180 |
+
|
181 |
+
|
182 |
+
def blend_original(img):
|
183 |
+
img = img.copy()
|
184 |
+
h, w = img.shape[:2]
|
185 |
+
rect = detector(img)
|
186 |
+
if len(rect) == 0:
|
187 |
+
return img
|
188 |
+
else:
|
189 |
+
rect = rect[0]
|
190 |
+
sp = predictor(img, rect)
|
191 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
192 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
193 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
194 |
+
raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
195 |
+
raw_mask[Y, X] = 1
|
196 |
+
face = img * np.expand_dims(raw_mask, -1)
|
197 |
+
|
198 |
+
# add warping
|
199 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
200 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
201 |
+
while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
|
202 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
203 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
204 |
+
face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
205 |
+
face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
206 |
+
|
207 |
+
raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
|
208 |
+
img[raw_mask, :] = face[raw_mask, :]
|
209 |
+
if random.random() < 0.2:
|
210 |
+
img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
|
211 |
+
# image compression
|
212 |
+
if random.random() < 0.5:
|
213 |
+
img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
|
214 |
+
return img
|
215 |
+
|
216 |
+
|
217 |
+
class DeepFakeClassifierDataset(Dataset):
|
218 |
+
|
219 |
+
def __init__(self,
|
220 |
+
data_path="/mnt/sota/datasets/deepfake",
|
221 |
+
fold=0,
|
222 |
+
label_smoothing=0.01,
|
223 |
+
padding_part=3,
|
224 |
+
hardcore=True,
|
225 |
+
crops_dir="crops",
|
226 |
+
folds_csv="folds.csv",
|
227 |
+
normalize={"mean": [0.485, 0.456, 0.406],
|
228 |
+
"std": [0.229, 0.224, 0.225]},
|
229 |
+
rotation=False,
|
230 |
+
mode="train",
|
231 |
+
reduce_val=True,
|
232 |
+
oversample_real=True,
|
233 |
+
transforms=None
|
234 |
+
):
|
235 |
+
super().__init__()
|
236 |
+
self.data_root = data_path
|
237 |
+
self.fold = fold
|
238 |
+
self.folds_csv = folds_csv
|
239 |
+
self.mode = mode
|
240 |
+
self.rotation = rotation
|
241 |
+
self.padding_part = padding_part
|
242 |
+
self.hardcore = hardcore
|
243 |
+
self.crops_dir = crops_dir
|
244 |
+
self.label_smoothing = label_smoothing
|
245 |
+
self.normalize = normalize
|
246 |
+
self.transforms = transforms
|
247 |
+
self.df = pd.read_csv(self.folds_csv)
|
248 |
+
self.oversample_real = oversample_real
|
249 |
+
self.reduce_val = reduce_val
|
250 |
+
|
251 |
+
def __getitem__(self, index: int):
|
252 |
+
|
253 |
+
while True:
|
254 |
+
video, img_file, label, ori_video, frame, fold = self.data[index]
|
255 |
+
try:
|
256 |
+
if self.mode == "train":
|
257 |
+
label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
|
258 |
+
img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
|
259 |
+
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
260 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
261 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
262 |
+
diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
|
263 |
+
try:
|
264 |
+
msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
|
265 |
+
if msk is not None:
|
266 |
+
mask = msk
|
267 |
+
except:
|
268 |
+
print("not found mask", diff_path)
|
269 |
+
pass
|
270 |
+
if self.mode == "train" and self.hardcore and not self.rotation:
|
271 |
+
landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
272 |
+
if os.path.exists(landmark_path) and random.random() < 0.7:
|
273 |
+
landmarks = np.load(landmark_path)
|
274 |
+
image = remove_landmark(image, landmarks)
|
275 |
+
elif random.random() < 0.2:
|
276 |
+
blackout_convex_hull(image)
|
277 |
+
elif random.random() < 0.1:
|
278 |
+
binary_mask = mask > 0.4 * 255
|
279 |
+
masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
|
280 |
+
tries = 6
|
281 |
+
current_try = 1
|
282 |
+
while current_try < tries:
|
283 |
+
bitmap_msk = random.choice(masks)
|
284 |
+
if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
|
285 |
+
mask *= bitmap_msk
|
286 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
287 |
+
break
|
288 |
+
current_try += 1
|
289 |
+
if self.mode == "train" and self.padding_part > 3:
|
290 |
+
image = change_padding(image, self.padding_part)
|
291 |
+
valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
|
292 |
+
valid_label = 1 if valid_label else 0
|
293 |
+
rotation = 0
|
294 |
+
if self.transforms:
|
295 |
+
data = self.transforms(image=image, mask=mask)
|
296 |
+
image = data["image"]
|
297 |
+
mask = data["mask"]
|
298 |
+
if self.mode == "train" and self.hardcore and self.rotation:
|
299 |
+
# landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
300 |
+
dropout = 0.8 if label > 0.5 else 0.6
|
301 |
+
if self.rotation:
|
302 |
+
dropout *= 0.7
|
303 |
+
elif random.random() < dropout:
|
304 |
+
blackout_random(image, mask, label)
|
305 |
+
|
306 |
+
#
|
307 |
+
# os.makedirs("../images", exist_ok=True)
|
308 |
+
# cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
|
309 |
+
|
310 |
+
if self.mode == "train" and self.rotation:
|
311 |
+
rotation = random.randint(0, 3)
|
312 |
+
image = rot90(image, rotation)
|
313 |
+
|
314 |
+
image = img_to_tensor(image, self.normalize)
|
315 |
+
return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
|
316 |
+
"valid": valid_label, "rotations": rotation}
|
317 |
+
except Exception as e:
|
318 |
+
traceback.print_exc(file=sys.stdout)
|
319 |
+
print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
|
320 |
+
index = random.randint(0, len(self.data) - 1)
|
321 |
+
|
322 |
+
def random_blackout_landmark(self, image, mask, landmarks):
|
323 |
+
x, y = random.choice(landmarks)
|
324 |
+
first = random.random() > 0.5
|
325 |
+
# crop half face either vertically or horizontally
|
326 |
+
if random.random() > 0.5:
|
327 |
+
# width
|
328 |
+
if first:
|
329 |
+
image[:, :x] = 0
|
330 |
+
mask[:, :x] = 0
|
331 |
+
else:
|
332 |
+
image[:, x:] = 0
|
333 |
+
mask[:, x:] = 0
|
334 |
+
else:
|
335 |
+
# height
|
336 |
+
if first:
|
337 |
+
image[:y, :] = 0
|
338 |
+
mask[:y, :] = 0
|
339 |
+
else:
|
340 |
+
image[y:, :] = 0
|
341 |
+
mask[y:, :] = 0
|
342 |
+
|
343 |
+
def reset(self, epoch, seed):
|
344 |
+
self.data = self._prepare_data(epoch, seed)
|
345 |
+
|
346 |
+
def __len__(self) -> int:
|
347 |
+
return len(self.data)
|
348 |
+
|
349 |
+
def _prepare_data(self, epoch, seed):
|
350 |
+
df = self.df
|
351 |
+
if self.mode == "train":
|
352 |
+
rows = df[df["fold"] != self.fold]
|
353 |
+
else:
|
354 |
+
rows = df[df["fold"] == self.fold]
|
355 |
+
seed = (epoch + 1) * seed
|
356 |
+
if self.oversample_real:
|
357 |
+
rows = self._oversample(rows, seed)
|
358 |
+
if self.mode == "val" and self.reduce_val:
|
359 |
+
# every 2nd frame, to speed up validation
|
360 |
+
rows = rows[rows["frame"] % 20 == 0]
|
361 |
+
# another option is to use public validation set
|
362 |
+
#rows = rows[rows["video"].isin(PUBLIC_SET)]
|
363 |
+
|
364 |
+
print(
|
365 |
+
"real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
|
366 |
+
data = rows.values
|
367 |
+
|
368 |
+
np.random.seed(seed)
|
369 |
+
np.random.shuffle(data)
|
370 |
+
return data
|
371 |
+
|
372 |
+
def _oversample(self, rows: pd.DataFrame, seed):
|
373 |
+
real = rows[rows["label"] == 0]
|
374 |
+
fakes = rows[rows["label"] == 1]
|
375 |
+
num_real = real["video"].count()
|
376 |
+
if self.mode == "train":
|
377 |
+
fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
|
378 |
+
return pd.concat([real, fakes])
|
training/datasets/validation_set.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
|
4 |
+
'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
|
5 |
+
'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
|
6 |
+
'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
|
7 |
+
'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
|
8 |
+
'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
|
9 |
+
'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
|
10 |
+
'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
|
11 |
+
'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
|
12 |
+
'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
|
13 |
+
'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
|
14 |
+
'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
|
15 |
+
'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
|
16 |
+
'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
|
17 |
+
'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
|
18 |
+
'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
|
19 |
+
'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
|
20 |
+
'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
|
21 |
+
'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
|
22 |
+
'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
|
23 |
+
'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
|
24 |
+
'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
|
25 |
+
'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
|
26 |
+
'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
|
27 |
+
'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
|
28 |
+
'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
|
29 |
+
'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
|
30 |
+
'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
|
31 |
+
'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
|
32 |
+
'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
|
33 |
+
'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
|
34 |
+
'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
|
35 |
+
'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
|
36 |
+
'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
|
37 |
+
'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
|
38 |
+
'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
|
39 |
+
'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
|
40 |
+
'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
|
41 |
+
'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
|
42 |
+
'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
|
43 |
+
'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
|
44 |
+
'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
|
45 |
+
'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
|
46 |
+
'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
|
47 |
+
'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
|
48 |
+
'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
|
49 |
+
'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
|
50 |
+
'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
|
51 |
+
'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
|
52 |
+
'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
|
53 |
+
'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
|
54 |
+
'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
|
55 |
+
'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
|
56 |
+
'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
|
57 |
+
'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
|
58 |
+
'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
|
59 |
+
'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
|
60 |
+
'txmnoyiyte'}
|
training/losses.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from pytorch_toolbelt.losses import BinaryFocalLoss
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules.loss import BCEWithLogitsLoss
|
6 |
+
|
7 |
+
|
8 |
+
class WeightedLosses(nn.Module):
|
9 |
+
def __init__(self, losses, weights):
|
10 |
+
super().__init__()
|
11 |
+
self.losses = losses
|
12 |
+
self.weights = weights
|
13 |
+
|
14 |
+
def forward(self, *input: Any, **kwargs: Any):
|
15 |
+
cum_loss = 0
|
16 |
+
for loss, w in zip(self.losses, self.weights):
|
17 |
+
cum_loss += w * loss.forward(*input, **kwargs)
|
18 |
+
return cum_loss
|
19 |
+
|
20 |
+
|
21 |
+
class BinaryCrossentropy(BCEWithLogitsLoss):
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class FocalLoss(BinaryFocalLoss):
|
26 |
+
def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
|
27 |
+
reduced_threshold=None):
|
28 |
+
super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
|
training/pipelines/__init__.py
ADDED
File without changes
|
training/pipelines/train_classifier.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from sklearn.metrics import log_loss
|
7 |
+
from torch import topk
|
8 |
+
|
9 |
+
from training import losses
|
10 |
+
from training.datasets.classifier_dataset import DeepFakeClassifierDataset
|
11 |
+
from training.losses import WeightedLosses
|
12 |
+
from training.tools.config import load_config
|
13 |
+
from training.tools.utils import create_optimizer, AverageMeter
|
14 |
+
from training.transforms.albu import IsotropicResize
|
15 |
+
from training.zoo import classifiers
|
16 |
+
|
17 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
18 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
19 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
20 |
+
|
21 |
+
import cv2
|
22 |
+
|
23 |
+
cv2.ocl.setUseOpenCL(False)
|
24 |
+
cv2.setNumThreads(0)
|
25 |
+
import numpy as np
|
26 |
+
from albumentations import Compose, RandomBrightnessContrast, \
|
27 |
+
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
|
28 |
+
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
|
29 |
+
|
30 |
+
from apex.parallel import DistributedDataParallel, convert_syncbn_model
|
31 |
+
from tensorboardX import SummaryWriter
|
32 |
+
|
33 |
+
from apex import amp
|
34 |
+
|
35 |
+
import torch
|
36 |
+
from torch.backends import cudnn
|
37 |
+
from torch.nn import DataParallel
|
38 |
+
from torch.utils.data import DataLoader
|
39 |
+
from tqdm import tqdm
|
40 |
+
import torch.distributed as dist
|
41 |
+
|
42 |
+
torch.backends.cudnn.benchmark = True
|
43 |
+
|
44 |
+
|
45 |
+
def create_train_transforms(size=300):
|
46 |
+
return Compose([
|
47 |
+
ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
|
48 |
+
GaussNoise(p=0.1),
|
49 |
+
GaussianBlur(blur_limit=3, p=0.05),
|
50 |
+
HorizontalFlip(),
|
51 |
+
OneOf([
|
52 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
53 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
|
54 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
|
55 |
+
], p=1),
|
56 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
57 |
+
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
|
58 |
+
ToGray(p=0.2),
|
59 |
+
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def create_val_transforms(size=300):
|
65 |
+
return Compose([
|
66 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
67 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
68 |
+
])
|
69 |
+
|
70 |
+
|
71 |
+
def main():
|
72 |
+
parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
|
73 |
+
arg = parser.add_argument
|
74 |
+
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
|
75 |
+
arg('--workers', type=int, default=6, help='number of cpu threads to use')
|
76 |
+
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
|
77 |
+
arg('--output-dir', type=str, default='weights/')
|
78 |
+
arg('--resume', type=str, default='')
|
79 |
+
arg('--fold', type=int, default=0)
|
80 |
+
arg('--prefix', type=str, default='classifier_')
|
81 |
+
arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
|
82 |
+
arg('--folds-csv', type=str, default='folds.csv')
|
83 |
+
arg('--crops-dir', type=str, default='crops')
|
84 |
+
arg('--label-smoothing', type=float, default=0.01)
|
85 |
+
arg('--logdir', type=str, default='logs')
|
86 |
+
arg('--zero-score', action='store_true', default=False)
|
87 |
+
arg('--from-zero', action='store_true', default=False)
|
88 |
+
arg('--distributed', action='store_true', default=False)
|
89 |
+
arg('--freeze-epochs', type=int, default=0)
|
90 |
+
arg("--local_rank", default=0, type=int)
|
91 |
+
arg("--seed", default=777, type=int)
|
92 |
+
arg("--padding-part", default=3, type=int)
|
93 |
+
arg("--opt-level", default='O1', type=str)
|
94 |
+
arg("--test_every", type=int, default=1)
|
95 |
+
arg("--no-oversample", action="store_true")
|
96 |
+
arg("--no-hardcore", action="store_true")
|
97 |
+
arg("--only-changed-frames", action="store_true")
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
101 |
+
if args.distributed:
|
102 |
+
torch.cuda.set_device(args.local_rank)
|
103 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
104 |
+
else:
|
105 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
106 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
107 |
+
|
108 |
+
cudnn.benchmark = True
|
109 |
+
|
110 |
+
conf = load_config(args.config)
|
111 |
+
model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
|
112 |
+
|
113 |
+
model = model.cuda()
|
114 |
+
if args.distributed:
|
115 |
+
model = convert_syncbn_model(model)
|
116 |
+
ohem = conf.get("ohem_samples", None)
|
117 |
+
reduction = "mean"
|
118 |
+
if ohem:
|
119 |
+
reduction = "none"
|
120 |
+
loss_fn = []
|
121 |
+
weights = []
|
122 |
+
for loss_name, weight in conf["losses"].items():
|
123 |
+
loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
|
124 |
+
weights.append(weight)
|
125 |
+
loss = WeightedLosses(loss_fn, weights)
|
126 |
+
loss_functions = {"classifier_loss": loss}
|
127 |
+
optimizer, scheduler = create_optimizer(conf['optimizer'], model)
|
128 |
+
bce_best = 100
|
129 |
+
start_epoch = 0
|
130 |
+
batch_size = conf['optimizer']['batch_size']
|
131 |
+
|
132 |
+
data_train = DeepFakeClassifierDataset(mode="train",
|
133 |
+
oversample_real=not args.no_oversample,
|
134 |
+
fold=args.fold,
|
135 |
+
padding_part=args.padding_part,
|
136 |
+
hardcore=not args.no_hardcore,
|
137 |
+
crops_dir=args.crops_dir,
|
138 |
+
data_path=args.data_dir,
|
139 |
+
label_smoothing=args.label_smoothing,
|
140 |
+
folds_csv=args.folds_csv,
|
141 |
+
transforms=create_train_transforms(conf["size"]),
|
142 |
+
normalize=conf.get("normalize", None))
|
143 |
+
data_val = DeepFakeClassifierDataset(mode="val",
|
144 |
+
fold=args.fold,
|
145 |
+
padding_part=args.padding_part,
|
146 |
+
crops_dir=args.crops_dir,
|
147 |
+
data_path=args.data_dir,
|
148 |
+
folds_csv=args.folds_csv,
|
149 |
+
transforms=create_val_transforms(conf["size"]),
|
150 |
+
normalize=conf.get("normalize", None))
|
151 |
+
val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
|
152 |
+
pin_memory=False)
|
153 |
+
os.makedirs(args.logdir, exist_ok=True)
|
154 |
+
summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
|
155 |
+
if args.resume:
|
156 |
+
if os.path.isfile(args.resume):
|
157 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
158 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
159 |
+
state_dict = checkpoint['state_dict']
|
160 |
+
state_dict = {k[7:]: w for k, w in state_dict.items()}
|
161 |
+
model.load_state_dict(state_dict, strict=False)
|
162 |
+
if not args.from_zero:
|
163 |
+
start_epoch = checkpoint['epoch']
|
164 |
+
if not args.zero_score:
|
165 |
+
bce_best = checkpoint.get('bce_best', 0)
|
166 |
+
print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
|
167 |
+
.format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
|
168 |
+
else:
|
169 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
170 |
+
if args.from_zero:
|
171 |
+
start_epoch = 0
|
172 |
+
current_epoch = start_epoch
|
173 |
+
|
174 |
+
if conf['fp16']:
|
175 |
+
model, optimizer = amp.initialize(model, optimizer,
|
176 |
+
opt_level=args.opt_level,
|
177 |
+
loss_scale='dynamic')
|
178 |
+
|
179 |
+
snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
|
180 |
+
|
181 |
+
if args.distributed:
|
182 |
+
model = DistributedDataParallel(model, delay_allreduce=True)
|
183 |
+
else:
|
184 |
+
model = DataParallel(model).cuda()
|
185 |
+
data_val.reset(1, args.seed)
|
186 |
+
max_epochs = conf['optimizer']['schedule']['epochs']
|
187 |
+
for epoch in range(start_epoch, max_epochs):
|
188 |
+
data_train.reset(epoch, args.seed)
|
189 |
+
train_sampler = None
|
190 |
+
if args.distributed:
|
191 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
|
192 |
+
train_sampler.set_epoch(epoch)
|
193 |
+
if epoch < args.freeze_epochs:
|
194 |
+
print("Freezing encoder!!!")
|
195 |
+
model.module.encoder.eval()
|
196 |
+
for p in model.module.encoder.parameters():
|
197 |
+
p.requires_grad = False
|
198 |
+
else:
|
199 |
+
model.module.encoder.train()
|
200 |
+
for p in model.module.encoder.parameters():
|
201 |
+
p.requires_grad = True
|
202 |
+
|
203 |
+
train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
|
204 |
+
shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
|
205 |
+
drop_last=True)
|
206 |
+
|
207 |
+
train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
208 |
+
args.local_rank, args.only_changed_frames)
|
209 |
+
model = model.eval()
|
210 |
+
|
211 |
+
if args.local_rank == 0:
|
212 |
+
torch.save({
|
213 |
+
'epoch': current_epoch + 1,
|
214 |
+
'state_dict': model.state_dict(),
|
215 |
+
'bce_best': bce_best,
|
216 |
+
}, args.output_dir + '/' + snapshot_name + "_last")
|
217 |
+
torch.save({
|
218 |
+
'epoch': current_epoch + 1,
|
219 |
+
'state_dict': model.state_dict(),
|
220 |
+
'bce_best': bce_best,
|
221 |
+
}, args.output_dir + snapshot_name + "_{}".format(current_epoch))
|
222 |
+
if (epoch + 1) % args.test_every == 0:
|
223 |
+
bce_best = evaluate_val(args, val_data_loader, bce_best, model,
|
224 |
+
snapshot_name=snapshot_name,
|
225 |
+
current_epoch=current_epoch,
|
226 |
+
summary_writer=summary_writer)
|
227 |
+
current_epoch += 1
|
228 |
+
|
229 |
+
|
230 |
+
def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
|
231 |
+
print("Test phase")
|
232 |
+
model = model.eval()
|
233 |
+
|
234 |
+
bce, probs, targets = validate(model, data_loader=data_val)
|
235 |
+
if args.local_rank == 0:
|
236 |
+
summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
|
237 |
+
if bce < bce_best:
|
238 |
+
print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
|
239 |
+
if args.output_dir is not None:
|
240 |
+
torch.save({
|
241 |
+
'epoch': current_epoch + 1,
|
242 |
+
'state_dict': model.state_dict(),
|
243 |
+
'bce_best': bce,
|
244 |
+
}, args.output_dir + snapshot_name + "_best_dice")
|
245 |
+
bce_best = bce
|
246 |
+
with open("predictions_{}.json".format(args.fold), "w") as f:
|
247 |
+
json.dump({"probs": probs, "targets": targets}, f)
|
248 |
+
torch.save({
|
249 |
+
'epoch': current_epoch + 1,
|
250 |
+
'state_dict': model.state_dict(),
|
251 |
+
'bce_best': bce_best,
|
252 |
+
}, args.output_dir + snapshot_name + "_last")
|
253 |
+
print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
|
254 |
+
return bce_best
|
255 |
+
|
256 |
+
|
257 |
+
def validate(net, data_loader, prefix=""):
|
258 |
+
probs = defaultdict(list)
|
259 |
+
targets = defaultdict(list)
|
260 |
+
|
261 |
+
with torch.no_grad():
|
262 |
+
for sample in tqdm(data_loader):
|
263 |
+
imgs = sample["image"].cuda()
|
264 |
+
img_names = sample["img_name"]
|
265 |
+
labels = sample["labels"].cuda().float()
|
266 |
+
out = net(imgs)
|
267 |
+
labels = labels.cpu().numpy()
|
268 |
+
preds = torch.sigmoid(out).cpu().numpy()
|
269 |
+
for i in range(out.shape[0]):
|
270 |
+
video, img_id = img_names[i].split("/")
|
271 |
+
probs[video].append(preds[i].tolist())
|
272 |
+
targets[video].append(labels[i].tolist())
|
273 |
+
data_x = []
|
274 |
+
data_y = []
|
275 |
+
for vid, score in probs.items():
|
276 |
+
score = np.array(score)
|
277 |
+
lbl = targets[vid]
|
278 |
+
|
279 |
+
score = np.mean(score)
|
280 |
+
lbl = np.mean(lbl)
|
281 |
+
data_x.append(score)
|
282 |
+
data_y.append(lbl)
|
283 |
+
y = np.array(data_y)
|
284 |
+
x = np.array(data_x)
|
285 |
+
fake_idx = y > 0.1
|
286 |
+
real_idx = y < 0.1
|
287 |
+
fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
|
288 |
+
real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
|
289 |
+
print("{}fake_loss".format(prefix), fake_loss)
|
290 |
+
print("{}real_loss".format(prefix), real_loss)
|
291 |
+
|
292 |
+
return (fake_loss + real_loss) / 2, probs, targets
|
293 |
+
|
294 |
+
|
295 |
+
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
296 |
+
local_rank, only_valid):
|
297 |
+
losses = AverageMeter()
|
298 |
+
fake_losses = AverageMeter()
|
299 |
+
real_losses = AverageMeter()
|
300 |
+
max_iters = conf["batches_per_epoch"]
|
301 |
+
print("training epoch {}".format(current_epoch))
|
302 |
+
model.train()
|
303 |
+
pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
|
304 |
+
if conf["optimizer"]["schedule"]["mode"] == "epoch":
|
305 |
+
scheduler.step(current_epoch)
|
306 |
+
for i, sample in pbar:
|
307 |
+
imgs = sample["image"].cuda()
|
308 |
+
labels = sample["labels"].cuda().float()
|
309 |
+
out_labels = model(imgs)
|
310 |
+
if only_valid:
|
311 |
+
valid_idx = sample["valid"].cuda().float() > 0
|
312 |
+
out_labels = out_labels[valid_idx]
|
313 |
+
labels = labels[valid_idx]
|
314 |
+
if labels.size(0) == 0:
|
315 |
+
continue
|
316 |
+
|
317 |
+
fake_loss = 0
|
318 |
+
real_loss = 0
|
319 |
+
fake_idx = labels > 0.5
|
320 |
+
real_idx = labels <= 0.5
|
321 |
+
|
322 |
+
ohem = conf.get("ohem_samples", None)
|
323 |
+
if torch.sum(fake_idx * 1) > 0:
|
324 |
+
fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
|
325 |
+
if torch.sum(real_idx * 1) > 0:
|
326 |
+
real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
|
327 |
+
if ohem:
|
328 |
+
fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
|
329 |
+
real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
|
330 |
+
|
331 |
+
loss = (fake_loss + real_loss) / 2
|
332 |
+
losses.update(loss.item(), imgs.size(0))
|
333 |
+
fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
|
334 |
+
real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
|
335 |
+
|
336 |
+
optimizer.zero_grad()
|
337 |
+
pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
|
338 |
+
"fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
|
339 |
+
|
340 |
+
if conf['fp16']:
|
341 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
342 |
+
scaled_loss.backward()
|
343 |
+
else:
|
344 |
+
loss.backward()
|
345 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
|
346 |
+
optimizer.step()
|
347 |
+
torch.cuda.synchronize()
|
348 |
+
if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
|
349 |
+
scheduler.step(i + current_epoch * max_iters)
|
350 |
+
if i == max_iters - 1:
|
351 |
+
break
|
352 |
+
pbar.close()
|
353 |
+
if local_rank == 0:
|
354 |
+
for idx, param_group in enumerate(optimizer.param_groups):
|
355 |
+
lr = param_group['lr']
|
356 |
+
summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
|
357 |
+
summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
|
358 |
+
|
359 |
+
|
360 |
+
if __name__ == '__main__':
|
361 |
+
main()
|
training/tools/__init__.py
ADDED
File without changes
|
training/tools/config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
DEFAULTS = {
|
4 |
+
"network": "dpn",
|
5 |
+
"encoder": "dpn92",
|
6 |
+
"model_params": {},
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"type": "SGD", # supported: SGD, Adam
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 0,
|
12 |
+
"clip": 1.,
|
13 |
+
"learning_rate": 0.1,
|
14 |
+
"classifier_lr": -1,
|
15 |
+
"nesterov": True,
|
16 |
+
"schedule": {
|
17 |
+
"type": "constant", # supported: constant, step, multistep, exponential, linear, poly
|
18 |
+
"mode": "epoch", # supported: epoch, step
|
19 |
+
"epochs": 10,
|
20 |
+
"params": {}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"normalize": {
|
24 |
+
"mean": [0.485, 0.456, 0.406],
|
25 |
+
"std": [0.229, 0.224, 0.225]
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def _merge(src, dst):
|
31 |
+
for k, v in src.items():
|
32 |
+
if k in dst:
|
33 |
+
if isinstance(v, dict):
|
34 |
+
_merge(src[k], dst[k])
|
35 |
+
else:
|
36 |
+
dst[k] = v
|
37 |
+
|
38 |
+
|
39 |
+
def load_config(config_file, defaults=DEFAULTS):
|
40 |
+
with open(config_file, "r") as fd:
|
41 |
+
config = json.load(fd)
|
42 |
+
_merge(defaults, config)
|
43 |
+
return config
|
training/tools/schedulers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bisect import bisect_right
|
2 |
+
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class LRStepScheduler(_LRScheduler):
|
7 |
+
def __init__(self, optimizer, steps, last_epoch=-1):
|
8 |
+
self.lr_steps = steps
|
9 |
+
super().__init__(optimizer, last_epoch)
|
10 |
+
|
11 |
+
def get_lr(self):
|
12 |
+
pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
|
13 |
+
return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
|
14 |
+
|
15 |
+
|
16 |
+
class PolyLR(_LRScheduler):
|
17 |
+
"""Sets the learning rate of each parameter group according to poly learning rate policy
|
18 |
+
"""
|
19 |
+
def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
|
20 |
+
self.max_iter = max_iter
|
21 |
+
self.power = power
|
22 |
+
super(PolyLR, self).__init__(optimizer, last_epoch)
|
23 |
+
|
24 |
+
def get_lr(self):
|
25 |
+
self.last_epoch = (self.last_epoch + 1) % self.max_iter
|
26 |
+
return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
|
27 |
+
|
28 |
+
class ExponentialLRScheduler(_LRScheduler):
|
29 |
+
"""Decays the learning rate of each parameter group by gamma every epoch.
|
30 |
+
When last_epoch=-1, sets initial lr as lr.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
optimizer (Optimizer): Wrapped optimizer.
|
34 |
+
gamma (float): Multiplicative factor of learning rate decay.
|
35 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, optimizer, gamma, last_epoch=-1):
|
39 |
+
self.gamma = gamma
|
40 |
+
super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
|
41 |
+
|
42 |
+
def get_lr(self):
|
43 |
+
if self.last_epoch <= 0:
|
44 |
+
return self.base_lrs
|
45 |
+
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
|
46 |
+
|
training/tools/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from apex.optimizers import FusedAdam, FusedSGD
|
3 |
+
from timm.optim import AdamW
|
4 |
+
from torch import optim
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from torch.optim.rmsprop import RMSprop
|
7 |
+
from torch.optim.adamw import AdamW
|
8 |
+
from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
|
9 |
+
|
10 |
+
from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
|
11 |
+
|
12 |
+
cv2.ocl.setUseOpenCL(False)
|
13 |
+
cv2.setNumThreads(0)
|
14 |
+
|
15 |
+
|
16 |
+
class AverageMeter(object):
|
17 |
+
"""Computes and stores the average and current value"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.reset()
|
21 |
+
|
22 |
+
def reset(self):
|
23 |
+
self.val = 0
|
24 |
+
self.avg = 0
|
25 |
+
self.sum = 0
|
26 |
+
self.count = 0
|
27 |
+
|
28 |
+
def update(self, val, n=1):
|
29 |
+
self.val = val
|
30 |
+
self.sum += val * n
|
31 |
+
self.count += n
|
32 |
+
self.avg = self.sum / self.count
|
33 |
+
|
34 |
+
def create_optimizer(optimizer_config, model, master_params=None):
|
35 |
+
"""Creates optimizer and schedule from configuration
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
optimizer_config : dict
|
40 |
+
Dictionary containing the configuration options for the optimizer.
|
41 |
+
model : Model
|
42 |
+
The network model.
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
optimizer : Optimizer
|
47 |
+
The optimizer.
|
48 |
+
scheduler : LRScheduler
|
49 |
+
The learning rate scheduler.
|
50 |
+
"""
|
51 |
+
if optimizer_config.get("classifier_lr", -1) != -1:
|
52 |
+
# Separate classifier parameters from all others
|
53 |
+
net_params = []
|
54 |
+
classifier_params = []
|
55 |
+
for k, v in model.named_parameters():
|
56 |
+
if not v.requires_grad:
|
57 |
+
continue
|
58 |
+
if k.find("encoder") != -1:
|
59 |
+
net_params.append(v)
|
60 |
+
else:
|
61 |
+
classifier_params.append(v)
|
62 |
+
params = [
|
63 |
+
{"params": net_params},
|
64 |
+
{"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
|
65 |
+
]
|
66 |
+
else:
|
67 |
+
if master_params:
|
68 |
+
params = master_params
|
69 |
+
else:
|
70 |
+
params = model.parameters()
|
71 |
+
|
72 |
+
if optimizer_config["type"] == "SGD":
|
73 |
+
optimizer = optim.SGD(params,
|
74 |
+
lr=optimizer_config["learning_rate"],
|
75 |
+
momentum=optimizer_config["momentum"],
|
76 |
+
weight_decay=optimizer_config["weight_decay"],
|
77 |
+
nesterov=optimizer_config["nesterov"])
|
78 |
+
elif optimizer_config["type"] == "FusedSGD":
|
79 |
+
optimizer = FusedSGD(params,
|
80 |
+
lr=optimizer_config["learning_rate"],
|
81 |
+
momentum=optimizer_config["momentum"],
|
82 |
+
weight_decay=optimizer_config["weight_decay"],
|
83 |
+
nesterov=optimizer_config["nesterov"])
|
84 |
+
elif optimizer_config["type"] == "Adam":
|
85 |
+
optimizer = optim.Adam(params,
|
86 |
+
lr=optimizer_config["learning_rate"],
|
87 |
+
weight_decay=optimizer_config["weight_decay"])
|
88 |
+
elif optimizer_config["type"] == "FusedAdam":
|
89 |
+
optimizer = FusedAdam(params,
|
90 |
+
lr=optimizer_config["learning_rate"],
|
91 |
+
weight_decay=optimizer_config["weight_decay"])
|
92 |
+
elif optimizer_config["type"] == "AdamW":
|
93 |
+
optimizer = AdamW(params,
|
94 |
+
lr=optimizer_config["learning_rate"],
|
95 |
+
weight_decay=optimizer_config["weight_decay"])
|
96 |
+
elif optimizer_config["type"] == "RmsProp":
|
97 |
+
optimizer = RMSprop(params,
|
98 |
+
lr=optimizer_config["learning_rate"],
|
99 |
+
weight_decay=optimizer_config["weight_decay"])
|
100 |
+
else:
|
101 |
+
raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
|
102 |
+
|
103 |
+
if optimizer_config["schedule"]["type"] == "step":
|
104 |
+
scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
105 |
+
elif optimizer_config["schedule"]["type"] == "clr":
|
106 |
+
scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
|
107 |
+
elif optimizer_config["schedule"]["type"] == "multistep":
|
108 |
+
scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
|
109 |
+
elif optimizer_config["schedule"]["type"] == "exponential":
|
110 |
+
scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
111 |
+
elif optimizer_config["schedule"]["type"] == "poly":
|
112 |
+
scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
|
113 |
+
elif optimizer_config["schedule"]["type"] == "constant":
|
114 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
|
115 |
+
elif optimizer_config["schedule"]["type"] == "linear":
|
116 |
+
def linear_lr(it):
|
117 |
+
return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
|
118 |
+
|
119 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
|
120 |
+
|
121 |
+
return optimizer, scheduler
|
training/transforms/__init__.py
ADDED
File without changes
|
training/transforms/albu.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from albumentations import DualTransform, ImageOnlyTransform
|
6 |
+
from albumentations.augmentations.functional import crop
|
7 |
+
|
8 |
+
|
9 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
10 |
+
h, w = img.shape[:2]
|
11 |
+
if max(w, h) == size:
|
12 |
+
return img
|
13 |
+
if w > h:
|
14 |
+
scale = size / w
|
15 |
+
h = h * scale
|
16 |
+
w = size
|
17 |
+
else:
|
18 |
+
scale = size / h
|
19 |
+
w = w * scale
|
20 |
+
h = size
|
21 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
22 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
23 |
+
return resized
|
24 |
+
|
25 |
+
|
26 |
+
class IsotropicResize(DualTransform):
|
27 |
+
def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
|
28 |
+
always_apply=False, p=1):
|
29 |
+
super(IsotropicResize, self).__init__(always_apply, p)
|
30 |
+
self.max_side = max_side
|
31 |
+
self.interpolation_down = interpolation_down
|
32 |
+
self.interpolation_up = interpolation_up
|
33 |
+
|
34 |
+
def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
|
35 |
+
return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
|
36 |
+
interpolation_up=interpolation_up)
|
37 |
+
|
38 |
+
def apply_to_mask(self, img, **params):
|
39 |
+
return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
|
40 |
+
|
41 |
+
def get_transform_init_args_names(self):
|
42 |
+
return ("max_side", "interpolation_down", "interpolation_up")
|
43 |
+
|
44 |
+
|
45 |
+
class Resize4xAndBack(ImageOnlyTransform):
|
46 |
+
def __init__(self, always_apply=False, p=0.5):
|
47 |
+
super(Resize4xAndBack, self).__init__(always_apply, p)
|
48 |
+
|
49 |
+
def apply(self, img, **params):
|
50 |
+
h, w = img.shape[:2]
|
51 |
+
scale = random.choice([2, 4])
|
52 |
+
img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
|
53 |
+
img = cv2.resize(img, (w, h),
|
54 |
+
interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
|
59 |
+
|
60 |
+
def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
|
61 |
+
super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
|
62 |
+
|
63 |
+
self.min_max_height = min_max_height
|
64 |
+
self.w2h_ratio = w2h_ratio
|
65 |
+
|
66 |
+
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
67 |
+
cropped = crop(img, x_min, y_min, x_max, y_max)
|
68 |
+
return cropped
|
69 |
+
|
70 |
+
@property
|
71 |
+
def targets_as_params(self):
|
72 |
+
return ["mask"]
|
73 |
+
|
74 |
+
def get_params_dependent_on_targets(self, params):
|
75 |
+
mask = params["mask"]
|
76 |
+
mask_height, mask_width = mask.shape[:2]
|
77 |
+
crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
|
78 |
+
w2h_ratio = random.uniform(*self.w2h_ratio)
|
79 |
+
crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
|
80 |
+
if mask.sum() == 0:
|
81 |
+
x_min = random.randint(0, mask_width - crop_width + 1)
|
82 |
+
y_min = random.randint(0, mask_height - crop_height + 1)
|
83 |
+
else:
|
84 |
+
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
|
85 |
+
non_zero_yx = np.argwhere(mask)
|
86 |
+
y, x = random.choice(non_zero_yx)
|
87 |
+
x_min = x - random.randint(0, crop_width - 1)
|
88 |
+
y_min = y - random.randint(0, crop_height - 1)
|
89 |
+
x_min = np.clip(x_min, 0, mask_width - crop_width)
|
90 |
+
y_min = np.clip(y_min, 0, mask_height - crop_height)
|
91 |
+
|
92 |
+
x_max = x_min + crop_height
|
93 |
+
y_max = y_min + crop_width
|
94 |
+
y_max = min(mask_height, y_max)
|
95 |
+
x_max = min(mask_width, x_max)
|
96 |
+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
|
97 |
+
|
98 |
+
def get_transform_init_args_names(self):
|
99 |
+
return "min_max_height", "height", "width", "w2h_ratio"
|
training/zoo/__init__.py
ADDED
File without changes
|
training/zoo/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (161 Bytes). View file
|
|
training/zoo/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (165 Bytes). View file
|
|
training/zoo/__pycache__/classifiers.cpython-37.pyc
ADDED
Binary file (5.69 kB). View file
|
|
training/zoo/__pycache__/classifiers.cpython-39.pyc
ADDED
Binary file (5.7 kB). View file
|
|
training/zoo/classifiers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
16 |
+
},
|
17 |
+
"tf_efficientnet_b2_ns": {
|
18 |
+
"features": 1408,
|
19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
20 |
+
},
|
21 |
+
"tf_efficientnet_b4_ns": {
|
22 |
+
"features": 1792,
|
23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
24 |
+
},
|
25 |
+
"tf_efficientnet_b5_ns": {
|
26 |
+
"features": 2048,
|
27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
28 |
+
},
|
29 |
+
"tf_efficientnet_b4_ns_03d": {
|
30 |
+
"features": 1792,
|
31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
32 |
+
},
|
33 |
+
"tf_efficientnet_b5_ns_03d": {
|
34 |
+
"features": 2048,
|
35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
36 |
+
},
|
37 |
+
"tf_efficientnet_b5_ns_04d": {
|
38 |
+
"features": 2048,
|
39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
40 |
+
},
|
41 |
+
"tf_efficientnet_b6_ns": {
|
42 |
+
"features": 2304,
|
43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
44 |
+
},
|
45 |
+
"tf_efficientnet_b7": {
|
46 |
+
"features": 2560,
|
47 |
+
"init_op": partial(tf_efficientnet_b7, pretrained=True, drop_path_rate=0.2)
|
48 |
+
},
|
49 |
+
"tf_efficientnet_b6_ns_04d": {
|
50 |
+
"features": 2304,
|
51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
57 |
+
"""Creates the SRM kernels for noise analysis."""
|
58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
59 |
+
srm_kernel = torch.from_numpy(np.array([
|
60 |
+
[ # srm 1/2 horiz
|
61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
66 |
+
], [ # srm 1/4
|
67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
72 |
+
], [ # srm 1/12
|
73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
78 |
+
]
|
79 |
+
])).float()
|
80 |
+
srm_kernel[0] /= 2
|
81 |
+
srm_kernel[1] /= 4
|
82 |
+
srm_kernel[2] /= 12
|
83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
84 |
+
|
85 |
+
|
86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
88 |
+
weights = setup_srm_weights(input_channels)
|
89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
90 |
+
with torch.no_grad():
|
91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
92 |
+
return conv
|
93 |
+
|
94 |
+
|
95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
100 |
+
self.srm_conv = setup_srm_layer(3)
|
101 |
+
self.dropout = Dropout(dropout_rate)
|
102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
noise = self.srm_conv(x)
|
106 |
+
x = self.encoder.forward_features(noise)
|
107 |
+
x = self.avg_pool(x).flatten(1)
|
108 |
+
x = self.dropout(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
114 |
+
"""
|
115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, features: int, flatten=False):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
122 |
+
self.flatten = flatten
|
123 |
+
|
124 |
+
def fscore(self, x):
|
125 |
+
m = self.conv(x)
|
126 |
+
m = m.sigmoid().exp()
|
127 |
+
return m
|
128 |
+
|
129 |
+
def norm(self, x: torch.Tensor):
|
130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
input_x = x
|
134 |
+
x = self.fscore(x)
|
135 |
+
x = self.norm(x)
|
136 |
+
x = x * input_x
|
137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class DeepFakeClassifier(nn.Module):
|
142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
146 |
+
self.dropout = Dropout(dropout_rate)
|
147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.encoder.forward_features(x)
|
151 |
+
x = self.avg_pool(x).flatten(1)
|
152 |
+
x = self.dropout(x)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
164 |
+
self.dropout = Dropout(dropout_rate)
|
165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = self.encoder.forward_features(x)
|
169 |
+
x = self.avg_pool(x).flatten(1)
|
170 |
+
x = self.dropout(x)
|
171 |
+
x = self.fc(x)
|
172 |
+
return x
|
training/zoo/unet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Dropout2d, Conv2d
|
7 |
+
from torch.nn.modules.dropout import Dropout
|
8 |
+
from torch.nn.modules.linear import Linear
|
9 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
10 |
+
from torch.nn.modules.upsampling import UpsamplingBilinear2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"filters": [40, 32, 48, 136, 1536],
|
16 |
+
"decoder_filters": [64, 128, 256, 256],
|
17 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
18 |
+
},
|
19 |
+
"tf_efficientnet_b5_ns": {
|
20 |
+
"features": 2048,
|
21 |
+
"filters": [48, 40, 64, 176, 2048],
|
22 |
+
"decoder_filters": [64, 128, 256, 256],
|
23 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderBlock(nn.Module):
|
29 |
+
def __init__(self, in_channels, out_channels):
|
30 |
+
super().__init__()
|
31 |
+
self.layer = nn.Sequential(
|
32 |
+
nn.Upsample(scale_factor=2),
|
33 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
34 |
+
nn.ReLU(inplace=True)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConcatBottleneck(nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels):
|
43 |
+
super().__init__()
|
44 |
+
self.seq = nn.Sequential(
|
45 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
46 |
+
nn.ReLU(inplace=True)
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, dec, enc):
|
50 |
+
x = torch.cat([dec, enc], dim=1)
|
51 |
+
return self.seq(x)
|
52 |
+
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(self, decoder_filters, filters, upsample_filters=None,
|
56 |
+
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
|
57 |
+
super().__init__()
|
58 |
+
self.decoder_filters = decoder_filters
|
59 |
+
self.filters = filters
|
60 |
+
self.decoder_block = decoder_block
|
61 |
+
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
|
62 |
+
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
|
63 |
+
for i, f in enumerate(reversed(decoder_filters))])
|
64 |
+
self.dropout = Dropout2d(dropout) if dropout > 0 else None
|
65 |
+
self.last_block = None
|
66 |
+
if upsample_filters:
|
67 |
+
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
|
68 |
+
else:
|
69 |
+
self.last_block = UpsamplingBilinear2d(scale_factor=2)
|
70 |
+
|
71 |
+
def forward(self, encoder_results: list):
|
72 |
+
x = encoder_results[0]
|
73 |
+
bottlenecks = self.bottlenecks
|
74 |
+
for idx, bottleneck in enumerate(bottlenecks):
|
75 |
+
rev_idx = - (idx + 1)
|
76 |
+
x = self.decoder_stages[rev_idx](x)
|
77 |
+
x = bottleneck(x, encoder_results[-rev_idx])
|
78 |
+
if self.last_block:
|
79 |
+
x = self.last_block(x)
|
80 |
+
if self.dropout:
|
81 |
+
x = self.dropout(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def _get_decoder(self, layer):
|
85 |
+
idx = layer + 1
|
86 |
+
if idx == len(self.decoder_filters):
|
87 |
+
in_channels = self.filters[idx]
|
88 |
+
else:
|
89 |
+
in_channels = self.decoder_filters[idx]
|
90 |
+
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
|
91 |
+
|
92 |
+
|
93 |
+
def _initialize_weights(module):
|
94 |
+
for m in module.modules():
|
95 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
96 |
+
m.weight.data = nn.init.kaiming_normal_(m.weight.data)
|
97 |
+
if m.bias is not None:
|
98 |
+
m.bias.data.zero_()
|
99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
100 |
+
m.weight.data.fill_(1)
|
101 |
+
m.bias.data.zero_()
|
102 |
+
|
103 |
+
|
104 |
+
class EfficientUnetClassifier(nn.Module):
|
105 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
106 |
+
super().__init__()
|
107 |
+
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
|
108 |
+
filters=encoder_params[encoder]["filters"])
|
109 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
110 |
+
self.dropout = Dropout(dropout_rate)
|
111 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
112 |
+
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
|
113 |
+
_initialize_weights(self)
|
114 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
115 |
+
|
116 |
+
def get_encoder_features(self, x):
|
117 |
+
encoder_results = []
|
118 |
+
x = self.encoder.conv_stem(x)
|
119 |
+
x = self.encoder.bn1(x)
|
120 |
+
x = self.encoder.act1(x)
|
121 |
+
encoder_results.append(x)
|
122 |
+
x = self.encoder.blocks[:2](x)
|
123 |
+
encoder_results.append(x)
|
124 |
+
x = self.encoder.blocks[2:3](x)
|
125 |
+
encoder_results.append(x)
|
126 |
+
x = self.encoder.blocks[3:5](x)
|
127 |
+
encoder_results.append(x)
|
128 |
+
x = self.encoder.blocks[5:](x)
|
129 |
+
x = self.encoder.conv_head(x)
|
130 |
+
x = self.encoder.bn2(x)
|
131 |
+
x = self.encoder.act2(x)
|
132 |
+
encoder_results.append(x)
|
133 |
+
encoder_results = list(reversed(encoder_results))
|
134 |
+
return encoder_results
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
encoder_results = self.get_encoder_features(x)
|
138 |
+
seg = self.final(self.decoder(encoder_results))
|
139 |
+
x = encoder_results[0]
|
140 |
+
x = self.avg_pool(x).flatten(1)
|
141 |
+
x = self.dropout(x)
|
142 |
+
x = self.fc(x)
|
143 |
+
return x, seg
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
|
148 |
+
model.eval()
|
149 |
+
with torch.no_grad():
|
150 |
+
input = torch.rand(4, 3, 224, 224)
|
151 |
+
print(model(input))
|
utils.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from albumentations.augmentations.functional import image_compression
|
8 |
+
from facenet_pytorch.models.mtcnn import MTCNN
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
from torchvision.transforms import Normalize
|
12 |
+
|
13 |
+
mean = [0.485, 0.456, 0.406]
|
14 |
+
std = [0.229, 0.224, 0.225]
|
15 |
+
normalize_transform = Normalize(mean, std)
|
16 |
+
|
17 |
+
|
18 |
+
class VideoReader:
|
19 |
+
"""Helper class for reading one or more frames from a video file."""
|
20 |
+
|
21 |
+
def __init__(self, verbose=True, insets=(0, 0)):
|
22 |
+
"""Creates a new VideoReader.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
verbose: whether to print warnings and error messages
|
26 |
+
insets: amount to inset the image by, as a percentage of
|
27 |
+
(width, height). This lets you "zoom in" to an image
|
28 |
+
to remove unimportant content around the borders.
|
29 |
+
Useful for face detection, which may not work if the
|
30 |
+
faces are too small.
|
31 |
+
"""
|
32 |
+
self.verbose = verbose
|
33 |
+
self.insets = insets
|
34 |
+
|
35 |
+
def read_frames(self, path, num_frames, jitter=0, seed=None):
|
36 |
+
"""Reads frames that are always evenly spaced throughout the video.
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
path: the video file
|
40 |
+
num_frames: how many frames to read, -1 means the entire video
|
41 |
+
(warning: this will take up a lot of memory!)
|
42 |
+
jitter: if not 0, adds small random offsets to the frame indices;
|
43 |
+
this is useful so we don't always land on even or odd frames
|
44 |
+
seed: random seed for jittering; if you set this to a fixed value,
|
45 |
+
you probably want to set it only on the first video
|
46 |
+
"""
|
47 |
+
assert num_frames > 0
|
48 |
+
|
49 |
+
capture = cv2.VideoCapture(path)
|
50 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
51 |
+
if frame_count <= 0: return None
|
52 |
+
|
53 |
+
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int)
|
54 |
+
if jitter > 0:
|
55 |
+
np.random.seed(seed)
|
56 |
+
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
|
57 |
+
frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
|
58 |
+
|
59 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
60 |
+
capture.release()
|
61 |
+
return result
|
62 |
+
|
63 |
+
def read_random_frames(self, path, num_frames, seed=None):
|
64 |
+
"""Picks the frame indices at random.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
path: the video file
|
68 |
+
num_frames: how many frames to read, -1 means the entire video
|
69 |
+
(warning: this will take up a lot of memory!)
|
70 |
+
"""
|
71 |
+
assert num_frames > 0
|
72 |
+
np.random.seed(seed)
|
73 |
+
|
74 |
+
capture = cv2.VideoCapture(path)
|
75 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
76 |
+
if frame_count <= 0: return None
|
77 |
+
|
78 |
+
frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
|
79 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
80 |
+
|
81 |
+
capture.release()
|
82 |
+
return result
|
83 |
+
|
84 |
+
def read_frames_at_indices(self, path, frame_idxs):
|
85 |
+
"""Reads frames from a video and puts them into a NumPy array.
|
86 |
+
|
87 |
+
Arguments:
|
88 |
+
path: the video file
|
89 |
+
frame_idxs: a list of frame indices. Important: should be
|
90 |
+
sorted from low-to-high! If an index appears multiple
|
91 |
+
times, the frame is still read only once.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
- a NumPy array of shape (num_frames, height, width, 3)
|
95 |
+
- a list of the frame indices that were read
|
96 |
+
|
97 |
+
Reading stops if loading a frame fails, in which case the first
|
98 |
+
dimension returned may actually be less than num_frames.
|
99 |
+
|
100 |
+
Returns None if an exception is thrown for any reason, or if no
|
101 |
+
frames were read.
|
102 |
+
"""
|
103 |
+
assert len(frame_idxs) > 0
|
104 |
+
capture = cv2.VideoCapture(path)
|
105 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
106 |
+
capture.release()
|
107 |
+
return result
|
108 |
+
|
109 |
+
def _read_frames_at_indices(self, path, capture, frame_idxs):
|
110 |
+
try:
|
111 |
+
frames = []
|
112 |
+
idxs_read = []
|
113 |
+
for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
|
114 |
+
# Get the next frame, but don't decode if we're not using it.
|
115 |
+
ret = capture.grab()
|
116 |
+
if not ret:
|
117 |
+
if self.verbose:
|
118 |
+
print("Error grabbing frame %d from movie %s" % (frame_idx, path))
|
119 |
+
break
|
120 |
+
|
121 |
+
# Need to look at this frame?
|
122 |
+
current = len(idxs_read)
|
123 |
+
if frame_idx == frame_idxs[current]:
|
124 |
+
ret, frame = capture.retrieve()
|
125 |
+
if not ret or frame is None:
|
126 |
+
if self.verbose:
|
127 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
128 |
+
break
|
129 |
+
|
130 |
+
frame = self._postprocess_frame(frame)
|
131 |
+
frames.append(frame)
|
132 |
+
idxs_read.append(frame_idx)
|
133 |
+
|
134 |
+
if len(frames) > 0:
|
135 |
+
return np.stack(frames), idxs_read
|
136 |
+
if self.verbose:
|
137 |
+
print("No frames read from movie %s" % path)
|
138 |
+
return None
|
139 |
+
except:
|
140 |
+
if self.verbose:
|
141 |
+
print("Exception while reading movie %s" % path)
|
142 |
+
return None
|
143 |
+
|
144 |
+
def read_middle_frame(self, path):
|
145 |
+
"""Reads the frame from the middle of the video."""
|
146 |
+
capture = cv2.VideoCapture(path)
|
147 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
148 |
+
result = self._read_frame_at_index(path, capture, frame_count // 2)
|
149 |
+
capture.release()
|
150 |
+
return result
|
151 |
+
|
152 |
+
def read_frame_at_index(self, path, frame_idx):
|
153 |
+
"""Reads a single frame from a video.
|
154 |
+
|
155 |
+
If you just want to read a single frame from the video, this is more
|
156 |
+
efficient than scanning through the video to find the frame. However,
|
157 |
+
for reading multiple frames it's not efficient.
|
158 |
+
|
159 |
+
My guess is that a "streaming" approach is more efficient than a
|
160 |
+
"random access" approach because, unless you happen to grab a keyframe,
|
161 |
+
the decoder still needs to read all the previous frames in order to
|
162 |
+
reconstruct the one you're asking for.
|
163 |
+
|
164 |
+
Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
|
165 |
+
or None if reading failed.
|
166 |
+
"""
|
167 |
+
capture = cv2.VideoCapture(path)
|
168 |
+
result = self._read_frame_at_index(path, capture, frame_idx)
|
169 |
+
capture.release()
|
170 |
+
return result
|
171 |
+
|
172 |
+
def _read_frame_at_index(self, path, capture, frame_idx):
|
173 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
174 |
+
ret, frame = capture.read()
|
175 |
+
if not ret or frame is None:
|
176 |
+
if self.verbose:
|
177 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
178 |
+
return None
|
179 |
+
else:
|
180 |
+
frame = self._postprocess_frame(frame)
|
181 |
+
return np.expand_dims(frame, axis=0), [frame_idx]
|
182 |
+
|
183 |
+
def _postprocess_frame(self, frame):
|
184 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
185 |
+
|
186 |
+
if self.insets[0] > 0:
|
187 |
+
W = frame.shape[1]
|
188 |
+
p = int(W * self.insets[0])
|
189 |
+
frame = frame[:, p:-p, :]
|
190 |
+
|
191 |
+
if self.insets[1] > 0:
|
192 |
+
H = frame.shape[1]
|
193 |
+
q = int(H * self.insets[1])
|
194 |
+
frame = frame[q:-q, :, :]
|
195 |
+
|
196 |
+
return frame
|
197 |
+
|
198 |
+
|
199 |
+
class FaceExtractor:
|
200 |
+
def __init__(self, video_read_fn):
|
201 |
+
self.video_read_fn = video_read_fn
|
202 |
+
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cpu")
|
203 |
+
|
204 |
+
def process_videos(self, video, video_idxs):
|
205 |
+
videos_read = []
|
206 |
+
frames_read = []
|
207 |
+
frames = []
|
208 |
+
results = []
|
209 |
+
for video_idx in video_idxs:
|
210 |
+
# Read the full-size frames from this video.
|
211 |
+
result = self.video_read_fn(video)
|
212 |
+
# Error? Then skip this video.
|
213 |
+
if result is None: continue
|
214 |
+
|
215 |
+
videos_read.append(video_idx)
|
216 |
+
|
217 |
+
# Keep track of the original frames (need them later).
|
218 |
+
my_frames, my_idxs = result
|
219 |
+
|
220 |
+
frames.append(my_frames)
|
221 |
+
frames_read.append(my_idxs)
|
222 |
+
for i, frame in enumerate(my_frames):
|
223 |
+
h, w = frame.shape[:2]
|
224 |
+
img = Image.fromarray(frame.astype(np.uint8))
|
225 |
+
img = img.resize(size=[s // 2 for s in img.size])
|
226 |
+
|
227 |
+
batch_boxes, probs = self.detector.detect(img, landmarks=False)
|
228 |
+
|
229 |
+
faces = []
|
230 |
+
scores = []
|
231 |
+
if batch_boxes is None:
|
232 |
+
continue
|
233 |
+
for bbox, score in zip(batch_boxes, probs):
|
234 |
+
if bbox is not None:
|
235 |
+
xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
|
236 |
+
w = xmax - xmin
|
237 |
+
h = ymax - ymin
|
238 |
+
p_h = h // 3
|
239 |
+
p_w = w // 3
|
240 |
+
crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
|
241 |
+
faces.append(crop)
|
242 |
+
scores.append(score)
|
243 |
+
|
244 |
+
frame_dict = {"video_idx": video_idx,
|
245 |
+
"frame_idx": my_idxs[i],
|
246 |
+
"frame_w": w,
|
247 |
+
"frame_h": h,
|
248 |
+
"faces": faces,
|
249 |
+
"scores": scores}
|
250 |
+
results.append(frame_dict)
|
251 |
+
|
252 |
+
return results
|
253 |
+
|
254 |
+
def process_video(self, video):
|
255 |
+
"""Convenience method for doing face extraction on a single video."""
|
256 |
+
return self.process_videos(video, [0])
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
def confident_strategy(pred, t=0.8):
|
261 |
+
pred = np.array(pred)
|
262 |
+
sz = len(pred)
|
263 |
+
fakes = np.count_nonzero(pred > t)
|
264 |
+
# 11 frames are detected as fakes with high probability
|
265 |
+
if fakes > sz // 2.5 and fakes > 11:
|
266 |
+
return np.mean(pred[pred > t])
|
267 |
+
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
|
268 |
+
return np.mean(pred[pred < 0.2])
|
269 |
+
else:
|
270 |
+
return np.mean(pred)
|
271 |
+
|
272 |
+
strategy = confident_strategy
|
273 |
+
|
274 |
+
|
275 |
+
def put_to_center(img, input_size):
|
276 |
+
img = img[:input_size, :input_size]
|
277 |
+
image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
|
278 |
+
start_w = (input_size - img.shape[1]) // 2
|
279 |
+
start_h = (input_size - img.shape[0]) // 2
|
280 |
+
image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
|
281 |
+
return image
|
282 |
+
|
283 |
+
|
284 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
285 |
+
h, w = img.shape[:2]
|
286 |
+
if max(w, h) == size:
|
287 |
+
return img
|
288 |
+
if w > h:
|
289 |
+
scale = size / w
|
290 |
+
h = h * scale
|
291 |
+
w = size
|
292 |
+
else:
|
293 |
+
scale = size / h
|
294 |
+
w = w * scale
|
295 |
+
h = size
|
296 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
297 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
298 |
+
return resized
|
299 |
+
|
300 |
+
|
301 |
+
def predict_on_video(face_extractor, video, batch_size, input_size, models, strategy=np.mean,
|
302 |
+
apply_compression=False):
|
303 |
+
batch_size *= 4
|
304 |
+
try:
|
305 |
+
faces = face_extractor.process_video(video)
|
306 |
+
if len(faces) > 0:
|
307 |
+
x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
|
308 |
+
n = 0
|
309 |
+
for frame_data in faces:
|
310 |
+
for face in frame_data["faces"]:
|
311 |
+
resized_face = isotropically_resize_image(face, input_size)
|
312 |
+
resized_face = put_to_center(resized_face, input_size)
|
313 |
+
if apply_compression:
|
314 |
+
resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
|
315 |
+
if n + 1 < batch_size:
|
316 |
+
x[n] = resized_face
|
317 |
+
n += 1
|
318 |
+
else:
|
319 |
+
pass
|
320 |
+
if n > 0:
|
321 |
+
x = torch.tensor(x, device="cpu").float()
|
322 |
+
# Preprocess the images.
|
323 |
+
x = x.permute((0, 3, 1, 2))
|
324 |
+
for i in range(len(x)):
|
325 |
+
x[i] = normalize_transform(x[i] / 255.)
|
326 |
+
# Make a prediction, then take the average.
|
327 |
+
with torch.no_grad():
|
328 |
+
preds = []
|
329 |
+
for model in models:
|
330 |
+
y_pred = model(x[:n].float())
|
331 |
+
y_pred = torch.sigmoid(y_pred.squeeze())
|
332 |
+
bpred = y_pred[:n].cpu().numpy()
|
333 |
+
preds.append(strategy(bpred))
|
334 |
+
return np.mean(preds)
|
335 |
+
except Exception as e:
|
336 |
+
print("Prediction error on video: %s" % str(e))
|
337 |
+
|
338 |
+
return 0.5
|
339 |
+
|
340 |
+
|
341 |
+
def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
|
342 |
+
strategy=np.mean,
|
343 |
+
apply_compression=False):
|
344 |
+
def process_file(i):
|
345 |
+
filename = videos[i]
|
346 |
+
y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
|
347 |
+
input_size=input_size,
|
348 |
+
batch_size=frames_per_video,
|
349 |
+
models=models, strategy=strategy, apply_compression=apply_compression)
|
350 |
+
return y_pred
|
351 |
+
|
352 |
+
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
353 |
+
predictions = ex.map(process_file, range(len(videos)))
|
354 |
+
return list(predictions)
|
weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
|
3 |
+
size 266910617
|