jungwoonshin
commited on
Commit
•
a8ff7ce
1
Parent(s):
5cd7059
Delete training
Browse files
training/pipelines/app.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import argparse
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import time
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import pandas as pd
|
9 |
-
|
10 |
-
import os, sys
|
11 |
-
root_folder = os.path.abspath(
|
12 |
-
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
-
)
|
14 |
-
sys.path.append(root_folder)
|
15 |
-
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
|
16 |
-
from training.zoo.classifiers import DeepFakeClassifier
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
def predict(video_index):
|
21 |
-
video_index = int(video_index)
|
22 |
-
|
23 |
-
frames_per_video = 32
|
24 |
-
video_reader = VideoReader()
|
25 |
-
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
26 |
-
face_extractor = FaceExtractor(video_read_fn)
|
27 |
-
input_size = 380
|
28 |
-
strategy = confident_strategy
|
29 |
-
|
30 |
-
test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
|
31 |
-
print(f"Predicting {video_index} videos")
|
32 |
-
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
|
33 |
-
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
|
34 |
-
num_workers=6, test_dir=args.test_dir)
|
35 |
-
return predictions
|
36 |
-
|
37 |
-
def get_args_models():
|
38 |
-
parser = argparse.ArgumentParser("Predict test videos")
|
39 |
-
arg = parser.add_argument
|
40 |
-
arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
|
41 |
-
arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
|
42 |
-
arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
|
43 |
-
arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
|
44 |
-
args = parser.parse_args()
|
45 |
-
|
46 |
-
models = []
|
47 |
-
# model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
|
48 |
-
model_paths = [os.path.join(args.weights_dir, args.models)]
|
49 |
-
for path in model_paths:
|
50 |
-
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
|
51 |
-
print("loading state dict {}".format(path))
|
52 |
-
checkpoint = torch.load(path, map_location="cpu")
|
53 |
-
state_dict = checkpoint.get("state_dict", checkpoint)
|
54 |
-
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
55 |
-
model.eval()
|
56 |
-
del checkpoint
|
57 |
-
models.append(model.half())
|
58 |
-
return args, models
|
59 |
-
|
60 |
-
if __name__ == '__main__':
|
61 |
-
global models, args
|
62 |
-
stime = time.time()
|
63 |
-
print("Elapsed:", time.time() - stime)
|
64 |
-
args, models = get_args_models()
|
65 |
-
demo = gr.Interface(fn=predict, inputs="text", outputs="text")
|
66 |
-
demo.launch()
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training/pipelines/train_classifier_gradio.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import argparse
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import time
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import pandas as pd
|
9 |
-
|
10 |
-
import os, sys
|
11 |
-
root_folder = os.path.abspath(
|
12 |
-
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
-
)
|
14 |
-
sys.path.append(root_folder)
|
15 |
-
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
|
16 |
-
from training.zoo.classifiers import DeepFakeClassifier
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
def predict(video_index):
|
21 |
-
video_index = int(video_index)
|
22 |
-
|
23 |
-
frames_per_video = 32
|
24 |
-
video_reader = VideoReader()
|
25 |
-
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
26 |
-
face_extractor = FaceExtractor(video_read_fn)
|
27 |
-
input_size = 380
|
28 |
-
strategy = confident_strategy
|
29 |
-
|
30 |
-
test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
|
31 |
-
print(f"Predicting {video_index} videos")
|
32 |
-
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
|
33 |
-
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
|
34 |
-
num_workers=6, test_dir=args.test_dir)
|
35 |
-
return predictions
|
36 |
-
|
37 |
-
def get_args_models():
|
38 |
-
parser = argparse.ArgumentParser("Predict test videos")
|
39 |
-
arg = parser.add_argument
|
40 |
-
arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
|
41 |
-
arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
|
42 |
-
arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
|
43 |
-
arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
|
44 |
-
args = parser.parse_args()
|
45 |
-
|
46 |
-
models = []
|
47 |
-
# model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
|
48 |
-
model_paths = [os.path.join(args.weights_dir, args.models)]
|
49 |
-
for path in model_paths:
|
50 |
-
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
|
51 |
-
print("loading state dict {}".format(path))
|
52 |
-
checkpoint = torch.load(path, map_location="cpu")
|
53 |
-
state_dict = checkpoint.get("state_dict", checkpoint)
|
54 |
-
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
55 |
-
model.eval()
|
56 |
-
del checkpoint
|
57 |
-
models.append(model.half())
|
58 |
-
return args, models
|
59 |
-
|
60 |
-
if __name__ == '__main__':
|
61 |
-
global models, args
|
62 |
-
stime = time.time()
|
63 |
-
print("Elapsed:", time.time() - stime)
|
64 |
-
args, models = get_args_models()
|
65 |
-
demo = gr.Interface(fn=predict, inputs="text", outputs="text")
|
66 |
-
demo.launch()
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|