Spaces:
Runtime error
Runtime error
jungwoonshin
commited on
Commit
•
d8e6c94
1
Parent(s):
4de0f77
second commit
Browse files- training/pipelines/app.py +67 -0
training/pipelines/app.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
import os, sys
|
11 |
+
root_folder = os.path.abspath(
|
12 |
+
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
13 |
+
)
|
14 |
+
sys.path.append(root_folder)
|
15 |
+
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
|
16 |
+
from training.zoo.classifiers import DeepFakeClassifier
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def predict(video_index):
|
21 |
+
video_index = int(video_index)
|
22 |
+
|
23 |
+
frames_per_video = 32
|
24 |
+
video_reader = VideoReader()
|
25 |
+
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
26 |
+
face_extractor = FaceExtractor(video_read_fn)
|
27 |
+
input_size = 380
|
28 |
+
strategy = confident_strategy
|
29 |
+
|
30 |
+
test_videos = sorted([x for x in os.listdir(args.test_dir) if x[-4:] == ".mp4"])[video_index]
|
31 |
+
print(f"Predicting {video_index} videos")
|
32 |
+
predictions = predict_on_video_set(face_extractor=face_extractor, input_size=input_size, models=models,
|
33 |
+
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
|
34 |
+
num_workers=6, test_dir=args.test_dir)
|
35 |
+
return predictions
|
36 |
+
|
37 |
+
def get_args_models():
|
38 |
+
parser = argparse.ArgumentParser("Predict test videos")
|
39 |
+
arg = parser.add_argument
|
40 |
+
arg('--weights-dir', type=str, default="weights", help="path to directory with checkpoints")
|
41 |
+
arg('--models', type=str, default='classifier_DeepFakeClassifier_tf_efficientnet_b7_ns_1_best_dice', help="checkpoint files") # nargs='+',
|
42 |
+
arg('--test-dir', type=str, default='test_dataset', help="path to directory with videos")
|
43 |
+
arg('--output', type=str, required=False, help="path to output csv", default="submission.csv")
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
models = []
|
47 |
+
# model_paths = [os.path.join(args.weights_dir, model) for model in args.models]
|
48 |
+
model_paths = [os.path.join(args.weights_dir, args.models)]
|
49 |
+
for path in model_paths:
|
50 |
+
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
|
51 |
+
print("loading state dict {}".format(path))
|
52 |
+
checkpoint = torch.load(path, map_location="cpu")
|
53 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
54 |
+
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
55 |
+
model.eval()
|
56 |
+
del checkpoint
|
57 |
+
models.append(model.half())
|
58 |
+
return args, models
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
global models, args
|
62 |
+
stime = time.time()
|
63 |
+
print("Elapsed:", time.time() - stime)
|
64 |
+
args, models = get_args_models()
|
65 |
+
demo = gr.Interface(fn=predict, inputs="text", outputs="text")
|
66 |
+
demo.launch()
|
67 |
+
|