binhduong2310 commited on
Commit
67c6064
1 Parent(s): dfbd5be

Upload 4 files

Browse files
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import albumentations as A
6
+ from albumentations.pytorch import ToTensorV2
7
+ import timm
8
+ import numpy as np
9
+ import onnxruntime as ort
10
+
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ # Khởi tạo mô hình
15
+ class SiameseNetwork(nn.Module):
16
+ def __init__(self, model_name='resnet18', pretrained=True):
17
+ super(SiameseNetwork, self).__init__()
18
+ self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
19
+
20
+ self.fc = nn.Sequential(
21
+ nn.Linear(self.encoder.num_features, 256),
22
+ nn.ReLU(inplace=True),
23
+ nn.Linear(256, 128)
24
+ )
25
+
26
+ def forward_once(self, x):
27
+ output = self.encoder(x)
28
+ output = self.fc(output)
29
+ return output
30
+
31
+ def forward(self, img1, img2):
32
+ output1 = self.forward_once(img1)
33
+ output2 = self.forward_once(img2)
34
+ return output1, output2
35
+
36
+ # Load the model
37
+ @st.cache_resource
38
+ def load_model():
39
+ backbone = 'efficientnet_b2'
40
+ onnx_file_path = f"weights/model_{backbone}.onnx"
41
+ session = ort.InferenceSession(onnx_file_path)
42
+ return session
43
+
44
+ # Function to run inference
45
+ def inference_with_real_set(session, img, real_imgs, threshold=0.5):
46
+ distances = []
47
+ for real_img in real_imgs:
48
+ input_dict = {
49
+ session.get_inputs()[0].name: img.numpy(),
50
+ session.get_inputs()[1].name: real_img.numpy()
51
+ }
52
+ outputs = session.run(None, input_dict)
53
+
54
+ euclidean_distance = np.linalg.norm(outputs[0] - outputs[1])
55
+
56
+ distances.append(euclidean_distance)
57
+
58
+ avg_distance = sum(distances) / len(distances)
59
+
60
+ return "Real" if avg_distance < threshold else "Fake"
61
+
62
+ # Image preprocessing function
63
+ def preprocess_image(image):
64
+ transform = A.Compose([
65
+ A.Resize(256, 256),
66
+ A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
67
+ ToTensorV2(),
68
+ ])
69
+ image = np.array(image)/255.0
70
+ image = transform(image=image)['image'].unsqueeze(0)
71
+ return image
72
+
73
+ # Hide Streamlit warnings and footers
74
+ hide_warning = """
75
+ <style>
76
+ .stAlert {display: none;}
77
+ .css-18e3th9 {padding-top: 2rem; padding-bottom: 2rem;} /* Adjust top and bottom padding */
78
+ .css-1d391kg {max-width: 100% !important; padding-left: 1rem; padding-right: 1rem;} /* Adjust the max-width and side padding */
79
+ </style>
80
+ """
81
+
82
+ st.markdown(hide_warning, unsafe_allow_html=True)
83
+
84
+ # Streamlit interface
85
+ st.title("Forge Signature Siamese")
86
+
87
+ # Load the model
88
+ model = load_model()
89
+
90
+ # Chia màn hình thành 2 cột
91
+ col1, col2 = st.columns([1, 1]) # Cột bên trái để input, cột bên phải để hiển thị kết quả
92
+
93
+ with col1:
94
+ # Upload image to compare
95
+ uploaded_image = st.file_uploader("Upload an image to verify", type=["png", "jpg", "jpeg"])
96
+
97
+ # Upload real reference images
98
+ uploaded_real_images = st.file_uploader("Upload real reference images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
99
+
100
+ # Threshold slider
101
+ threshold = st.slider("Threshold", 0.0, 1.0, 0.5)
102
+
103
+ # Nếu cả hình ảnh tải lên và hình ảnh thật đã được tải lên
104
+ if uploaded_image is not None and uploaded_real_images:
105
+ # Xử lý ảnh
106
+ img = preprocess_image(Image.open(uploaded_image).convert('RGB'))
107
+ real_imgs = [preprocess_image(Image.open(img).convert('RGB')) for img in uploaded_real_images]
108
+
109
+ # Chạy inference
110
+ result = inference_with_real_set(model, img, real_imgs, threshold)
111
+
112
+ with col2:
113
+ # Hiển thị nút kết quả với màu sắc tùy chỉnh
114
+ if result == "Real":
115
+ button_color = "background-color: green; color: white; font-weight: bold;"
116
+ button_label = "Real"
117
+ else:
118
+ button_color = "background-color: red; color: white; font-weight: bold;"
119
+ button_label = "Fake"
120
+
121
+ st.markdown(
122
+ f"""
123
+ <div style="display: flex; justify-content: center; margin-top: 20px;">
124
+ <button style="padding: 10px 20px; font-size: 18px; {button_color}">{button_label}</button>
125
+ </div>
126
+ """,
127
+ unsafe_allow_html=True
128
+ )
129
+
130
+ # Hiển thị hình ảnh và tham chiếu
131
+ sub_col1, sub_col2 = st.columns([1, 1])
132
+
133
+ # Cột nhỏ bên trái: Hình ảnh tải lên
134
+ with sub_col1:
135
+ st.write("\n")
136
+ st.write("**Input Image**")
137
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
138
+
139
+ # Cột nhỏ bên phải: Hình ảnh tham chiếu
140
+ with sub_col2:
141
+ st.write("\n")
142
+ st.write("**Real Reference Images**")
143
+ for real_image in uploaded_real_images:
144
+ st.image(real_image, caption="Real Image", use_column_width=True)
145
+
146
+
147
+
148
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ albumentations
4
+ timm
5
+ numpy
6
+ onnxruntime
weights/best_efficientnet_b2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e43e5d65c1ac1c3aeb99d2366a7a3bfd81c103f5c3fd96a6e837f7c1e5e9ec82
3
+ size 32841296
weights/model_efficientnet_b2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b91a9fc692af66d06cddad847ad68fe9b5983c883f4057b04315c198ac26d85d
3
+ size 32433908