tensorkelechi commited on
Commit
46a90f0
·
verified ·
1 Parent(s): 8db0202

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+ from safetensors.torch import load_model
6
+ from transformers import pipeline
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as func_nn
10
+ from einops import rearrange
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+ from torchvision import models
13
+
14
+
15
+ # main model network
16
+ class SiameseNetwork(nn.Module, PyTorchModelHubMixin):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ # convolutional layer/block
21
+ # self.convnet = MobileNet()
22
+ self.convnet = models.mobilenet_v2(pretrained=True) # pretrained backbone
23
+ num_ftrs = self.convnet.classifier[1].in_features # get the first deimnesion of model head
24
+
25
+ self.convnet.classifier[1] = nn.Linear(num_ftrs, 512) # change/switch backbone linear head
26
+
27
+ # fully connected layer for classification
28
+ self.fc_linear = nn.Sequential(
29
+ nn.Linear(512, 128),
30
+ nn.ReLU(inplace=True), # actvation layer
31
+ nn.Linear(128, 2)
32
+ )
33
+
34
+ def single_pass(self, x) -> torch.Tensor:
35
+ # sinlge Forward pass for each image
36
+ x = rearrange(x, 'b h w c -> b c h w') # rearrange to (batch, channels, height, width) to match model input
37
+ output = self.convnet(x)
38
+ output = self.fc_linear(output)
39
+
40
+ return output
41
+
42
+ def forward(self, input_1: torch.Tensor, input_2: torch.Tensor) -> torch.Tensor:
43
+
44
+ # forward pass of first image
45
+ output_1 = self.single_pass(input_1)
46
+
47
+ # forward pass of second contrast image
48
+ output_2 = self.single_pass(input_2)
49
+
50
+ return output_1, output_2
51
+
52
+
53
+ # pretrained model file
54
+ model_file = 'best_signature_mobilenet.safetensors' #config.safetensor_file
55
+
56
+
57
+ # Function to compute similarity
58
+ def compute_similarity(output1, output2):
59
+ return torch.nn.functional.cosine_similarity(output1, output2).item()
60
+
61
+ # Function to visualize feature heatmaps
62
+ def visualize_heatmap(model, image):
63
+ model.eval()
64
+ x = image.unsqueeze(0) # remove batch dimension
65
+ features = model.convnet(x) # feature heatmap learnt by model
66
+ heatmap = torch.mean(features, dim=1).squeeze().detach().numpy() # normalize heatmap to ndarray
67
+ plt.imshow(heatmap, cmap="hot") # display heatmap as plot
68
+ plt.axis("off")
69
+
70
+ return plt
71
+
72
+ # Load the pre-trained model from safeetesor file
73
+ def load_pipeline(model_id=):
74
+ model_id = 'tensorkelechi/signature_mobilenet'
75
+ # model = SiameseNetwork() # model class/skeleton
76
+
77
+ # model.load_state_dict(torch.load(model_file))
78
+ model = pipeline('image-classification', model=model_id, device='cpu')
79
+ model.eval()
80
+
81
+ return model
82
+
83
+ # Streamlit app UI template
84
+ st.title("Signature Forgery Detection")
85
+ st.write('Application to run/test signature forgery detecton model')
86
+
87
+ st.subheader('Compare signatures')
88
+ # File uploaders for the two images
89
+ original_image = st.file_uploader(
90
+ "Upload the original signature", type=["png", "jpg", "jpeg"]
91
+ )
92
+ comparison_image = st.file_uploader(
93
+ "Upload the signature to compare", type=["png", "jpg", "jpeg"]
94
+ )
95
+
96
+ def run_model_pipeline(model, original_image, comparison_image, threshold=0.5):
97
+ if original_image is not None and comparison_image is not None: # ensure both images are uploaded
98
+
99
+ # Preprocess images
100
+ img1 = Image.open(original_image).convert("RGB") # load images from file paths to PIL Image
101
+ img2 = Image.open(comparison_image).convert("RGB")
102
+
103
+ # read/reshape and normalize as numpy array
104
+ img1 = read_image(img1)
105
+ img2 = read_image(img2)
106
+
107
+ # convert to tensors and add batch dimensions to match model input shape
108
+ img1_tensor = torch.unsqueeze(torch.as_tensor(img1), 0)
109
+ img2_tensor = torch.unsqueeze(torch.as_tensor(img2), 0)
110
+
111
+ # Get model embeddings/probabilites
112
+ output1, output2 = model(img1_tensor, img2_tensor)
113
+ st.success('outputs extracted')
114
+
115
+ # Compute similarity
116
+ similarity = compute_similarity(output1, output2)
117
+
118
+ # Determine if it's a forgery based on determined threshold
119
+ is_forgery = similarity < threshold
120
+
121
+ # Display results
122
+ st.subheader("Results")
123
+ st.write(f"Similarity: {similarity:.2f}")
124
+ st.write(f"Classification: {'Forgery' if is_forgery else 'Genuine'}")
125
+
126
+ # Display images
127
+ col1, col2 = st.columns(2) # GUI columns
128
+
129
+ with col1:
130
+ st.image(img1, caption="Original Signature", use_column_width=True)
131
+ with col2:
132
+ st.image(img2, caption="Comparison Signature", use_column_width=True)
133
+
134
+ # Visualize heatmaps from extracted model features
135
+ st.subheader("Feature Heatmaps")
136
+ col3, col4 = st.columns(2)
137
+ with col3:
138
+ fig1 = visualize_heatmap(model, img1_tensor)
139
+ st.pyplot(fig1)
140
+ with col4:
141
+ fig2 = visualize_heatmap(model, img2_tensor)
142
+ st.pyplot(fig2)
143
+
144
+ else:
145
+ st.write("Please upload both the original and comparison signatures.")
146
+
147
+ # Run the model pipeline if a button is clicked
148
+ if st.button("Run Model Pipeline"):
149
+ model = load_pipeline()
150
+
151
+ # button click to process images
152
+ if st.button("Process Images"):
153
+ run_model_pipeline(model, original_image, comparison_image)