russel0719's picture
Upload 38 files
55da56b
import gradio as gr
import os
import re
import torch
from utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
from training.zoo.classifiers import DeepFakeClassifier
def detect(video):
# Load model
model = DeepFakeClassifier(encoder="tf_efficientnet_b7")
path = os.path.join('weights', 'final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36')
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
model.eval()
del checkpoint
models = [model.float()]
# Setting Video
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 380
strategy = confident_strategy
# Predict
pred = predict_on_video(
face_extractor=face_extractor,
video=video,
batch_size=frames_per_video,
input_size=input_size,
models=models,
strategy=strategy
)
prob = {'Fake': float(pred), 'Real': float(1 - pred)}
return prob
gr_inputs = gr.Video(format='mp4', source='upload')
gr_outputs = gr.Label()
gr_ex = [
[os.path.join(os.path.dirname(__file__),"sample/sample1.mp4")],
[os.path.join(os.path.dirname(__file__),"sample/sample2.mp4")],
]
iface = gr.Interface(
fn=detect,
inputs=gr_inputs,
outputs=gr_outputs,
examples=gr_ex,
)
iface.launch()