signforge / app.py
tensorkelechi's picture
Update app.py
8e7b01a verified
raw
history blame
5.44 kB
import streamlit as st
import torch
from PIL import Image
import matplotlib.pyplot as plt
from safetensors.torch import load_model
from transformers import pipeline
import torch
from torch import nn
from torch.nn import functional as func_nn
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torchvision import models
# main model network
class SiameseNetwork(nn.Module, PyTorchModelHubMixin):
def __init__(self):
super().__init__()
# convolutional layer/block
# self.convnet = MobileNet()
self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone
num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head
self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head
# fully connected layer for classification
self.fc_linear = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(inplace=True), # actvation layer
nn.Linear(128, 2)
)
def single_pass(self, x) -> torch.Tensor:
# sinlge Forward pass for each image
x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width) to match model input
output = self.convnet(x)
output = self.fc_linear(output)
return output
def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor:
# forward pass of first image
output_1 = self.single_pass(input_1)
# forward pass of second contrast image
output_2 = self.single_pass(input_2)
return output_1, output_2
# pretrained model file
model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file
# Function to compute similarity
def compute_similarity(output1, output2):
return torch.nn.functional.cosine_similarity(output1, output2).item()
# Function to visualize feature heatmaps
def visualize_heatmap(model, image):
model.eval()
x = image.unsqueeze(0) # remove batch dimension
features = model.convnet(x) # feature heatmap learnt by model
heatmap = torch.mean(features, dim=1).squeeze().detach().numpy() # normalize heatmap to ndarray
plt.imshow(heatmap, cmap="hot") # display heatmap as plot
plt.axis("off")
return plt
# Load the pre-trained model from safeetesor file
def load_pipeline():
model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file
# model_id = 'tensorkelechi/signature_mobilenet'
model = SiameseNetwork() # model class/skeleton
# model.load_state_dict(torch.load(model_file))
model = load_model(model, model_file)
# model = pipeline('image-classification', model=model_id, device='cpu')
model.eval()
return model.to('cpu')
# Streamlit app UI template
st.title("Signature Forgery Detection")
st.write('Application to run/test signature forgery detecton model')
st.subheader('Compare signatures')
# File uploaders for the two images
original_image = st.file_uploader(
"Upload the original signature", type=["png", "jpg", "jpeg"]
)
comparison_image = st.file_uploader(
"Upload the signature to compare", type=["png", "jpg", "jpeg"]
)
def run_model_pipeline(model, original_image, comparison_image, threshold=0.5):
if original_image is not None and comparison_image is not None: # ensure both images are uploaded
# Preprocess images
img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image
img2 = Image.open(comparison_image).convert("RGB")
# read/reshape and normalize as numpy array
img1 = read_image(img1)
img2 = read_image(img2)
# convert to tensors and add batch dimensions to match model input shape
img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0)
img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0)
# Get model embeddings/probabilites
output1, output2 = model(img1_tensor, img2_tensor)
st.success('outputs extracted')
# Compute similarity
similarity = compute_similarity(output1, output2)
# Determine if it's a forgery based on determined threshold
is_forgery = similarity < threshold
# Display results
st.subheader("Results")
st.write(f"Similarity: {similarity:.2f}")
st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}")
# Display images
col1, col2 = st.columns(2) # GUI columns
with col1:
st.image(img1, caption="Original Signature", use_column_width=True)
with col2:
st.image(img2, caption="Comparison Signature", use_column_width=True)
# Visualize heatmaps from extracted model features
st.subheader("Feature Heatmaps")
col3, col4 = st.columns(2)
with col3:
fig1 = visualize_heatmap(model, img1_tensor)
st.pyplot(fig1)
with col4:
fig2 = visualize_heatmap(model, img2_tensor)
st.pyplot(fig2)
else:
st.write("Please upload both the original and comparison signatures.")
# Run the model pipeline if a button is clicked
if st.button("Run Model Pipeline"):
model = load_pipeline()
# button click to process images
if st.button("Process Images"):
run_model_pipeline(model, original_image, comparison_image)