Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
""" | |
Created on Sun Jan 28 18:48:07 2024 | |
@author: liewchooichin | |
""" | |
import os | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from huggingface_hub import snapshot_download | |
from huggingface_hub import from_pretrained_keras | |
# check the tensoflow version | |
print(f"tensorflow version: {tf.__version__}") | |
# global variables | |
# predictions from: | |
pred_binary = "" # binary labels | |
pred_multi = "" # multi labels | |
# sample files | |
samples = [] | |
labels = [] | |
data_dir = "face_samples" | |
# local testing | |
LOCAL_TEST = False # when in HF, set to False | |
HF_SPACE = True # when in HF | |
# My model in the HF repo | |
REPO_ID_BINARY = 'liewchooichin/fake_binary' | |
REPO_ID_MULTILABEL = 'liewchooichin/fake_multilabel' | |
# tf_model = None | |
# keras_model = None | |
local_model_dir = "fake_models" | |
pb_name = "saved_model.pb" | |
keras_binary_label = os.path.join("binary_label", "all_binary_6771.keras") | |
keras_multilabel = os.path.join("multi_label", "multi_7036.keras") | |
def get_samples(): | |
samples_path = os.path.join( | |
os.path.dirname(__file__), | |
data_dir | |
) | |
samples_path = pathlib.Path(samples_path) | |
files = list(samples_path.glob("*.jpg")) | |
# hard code the examples first for test | |
# first 9 are fake, the last 3 are real | |
# fake faces | |
for i in range(9): | |
samples.append(files[i]) | |
# get the fake or real label | |
fake = 1 | |
labels.append(fake) | |
# real faces | |
for i in range(3): | |
samples.append(files[i+9]) | |
fake = 0 | |
labels.append(fake) | |
# print to check the image and labels | |
for i in range(12): | |
print(samples[i], labels[i]) | |
def download_keras_model(): | |
# set the model variables to be global | |
global keras_binary_model | |
global keras_multi_model | |
# HF repo | |
# load binary label | |
if HF_SPACE: | |
download_dir = snapshot_download(repo_id=REPO_ID_BINARY) | |
print(f"Download dir: {download_dir}") | |
keras_binary_path = os.path.join(download_dir, keras_binary_label) | |
print(f"Keras binary label: {keras_binary_path}") | |
# this load() does not work in HF | |
#keras_binary_model = tf.keras.models.load(keras_binary_path) | |
#keras_binary_model = tf.keras.saving.load_model(keras_binary_path) | |
#keras_binary_model = from_pretrained_keras("liewchooichin/fake_binary") | |
keras_binary_model = tf.saved_model.load(download_dir) | |
# local testing | |
# check if the model exists | |
# binary label | |
# "C:\PY\exercises\hello_iris\fake_models\binary_label\all_binary_6771.keras" | |
if LOCAL_TEST: | |
model_path = os.path.join( | |
os.path.dirname(__file__), | |
local_model_dir, | |
keras_binary_label | |
) | |
if not os.path.exists(model_path): | |
print(f"Model not found: {model_path}") | |
# load local keras model | |
keras_binary_model = tf.keras.models.load_model(model_path) | |
# Check with model loaded | |
#print(f"\nBinary label model: {keras_binary_model.name}") | |
# load multilabel | |
# "C:\PY\exercises\hello_iris\fake_models\multi_label\all_multi_7036.keras" | |
if LOCAL_TEST: | |
model_path = os.path.join( | |
os.path.dirname(__file__), | |
local_model_dir, | |
keras_multilabel | |
) | |
if not os.path.exists(model_path): | |
print(f"Model not found: {model_path}") | |
# load local keras model | |
keras_multi_model = tf.keras.models.load_model(model_path) | |
# In HF space, load model from repository | |
# Load the multilabel model | |
if HF_SPACE: | |
# HF repo | |
download_dir = snapshot_download(repo_id=REPO_ID_MULTILABEL) | |
print(f"Download dir: {download_dir}") | |
keras_multi_path = os.path.join(download_dir, keras_multilabel) | |
print(f"Keras multi label: {keras_multi_path}") | |
# load() does not work in HF | |
#keras_multi_model = tf.keras.models.load(keras_multi_path) | |
#keras_multi_model = tf.keras.saving.load_model(keras_multi_path) | |
#keras_multi_model = from_pretrained_keras("liewchooichin/fake_multilabel") | |
keras_multi_model = tf.saved_model.load(download_dir) | |
# Check with model loaded | |
#print(f"\nLoaded model: {keras_multi_model.name}") | |
def get_img_array(img_path): | |
# get the dataset into array of 224x224 | |
img = tf.keras.utils.load_img( | |
img_path, | |
target_size=(224, 224) | |
) | |
img_array = tf.keras.utils.img_to_array(img) | |
# expand the dimension for prediction | |
img_array = np.expand_dims(img_array, axis=0) | |
print(f"Shape of image array: {img_array.shape}") | |
return img_array | |
def get_prediction(img_path): | |
# adjust threshold for accuracy | |
threshold = 0.4 | |
# check the image path | |
print(f"Image path: {img_path}") | |
# also display the original filename for info | |
orig_filename = img_path.split("\\")[-1] | |
get_img_array(img_path) | |
# get the image array | |
img_array = get_img_array(img_path) | |
# test with local model | |
# binary label | |
pred_binary = keras_binary_model(img_array, training=False) | |
print(f"Keras binary label: {pred_binary}") | |
if pred_binary[0][0] > threshold: | |
fake = "Fake" | |
else: | |
fake = "Real" | |
# multi label | |
pred_multi = keras_multi_model(img_array, training=False) | |
print(f"Keras multi label: {pred_multi}") | |
# Cut at the sigmoid 0.5 threshold | |
fake_parts = np.where(pred_multi > threshold, 1, 0) | |
print(f"Multi label: {fake_parts}") | |
# Format each of the fake face parts | |
parts_message = dict() | |
# The last one is the overall prediction | |
parts_message["overall"] = "Fake" if fake_parts[0][4] == 1 else "Real" | |
parts_message["left_eye"] = "Fake" if fake_parts[0][0] == 1 else "Real" | |
parts_message["right_eye"] = "Fake" if fake_parts[0][1] == 1 else "Real" | |
parts_message["nose"] = "Fake" if fake_parts[0][2] == 1 else "Real" | |
parts_message["mouth"] = "Fake" if fake_parts[0][3] == 1 else "Real" | |
# Format the display line by line | |
parts_formatted = "" | |
for k, v in parts_message.items(): | |
parts_formatted = parts_formatted + f"{k}: {v}\n" | |
# Format result string | |
result_binary = f"Probability: {pred_binary} \ | |
\nPrediction: {fake}\n" | |
result_multi = f"Probability: {pred_multi} \ | |
\nPrediction: {fake_parts} \ | |
\n{parts_formatted}" | |
# pred_multi = tf_model(img_path) | |
# print(f"tf: \n{pred_multi}") | |
return orig_filename, result_binary, result_multi | |
def clear_image(): | |
# Clear the previous output result | |
return "", "", "" | |
def main(): | |
get_samples() | |
# download_tf_model() | |
download_keras_model() | |
with gr.Blocks() as demo: | |
# call the main for preliminary work | |
main() | |
image_width = 256 | |
image_height = 256 | |
gr.Markdown( | |
""" | |
# Fake or real faces detection. | |
The dataset is obtained from https://www.kaggle.com/datasets/ciplab/real-and-fake-face-detection. | |
Trained with EfficientNet V2 B0. | |
One model is trained to do binary classification and the other \ | |
multilabel classification. The multilabels classification is \ | |
based on the last four digits provided by the filenames. \ | |
The last four digits are following the order of left eye, \ | |
right eye, nose and mouth. \ | |
The labels are 1 (fake) and 0 (real). | |
For example: ___1010.jpg means left eye and nose are fake. | |
Binary accuracy for the binary label model: 0.6771. <br> | |
Binary accuracy for the multilabel model: 0.7036. | |
The fake faces are also categorized into how difficult it is \ | |
to detect the faces as fake. The categories are easy, mid and hard. | |
The top prediction and its probabilities of classes are shown. | |
Try our sample faces below or upload one of your own. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.Image(height=image_height, | |
width=image_width, | |
sources=["upload", "clipboard"], | |
interactive=True, | |
type="filepath") | |
with gr.Column(): | |
text_1 = gr.Text( | |
label="Filename", | |
interactive=False, lines=1 | |
) | |
text_2 = gr.Text( | |
label="Binary label, Efficient net v2 B0", | |
interactive=False, lines=2) | |
text_3 = gr.Text( | |
label="Multi label, Efficient net v2 B0", | |
interactive=False, lines=7, | |
visible=False) | |
""" | |
text_3 = gr.Text(label="Sashi's model", | |
interactive=False, lines=3) | |
text_4 = gr.Text(label="KK's model", | |
interactive=False, lines=3) | |
""" | |
# load the images directory | |
# print(f"List of examples: {samples}") | |
with gr.Row(): | |
gr.Markdown(""" | |
## Fakes faces <br>(easy) | |
""") | |
examples_1 = gr.Examples( | |
examples=[ | |
samples[0], samples[1], samples[2], | |
], | |
inputs=[img], | |
outputs=[text_1, text_2, text_3], | |
run_on_click=True, | |
fn=get_prediction | |
) | |
with gr.Row(): | |
gr.Markdown(""" | |
## Fakes faces <br>(mid) | |
""") | |
examples_2 = gr.Examples( | |
examples=[ | |
samples[3], samples[4], samples[5], | |
], | |
inputs=[img], | |
outputs=[text_1, text_2, text_3], | |
run_on_click=True, | |
fn=get_prediction | |
) | |
with gr.Row(): | |
gr.Markdown(""" | |
## Fakes faces <br>(hard) | |
""") | |
examples_3 = gr.Examples( | |
examples=[ | |
samples[6], samples[7], samples[8], | |
], | |
inputs=[img], | |
outputs=[text_1, text_2, text_3], | |
run_on_click=True, | |
fn=get_prediction | |
) | |
with gr.Row(): | |
gr.Markdown(""" | |
## Real faces | |
""") | |
examples_4 = gr.Examples( | |
examples=[ | |
samples[9], samples[10], samples[11] | |
], | |
inputs=[img], | |
outputs=[text_1, text_2, text_3], | |
run_on_click=True, | |
fn=get_prediction | |
) | |
# prediction when a file is uploaded | |
img.upload(fn=get_prediction, inputs=[img], | |
outputs=[text_1, text_2, text_3]) | |
# when an example is clicked | |
# img.change(fn=get_prediction, inputs=[img], | |
# outputs=[text_1, text_2]) | |
# when an image is cleared | |
img.clear(fn=clear_image, inputs=[], | |
outputs=[text_1, text_2, text_3]) | |
if __name__ == "__main__": | |
demo.launch() | |