binhduong2310
commited on
Commit
•
67c6064
1
Parent(s):
dfbd5be
Upload 4 files
Browse files- app.py +148 -0
- requirements.txt +6 -0
- weights/best_efficientnet_b2.pth +3 -0
- weights/model_efficientnet_b2.onnx +3 -0
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from PIL import Image
|
5 |
+
import albumentations as A
|
6 |
+
from albumentations.pytorch import ToTensorV2
|
7 |
+
import timm
|
8 |
+
import numpy as np
|
9 |
+
import onnxruntime as ort
|
10 |
+
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
# Khởi tạo mô hình
|
15 |
+
class SiameseNetwork(nn.Module):
|
16 |
+
def __init__(self, model_name='resnet18', pretrained=True):
|
17 |
+
super(SiameseNetwork, self).__init__()
|
18 |
+
self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
|
19 |
+
|
20 |
+
self.fc = nn.Sequential(
|
21 |
+
nn.Linear(self.encoder.num_features, 256),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.Linear(256, 128)
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward_once(self, x):
|
27 |
+
output = self.encoder(x)
|
28 |
+
output = self.fc(output)
|
29 |
+
return output
|
30 |
+
|
31 |
+
def forward(self, img1, img2):
|
32 |
+
output1 = self.forward_once(img1)
|
33 |
+
output2 = self.forward_once(img2)
|
34 |
+
return output1, output2
|
35 |
+
|
36 |
+
# Load the model
|
37 |
+
@st.cache_resource
|
38 |
+
def load_model():
|
39 |
+
backbone = 'efficientnet_b2'
|
40 |
+
onnx_file_path = f"weights/model_{backbone}.onnx"
|
41 |
+
session = ort.InferenceSession(onnx_file_path)
|
42 |
+
return session
|
43 |
+
|
44 |
+
# Function to run inference
|
45 |
+
def inference_with_real_set(session, img, real_imgs, threshold=0.5):
|
46 |
+
distances = []
|
47 |
+
for real_img in real_imgs:
|
48 |
+
input_dict = {
|
49 |
+
session.get_inputs()[0].name: img.numpy(),
|
50 |
+
session.get_inputs()[1].name: real_img.numpy()
|
51 |
+
}
|
52 |
+
outputs = session.run(None, input_dict)
|
53 |
+
|
54 |
+
euclidean_distance = np.linalg.norm(outputs[0] - outputs[1])
|
55 |
+
|
56 |
+
distances.append(euclidean_distance)
|
57 |
+
|
58 |
+
avg_distance = sum(distances) / len(distances)
|
59 |
+
|
60 |
+
return "Real" if avg_distance < threshold else "Fake"
|
61 |
+
|
62 |
+
# Image preprocessing function
|
63 |
+
def preprocess_image(image):
|
64 |
+
transform = A.Compose([
|
65 |
+
A.Resize(256, 256),
|
66 |
+
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
67 |
+
ToTensorV2(),
|
68 |
+
])
|
69 |
+
image = np.array(image)/255.0
|
70 |
+
image = transform(image=image)['image'].unsqueeze(0)
|
71 |
+
return image
|
72 |
+
|
73 |
+
# Hide Streamlit warnings and footers
|
74 |
+
hide_warning = """
|
75 |
+
<style>
|
76 |
+
.stAlert {display: none;}
|
77 |
+
.css-18e3th9 {padding-top: 2rem; padding-bottom: 2rem;} /* Adjust top and bottom padding */
|
78 |
+
.css-1d391kg {max-width: 100% !important; padding-left: 1rem; padding-right: 1rem;} /* Adjust the max-width and side padding */
|
79 |
+
</style>
|
80 |
+
"""
|
81 |
+
|
82 |
+
st.markdown(hide_warning, unsafe_allow_html=True)
|
83 |
+
|
84 |
+
# Streamlit interface
|
85 |
+
st.title("Forge Signature Siamese")
|
86 |
+
|
87 |
+
# Load the model
|
88 |
+
model = load_model()
|
89 |
+
|
90 |
+
# Chia màn hình thành 2 cột
|
91 |
+
col1, col2 = st.columns([1, 1]) # Cột bên trái để input, cột bên phải để hiển thị kết quả
|
92 |
+
|
93 |
+
with col1:
|
94 |
+
# Upload image to compare
|
95 |
+
uploaded_image = st.file_uploader("Upload an image to verify", type=["png", "jpg", "jpeg"])
|
96 |
+
|
97 |
+
# Upload real reference images
|
98 |
+
uploaded_real_images = st.file_uploader("Upload real reference images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
|
99 |
+
|
100 |
+
# Threshold slider
|
101 |
+
threshold = st.slider("Threshold", 0.0, 1.0, 0.5)
|
102 |
+
|
103 |
+
# Nếu cả hình ảnh tải lên và hình ảnh thật đã được tải lên
|
104 |
+
if uploaded_image is not None and uploaded_real_images:
|
105 |
+
# Xử lý ảnh
|
106 |
+
img = preprocess_image(Image.open(uploaded_image).convert('RGB'))
|
107 |
+
real_imgs = [preprocess_image(Image.open(img).convert('RGB')) for img in uploaded_real_images]
|
108 |
+
|
109 |
+
# Chạy inference
|
110 |
+
result = inference_with_real_set(model, img, real_imgs, threshold)
|
111 |
+
|
112 |
+
with col2:
|
113 |
+
# Hiển thị nút kết quả với màu sắc tùy chỉnh
|
114 |
+
if result == "Real":
|
115 |
+
button_color = "background-color: green; color: white; font-weight: bold;"
|
116 |
+
button_label = "Real"
|
117 |
+
else:
|
118 |
+
button_color = "background-color: red; color: white; font-weight: bold;"
|
119 |
+
button_label = "Fake"
|
120 |
+
|
121 |
+
st.markdown(
|
122 |
+
f"""
|
123 |
+
<div style="display: flex; justify-content: center; margin-top: 20px;">
|
124 |
+
<button style="padding: 10px 20px; font-size: 18px; {button_color}">{button_label}</button>
|
125 |
+
</div>
|
126 |
+
""",
|
127 |
+
unsafe_allow_html=True
|
128 |
+
)
|
129 |
+
|
130 |
+
# Hiển thị hình ảnh và tham chiếu
|
131 |
+
sub_col1, sub_col2 = st.columns([1, 1])
|
132 |
+
|
133 |
+
# Cột nhỏ bên trái: Hình ảnh tải lên
|
134 |
+
with sub_col1:
|
135 |
+
st.write("\n")
|
136 |
+
st.write("**Input Image**")
|
137 |
+
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
|
138 |
+
|
139 |
+
# Cột nhỏ bên phải: Hình ảnh tham chiếu
|
140 |
+
with sub_col2:
|
141 |
+
st.write("\n")
|
142 |
+
st.write("**Real Reference Images**")
|
143 |
+
for real_image in uploaded_real_images:
|
144 |
+
st.image(real_image, caption="Real Image", use_column_width=True)
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
albumentations
|
4 |
+
timm
|
5 |
+
numpy
|
6 |
+
onnxruntime
|
weights/best_efficientnet_b2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e43e5d65c1ac1c3aeb99d2366a7a3bfd81c103f5c3fd96a6e837f7c1e5e9ec82
|
3 |
+
size 32841296
|
weights/model_efficientnet_b2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b91a9fc692af66d06cddad847ad68fe9b5983c883f4057b04315c198ac26d85d
|
3 |
+
size 32433908
|