Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from safetensors.torch import load_model | |
from transformers import pipeline | |
import torch | |
from torch import nn | |
from torch.nn import functional as func_nn | |
from einops import rearrange | |
from huggingface_hub import PyTorchModelHubMixin | |
from torchvision import models | |
# main model network | |
class SiameseNetwork(nn.Module, PyTorchModelHubMixin): | |
def __init__(self): | |
super().__init__() | |
# convolutional layer/block | |
# self.convnet = MobileNet() | |
self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone | |
num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head | |
self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head | |
# fully connected layer for classification | |
self.fc_linear = nn.Sequential( | |
nn.Linear(512, 128), | |
nn.ReLU(inplace=True), # actvation layer | |
nn.Linear(128, 2) | |
) | |
def single_pass(self, x) -> torch.Tensor: | |
# sinlge Forward pass for each image | |
x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width) to match model input | |
output = self.convnet(x) | |
output = self.fc_linear(output) | |
return output | |
def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor: | |
# forward pass of first image | |
output_1 = self.single_pass(input_1) | |
# forward pass of second contrast image | |
output_2 = self.single_pass(input_2) | |
return output_1, output_2 | |
# pretrained model file | |
model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file | |
# Function to compute similarity | |
def compute_similarity(output1, output2): | |
return torch.nn.functional.cosine_similarity(output1, output2).item() | |
# Function to visualize feature heatmaps | |
def visualize_heatmap(model, image): | |
model.eval() | |
x = image.unsqueeze(0) # remove batch dimension | |
features = model.convnet(x) # feature heatmap learnt by model | |
heatmap = torch.mean(features, dim=1).squeeze().detach().numpy() # normalize heatmap to ndarray | |
plt.imshow(heatmap, cmap="hot") # display heatmap as plot | |
plt.axis("off") | |
return plt | |
# Load the pre-trained model from safeetesor file | |
def load_pipeline(model_id=): | |
model_id = 'tensorkelechi/signature_mobilenet' | |
# model = SiameseNetwork() # model class/skeleton | |
# model.load_state_dict(torch.load(model_file)) | |
model = pipeline('image-classification', model=model_id, device='cpu') | |
model.eval() | |
return model | |
# Streamlit app UI template | |
st.title("Signature Forgery Detection") | |
st.write('Application to run/test signature forgery detecton model') | |
st.subheader('Compare signatures') | |
# File uploaders for the two images | |
original_image = st.file_uploader( | |
"Upload the original signature", type=["png", "jpg", "jpeg"] | |
) | |
comparison_image = st.file_uploader( | |
"Upload the signature to compare", type=["png", "jpg", "jpeg"] | |
) | |
def run_model_pipeline(model, original_image, comparison_image, threshold=0.5): | |
if original_image is not None and comparison_image is not None: # ensure both images are uploaded | |
# Preprocess images | |
img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image | |
img2 = Image.open(comparison_image).convert("RGB") | |
# read/reshape and normalize as numpy array | |
img1 = read_image(img1) | |
img2 = read_image(img2) | |
# convert to tensors and add batch dimensions to match model input shape | |
img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0) | |
img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0) | |
# Get model embeddings/probabilites | |
output1, output2 = model(img1_tensor, img2_tensor) | |
st.success('outputs extracted') | |
# Compute similarity | |
similarity = compute_similarity(output1, output2) | |
# Determine if it's a forgery based on determined threshold | |
is_forgery = similarity < threshold | |
# Display results | |
st.subheader("Results") | |
st.write(f"Similarity: {similarity:.2f}") | |
st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}") | |
# Display images | |
col1, col2 = st.columns(2) # GUI columns | |
with col1: | |
st.image(img1, caption="Original Signature", use_column_width=True) | |
with col2: | |
st.image(img2, caption="Comparison Signature", use_column_width=True) | |
# Visualize heatmaps from extracted model features | |
st.subheader("Feature Heatmaps") | |
col3, col4 = st.columns(2) | |
with col3: | |
fig1 = visualize_heatmap(model, img1_tensor) | |
st.pyplot(fig1) | |
with col4: | |
fig2 = visualize_heatmap(model, img2_tensor) | |
st.pyplot(fig2) | |
else: | |
st.write("Please upload both the original and comparison signatures.") | |
# Run the model pipeline if a button is clicked | |
if st.button("Run Model Pipeline"): | |
model = load_pipeline() | |
# button click to process images | |
if st.button("Process Images"): | |
run_model_pipeline(model, original_image, comparison_image) | |