David Fischinger
fix for loading model
1530a3e
from PIL import Image
import streamlit as st
import cv2
import numpy as np
import os
import tensorflow as tf
from IMVIP_Supplementary_Material.scripts import dfutils #methods used for DF-Net
DESCRIPTION = """# DF-Net
The Digital Forensics Network is designed and trained to detect and locate image manipulations.
More information can be found in this [publication](https://zenodo.org/record/8214996)
#### Select example image or upload your own image:
"""
IMG_SIZE=256
tf.experimental.numpy.experimental_enable_numpy_behavior()
#np.warnings.filterwarnings('error', category=np.VisibleDeprecationWarning)
# function to load models
#@st.session_state better for hugging face?
@st.cache_resource
def load_models():
#load models
model_path1 = "IMVIP_Supplementary_Material/models/model1/"
model_path2 = "IMVIP_Supplementary_Material/models/model2/"
model_M1 = tf.keras.models.load_model("IMVIP_Supplementary_Material/models/model1/")
model_M2 = tf.keras.models.load_model("IMVIP_Supplementary_Material/models/model2/")
return model_M1, model_M2
model_M1, model_M2 = load_models()
def check_forgery_df(img):
shape_original = img.shape
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
x = np.expand_dims( img.astype('float32')/255., axis=0 )
pred1 = model_M1.predict(x, verbose=0)
pred2= model_M2.predict(x, verbose=0)
# # Ensure pred1 and pred2 are numpy arrays before proceeding
# if isinstance(pred1, dict):
# print("pred1 is dict!")
# pred1 = pred1[next(iter(pred1))]
# if isinstance(pred2, dict):
# pred2 = pred2[next(iter(pred2))]
pred = np.max([pred1,pred2], axis=0)
pred = dfutils.create_mask(pred)
pred = pred.reshape(pred.shape[-3:-1])
resized_image = cv2.resize(pred, (shape_original[1],shape_original[0]), interpolation=cv2.INTER_LINEAR)
return resized_image
def evaluate(img):
pre_t = check_forgery_df(img)
st.image(pre_t, caption="White area indicates potential image manipulations.")
def start_evaluation(uploaded_file):
# Convert the file to an opencv image.
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
reversed_image = opencv_image[:, :, ::-1]
st.image(reversed_image, caption="Input Image")
evaluate(reversed_image)
def start_evaluation_pil_img(pil_image):
# Convert the PIL image to a NumPy array
opencv_image = np.array(pil_image)
# Convert the image from RGB (PIL format) to BGR (OpenCV format)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
# Reverse the color channels back to RGB for display in Streamlit
reversed_image = opencv_image[:, :, ::-1]
st.image(reversed_image, caption="Input Image")
evaluate(reversed_image)
st.markdown(DESCRIPTION)
img_path1 = "example_images/Sp_D_NRD_A_nat0095_art0058_0582"
img_path2 = "example_images/Sp_D_NRN_A_nat0083_arc0080_0445"
#img_path3 = "example_images/Sp_D_NRN_A_ani0088_cha0044_0441"
image_paths = [img_path1+".jpg", img_path2+".jpg"] #, img_path3+".jpg"]
gt_paths = [img_path1+"_gt.png", img_path2+"_gt.png"] #, img_path3+"_gt.png"]
# Display images in a table format
img = None
for idx, image_path in enumerate(image_paths):
cols = st.columns([2, 2, 2, 2]) # Define column widths
# Place the button in the first column
if cols[0].button(f"Select Image {idx+1}", key=idx):
img = Image.open(image_path)
# Place the image in the second column
with cols[1]:
st.image(image_path, use_column_width=True, caption="Example Image "+str(idx+1))
# Place the ground truth in the third column
with cols[2]:
st.image(gt_paths[idx], use_column_width=True, caption="Ground Truth")
if img is not None:
start_evaluation_pil_img(img)
def reset_image_select():
img = None
uploaded_file= None
uploaded_file = st.file_uploader("Please upload an image", type=["jpeg", "jpg", "png"], on_change=reset_image_select)
if (uploaded_file is not None) and (img is None):
start_evaluation(uploaded_file)