VideoDetection / app.py
malmukhtar's picture
Update app.py
b164597
raw
history blame
2.28 kB
# !git clone https://github.com/polimi-ispl/icpr2020dfdc
# !pip install efficientnet-pytorch
# !pip install -U git+https://github.com/albu/albumentations > /dev/null
# %cd icpr2020dfdc/notebook
import torch
from torch.utils.model_zoo import load_url
from PIL import Image
from scipy.special import expit
import sys
sys.path.append('./icpr2020dfdc/')
from blazeface import FaceExtractor, BlazeFace, VideoReader
from architectures import fornet,weights
from isplutils import utils
import gradio as gr
"""
Choose an architecture between
- EfficientNetB4
- EfficientNetB4ST
- EfficientNetAutoAttB4
- EfficientNetAutoAttB4ST
- Xception
"""
net_model = 'EfficientNetAutoAttB4'
"""
Choose a training dataset between
- DFDC
- FFPP
"""
train_db = 'DFDC'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
face_policy = 'scale'
face_size = 224
frames_per_video = 32
model_url = weights.weight_url['{:s}_{:s}'.format(net_model,train_db)]
net = getattr(fornet,net_model)().eval().to(device)
net.load_state_dict(load_url(model_url,map_location=device,check_hash=True))
transf = utils.get_transformer(face_policy, face_size, net.get_normalizer(), train=False)
facedet = BlazeFace().to(device)
facedet.load_weights("./icpr2020dfdc/blazeface/blazeface.pth")
facedet.load_anchors("./icpr2020dfdc/blazeface/anchors.npy")
videoreader = VideoReader(verbose=False)
video_read_fn = lambda x: videoreader.read_frames(x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn=video_read_fn,facedet=facedet)
title = "FaceForensics++"
def inference(vid):
#return "./Labels/Fake.png", f"{vid}"
vid_real_faces = face_extractor.process_video(vid)
faces_real_t = torch.stack( [ transf(image=frame['faces'][0])['image'] for frame in vid_real_faces if len(frame['faces'])] )
with torch.no_grad():
faces_real_pred = net(faces_real_t.to(device)).cpu().numpy().flatten()
res = expit(faces_real_pred.mean())
if res >= 0.5:
return "./Labels/Fake.png", f"{res*100:.2f}%"
else:
return "./Labels/Real.jpg", f"{res*100:.2f}%"
demo = gr.Interface(
fn=inference,
inputs=[gr.inputs.Video(type="mp4", label="In")],
outputs=[gr.outputs.Image(type="pil"), "text"]
).launch(debug=True)