Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import re | |
import torch | |
from utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video | |
from training.zoo.classifiers import DeepFakeClassifier | |
def detect(video): | |
# Load model | |
model = DeepFakeClassifier(encoder="tf_efficientnet_b7") | |
path = os.path.join('weights', 'final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36') | |
checkpoint = torch.load(path, map_location="cpu") | |
state_dict = checkpoint.get("state_dict", checkpoint) | |
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True) | |
model.eval() | |
del checkpoint | |
models = [model.float()] | |
# Setting Video | |
frames_per_video = 32 | |
video_reader = VideoReader() | |
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) | |
face_extractor = FaceExtractor(video_read_fn) | |
input_size = 380 | |
strategy = confident_strategy | |
# Predict | |
pred = predict_on_video( | |
face_extractor=face_extractor, | |
video=video, | |
batch_size=frames_per_video, | |
input_size=input_size, | |
models=models, | |
strategy=strategy | |
) | |
prob = {'Fake': float(pred), 'Real': float(1 - pred)} | |
return prob | |
gr_inputs = gr.Video(format='mp4', source='upload') | |
gr_outputs = gr.Label() | |
gr_ex = [ | |
[os.path.join(os.path.dirname(__file__),"sample/sample1.mp4")], | |
[os.path.join(os.path.dirname(__file__),"sample/sample2.mp4")], | |
] | |
iface = gr.Interface( | |
fn=detect, | |
inputs=gr_inputs, | |
outputs=gr_outputs, | |
examples=gr_ex, | |
) | |
iface.launch() |