Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
import numpy as np | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import cv2 | |
import os | |
import shutil | |
from skimage.metrics import structural_similarity as ssim | |
beta = 1.0 | |
# Loss for reveal network | |
def rev_loss(s_true, s_pred): | |
# Loss for reveal network is: beta * |S-S'| | |
return beta * K.sum(K.square(s_true - s_pred)) | |
# Loss for the full model, used for preparation and hidding networks | |
def full_loss(y_true, y_pred): | |
# Loss for the full model is: |C-C'| + beta * |S-S'| | |
s_true, c_true = y_true[...,0:3], y_true[...,3:6] | |
s_pred, c_pred = y_pred[...,0:3], y_pred[...,3:6] | |
s_loss = rev_loss(s_true, s_pred) | |
c_loss = K.sum(K.square(c_true - c_pred)) | |
return s_loss + c_loss | |
model = tf.keras.models.load_model("model.h5", custom_objects={'full_loss': | |
full_loss}) | |
def preprocess_image(img): | |
if isinstance(img, np.ndarray): | |
img = Image.fromarray(img) | |
img = img.resize((124, 124), Image.ANTIALIAS) | |
img = np.array(img) | |
img = img / 255.0 | |
return img | |
def steganography_image(imageO, imageF): | |
# Preprocess images | |
img_S = preprocess_image(imageO) | |
img_C = preprocess_image(imageF) | |
# Add batch dimension | |
img_S = np.expand_dims(img_S, axis=0) | |
img_C = np.expand_dims(img_C, axis=0) | |
# Predict with pre/loaded model | |
decoded = model.predict([img_S, img_C]) | |
decoded_S, decoded_C = decoded[..., 0:3], decoded[..., 3:6] | |
# Post-process outputs | |
decoded_S = np.squeeze(decoded_S, axis=0) # Remove batch dimension | |
decoded_C = np.squeeze(decoded_C, axis=0) # Remove batch dimension | |
decoded_S = (decoded_S * 255).astype(np.uint8) | |
decoded_C = (decoded_C * 255).astype(np.uint8) | |
# Calculate absolute differences | |
diff_S = np.abs(decoded_S - (img_S.squeeze() * 255)).astype(np.uint8) | |
diff_C = np.abs(decoded_C - (img_C.squeeze() * 255)).astype(np.uint8) | |
# Create a plot of differences | |
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | |
ax[0].imshow(diff_S) | |
ax[0].set_title('Difference in Secret Image') | |
ax[0].axis('off') | |
ax[1].imshow(diff_C) | |
ax[1].set_title('Difference in Cover Image') | |
ax[1].axis('off') | |
plt.tight_layout() | |
# Return images and plot | |
return decoded_S, decoded_C, fig | |
#Function to clear a folder | |
def clear_folder(path): | |
if os.path.exists(path): | |
shutil.rmtree(path) | |
os.makedirs(path) | |
#Function to extract every frame of a video and save them in a folder | |
def extractImages(pathIn, pathOut): | |
clear_folder(pathOut) | |
if not os.path.exists(pathOut): | |
os.makedirs(pathOut) | |
vidcap = cv2.VideoCapture(pathIn) | |
success, image = vidcap.read() | |
count = 0 | |
while success: | |
frame_path = os.path.join(pathOut, f"frame{count}.jpg") | |
success, image = vidcap.read() | |
if success: | |
resized_image = cv2.resize(image, (124, 124), interpolation=cv2.INTER_AREA) | |
cv2.imwrite(frame_path, resized_image) | |
print(f'Saved frame {count} to {frame_path}') | |
else: | |
print(f'Failed to read frame at count {count}') | |
count += 1 | |
#Function to create a new video based on a folder of frames | |
def rebuildVideo(framesPath, outputPath, fps=30): | |
frame_files = sorted([f for f in os.listdir(framesPath) if f.endswith('.jpg')], | |
key=lambda x: int(x[5:-4])) | |
first_frame_path = os.path.join(framesPath, frame_files[0]) | |
frame = cv2.imread(first_frame_path) | |
height, width, layers = frame.shape | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(outputPath, fourcc, fps, (width, height)) | |
for file in frame_files: | |
frame_path = os.path.join(framesPath, file) | |
frame = cv2.imread(frame_path) | |
out.write(frame) | |
out.release() | |
def calculate_ssim(img1, img2): | |
img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) | |
img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) | |
score, _ = ssim(img1_gray, img2_gray, full=True) | |
return score | |
def plot_metrics(metrics): | |
fig, ax = plt.subplots() | |
ax.plot(metrics, label="SSIM") | |
ax.set_xlabel("Frame") | |
ax.set_ylabel("SSIM") | |
ax.set_title("SSIM over Frames") | |
ax.legend() | |
ax.grid(True) | |
return fig | |
def process_frame(imageO, imageF): | |
img_S = preprocess_image(imageO) | |
img_C = preprocess_image(imageF) | |
img_S = np.expand_dims(img_S, axis=0) | |
img_C = np.expand_dims(img_C, axis=0) | |
decoded = model.predict([img_S, img_C]) | |
decoded_S, decoded_C = decoded[..., 0:3], decoded[..., 3:6] | |
decoded_S = np.squeeze(decoded_S, axis=0) | |
decoded_C = np.squeeze(decoded_C, axis=0) | |
decoded_S = (decoded_S * 255).astype(np.uint8) | |
decoded_C = (decoded_C * 255).astype(np.uint8) | |
return decoded_S, decoded_C | |
def steganography_video(video_path1, video_path2): | |
input_frames_path = "Frames1" | |
input_frames_path2 = "Frames2" | |
output_frames_path = "Frames3" | |
output_frames_path2 = "Frames4" | |
output_video_path = "output_video.mp4" | |
output_video_path2 = "output_video2.mp4" | |
extractImages(video_path1, input_frames_path) | |
extractImages(video_path2, input_frames_path2) | |
input_frame_files = sorted([f for f in os.listdir(input_frames_path) if f.endswith('.jpg')], | |
key=lambda x: int(x[5:-4])) | |
clear_folder(output_frames_path) | |
clear_folder(output_frames_path2) | |
i = 0 | |
ssim_scores = [] | |
ssim_scores2 = [] | |
for file in input_frame_files: | |
frame_path = os.path.join(input_frames_path, file) | |
frame_path2 = os.path.join(input_frames_path2, f"frame{i}.jpg") | |
frame = cv2.imread(frame_path) | |
try: | |
frame2 = cv2.imread(frame_path2) | |
except: | |
print("Second video is too short, will be cut up to the length of the first one") | |
break | |
if frame2 is None: | |
break | |
decoded_S, decoded_C = process_frame(frame, frame2) | |
decoded_S_path = os.path.join(output_frames_path, file) | |
cv2.imwrite(decoded_S_path, decoded_S) | |
decoded_C_path = os.path.join(output_frames_path2, file) | |
cv2.imwrite(decoded_C_path, decoded_C) | |
print(frame.shape) | |
print(decoded_S.shape) | |
print(frame2.shape) | |
print(decoded_C.shape) | |
ssim_scores.append(calculate_ssim(frame, decoded_S)) | |
ssim_scores2.append(calculate_ssim(frame2, decoded_C)) | |
i+=1 | |
rebuildVideo(output_frames_path, output_video_path, fps=20) | |
rebuildVideo(output_frames_path2, output_video_path2, fps=20) | |
return output_video_path, output_video_path2, ssim_scores, ssim_scores2 | |
example_secret_image = "Examples/secret.jpg" | |
example_cover_image = "Examples/cover.jpg" | |
example_cover_video = "Examples/cover.mp4" | |
example_secret_video = "Examples/secret.mp4" | |
with gr.Blocks() as demo: | |
with gr.Tab("Image Processing"): | |
image_input1 = gr.Image(label="Cover Image") | |
image_input2 = gr.Image(label="Secret Image") | |
image_output1 = gr.Image(label="Decoded Cover Image") | |
image_output2 = gr.Image(label="Decoded Secret Image") | |
plot = gr.Plot(label = "Noise behind each image") | |
btn_image = gr.Button("Process Images") | |
btn_image.click( | |
fn=steganography_image, | |
inputs=[image_input1, image_input2], | |
outputs=[image_output1, image_output2, plot] | |
) | |
with gr.Tab("Video Processing"): | |
video_input = gr.Video(label="Input Cover Video") | |
video_input2 = gr.Video(label="Input Secret Video") | |
video_output = gr.Video(label="Output Cover Video") | |
video_output2 = gr.Video(label="Output Secret Video") | |
plot_output = gr.Plot(label="SSIM over Frames for Cover") | |
plot_output2 = gr.Plot(label="SSIM over Frames for Secret") | |
btn_video = gr.Button("Process Video") | |
def process_video_and_plot(video_path, video_path2): | |
video_path, video_path2, ssim_scores, ssim_scores2 = steganography_video(video_path, video_path2) | |
plot = plot_metrics(ssim_scores) | |
plot2 = plot_metrics(ssim_scores2) | |
plot.show() | |
return video_path, video_path2, plot, plot2 | |
btn_video.click( | |
fn=process_video_and_plot, | |
inputs=[video_input, video_input2], | |
outputs=[video_output, video_output2, plot_output, plot_output2] | |
) | |
demo.launch(debug=True, share = True) |