signforge / app.py
tensorkelechi's picture
Update app.py
0d0f527 verified
import streamlit as st
import torch
from PIL import Image
import matplotlib.pyplot as plt
from safetensors.torch import load_model
from transformers import pipeline
import torch, cv2
from torch import nn
import numpy as np
from torch.nn import functional as func_nn
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torchvision import models
def read_image(img, img_size=100):
img = np.array(img)
img = cv2.resize(img, (img_size, img_size)) # resize to mathc model input here
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img / 255.0
return img
# main model network
class SiameseNetwork(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# convolutional layer/block
# self.convnet = MobileNet()
self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone
num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head
self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head
# fully connected layer for classification
self.fc_linear = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(inplace=True), # actvation layer
nn.Linear(128, 2)
)
def single_pass(self, x) -> torch.Tensor:
# sinlge Forward pass for each image
x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width) to match model input
output = self.convnet(x)
output = self.fc_linear(output)
return output
def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor:
# forward pass of first image
output_1 = self.single_pass(input_1)
# forward pass of second contrast image
output_2 = self.single_pass(input_2)
return output_1, output_2
# pretrained model file
model_file = 'model.safetensors' #config.safetensor_file
# Function to compute similarity
def compute_similarity(output1, output2):
return torch.nn.functional.cosine_similarity(output1, output2).item()
# Function to visualize feature heatmaps
def visualize_heatmap(model, image):
model.eval()
x = image#.unsqueeze(0) # remove batch dimension
features = model.convnet.features(x) # feature heatmap learnt by model
heatmap = features.detach().numpy() #.squeeze() normalize heatmap to ndarray
plt.imshow(heatmap, cmap="hot") # display heatmap as plot
plt.axis("off")
return plt
# Load the pre-trained model from safeetesor file
def load_pipeline():
model_file = 'model.safetensors' #config.safetensor_file
# model_id = 'tensorkelechi/signature_mobilenet'
model = SiameseNetwork() # model class/skeleton
# model.load_state_dict(torch.load(model_file))
load_model(model, model_file)
# model = pipeline('image-classification', model=model_id, device='cpu')
# model.eval()
print(model)
return model #.to('cpu')
# Streamlit app UI template
st.title("Signature Forgery Detection")
st.write('Application to run/test signature forgery detecton model')
st.subheader('Compare signatures')
# File uploaders for the two images
original_image = st.file_uploader(
"Upload the original signature", type=["png", "jpg", "jpeg"]
)
comparison_image = st.file_uploader(
"Upload the signature to compare", type=["png", "jpg", "jpeg"]
)
def run_model_pipeline(model, original_image, comparison_image, threshold=0.3):
if original_image is not None and comparison_image is not None: # ensure both images are uploaded
# Preprocess images
img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image
img2 = Image.open(comparison_image).convert("RGB")
# read/reshape and normalize as numpy array
img1 = read_image(img1)
img2 = read_image(img2)
# convert to tensors and add batch dimensions to match model input shape
img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0)
img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0)
st.success('images loaded')
# Get model embeddings/probabilites
output1, output2 = model(img1_tensor.float(), img2_tensor.float())
st.success('outputs extracted')
# Compute similarity
similarity = compute_similarity(output1, output2)
# Determine if it's a forgery based on determined threshold
is_forgery = similarity < threshold
# Display results
st.subheader("Results")
st.write(f"Similarity: {similarity:.2f}")
st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}")
# Display images
col1, col2 = st.columns(2) # GUI columns
with col1:
st.image(img1, caption="Original Signature", use_column_width=True)
with col2:
st.image(img2, caption="Comparison Signature", use_column_width=True)
# # Visualize heatmaps from extracted model features
# st.subheader("Feature Heatmaps")
# col3, col4 = st.columns(2)
# with col3:
# x1 = rearrange(img1_tensor.float(), 'b h w c -> b c h w')
# fig1 = visualize_heatmap(model, x1)
# st.pyplot(fig1)
# with col4:
# x2 = rearrange(img2_tensor.float(), 'b h w c -> b c h w')
# fig2 = visualize_heatmap(model, x2)
# st.pyplot(fig2)
else:
st.write("Please upload both the original and comparison signatures.")
# Run the model pipeline if a button is clicked
if st.button("Run Model Pipeline for images"):
model = load_pipeline()
# button click to process images
# if st.button("Process Images"):
run_model_pipeline(model, original_image, comparison_image)