|
import streamlit as st |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
import timm |
|
import numpy as np |
|
import onnxruntime as ort |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class SiameseNetwork(nn.Module): |
|
def __init__(self, model_name='resnet18', pretrained=True): |
|
super(SiameseNetwork, self).__init__() |
|
self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(self.encoder.num_features, 256), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(256, 128) |
|
) |
|
|
|
def forward_once(self, x): |
|
output = self.encoder(x) |
|
output = self.fc(output) |
|
return output |
|
|
|
def forward(self, img1, img2): |
|
output1 = self.forward_once(img1) |
|
output2 = self.forward_once(img2) |
|
return output1, output2 |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
backbone = 'efficientnet_b2' |
|
onnx_file_path = f"weights/model_{backbone}.onnx" |
|
session = ort.InferenceSession(onnx_file_path) |
|
return session |
|
|
|
|
|
def inference_with_real_set(session, img, real_imgs, threshold=0.5): |
|
distances = [] |
|
for real_img in real_imgs: |
|
input_dict = { |
|
session.get_inputs()[0].name: img.numpy(), |
|
session.get_inputs()[1].name: real_img.numpy() |
|
} |
|
outputs = session.run(None, input_dict) |
|
|
|
euclidean_distance = np.linalg.norm(outputs[0] - outputs[1]) |
|
|
|
distances.append(euclidean_distance) |
|
|
|
avg_distance = sum(distances) / len(distances) |
|
|
|
return "Real" if avg_distance < threshold else "Fake" |
|
|
|
|
|
def preprocess_image(image): |
|
transform = A.Compose([ |
|
A.Resize(256, 256), |
|
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), |
|
ToTensorV2(), |
|
]) |
|
image = np.array(image)/255.0 |
|
image = transform(image=image)['image'].unsqueeze(0) |
|
return image |
|
|
|
|
|
hide_warning = """ |
|
<style> |
|
.stAlert {display: none;} |
|
.css-18e3th9 {padding-top: 2rem; padding-bottom: 2rem;} /* Adjust top and bottom padding */ |
|
.css-1d391kg {max-width: 100% !important; padding-left: 1rem; padding-right: 1rem;} /* Adjust the max-width and side padding */ |
|
</style> |
|
""" |
|
|
|
st.markdown(hide_warning, unsafe_allow_html=True) |
|
|
|
|
|
st.title("Forge Signature Siamese") |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
|
with col1: |
|
|
|
uploaded_image = st.file_uploader("Upload an image to verify", type=["png", "jpg", "jpeg"]) |
|
|
|
|
|
uploaded_real_images = st.file_uploader("Upload real reference images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) |
|
|
|
|
|
threshold = st.slider("Threshold", 0.0, 1.0, 0.5) |
|
|
|
|
|
if uploaded_image is not None and uploaded_real_images: |
|
|
|
img = preprocess_image(Image.open(uploaded_image).convert('RGB')) |
|
real_imgs = [preprocess_image(Image.open(img).convert('RGB')) for img in uploaded_real_images] |
|
|
|
|
|
result = inference_with_real_set(model, img, real_imgs, threshold) |
|
|
|
with col2: |
|
|
|
if result == "Real": |
|
button_color = "background-color: green; color: white; font-weight: bold;" |
|
button_label = "Real" |
|
else: |
|
button_color = "background-color: red; color: white; font-weight: bold;" |
|
button_label = "Fake" |
|
|
|
st.markdown( |
|
f""" |
|
<div style="display: flex; justify-content: center; margin-top: 20px;"> |
|
<button style="padding: 10px 20px; font-size: 18px; {button_color}">{button_label}</button> |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
sub_col1, sub_col2 = st.columns([1, 1]) |
|
|
|
|
|
with sub_col1: |
|
st.write("\n") |
|
st.write("**Input Image**") |
|
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
with sub_col2: |
|
st.write("\n") |
|
st.write("**Real Reference Images**") |
|
for real_image in uploaded_real_images: |
|
st.image(real_image, caption="Real Image", use_column_width=True) |
|
|
|
|
|
|
|
|
|
|