binhduong2310's picture
Upload 4 files
67c6064 verified
import streamlit as st
import torch
import torch.nn as nn
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import numpy as np
import onnxruntime as ort
import warnings
warnings.filterwarnings("ignore")
# Khởi tạo mô hình
class SiameseNetwork(nn.Module):
def __init__(self, model_name='resnet18', pretrained=True):
super(SiameseNetwork, self).__init__()
self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
self.fc = nn.Sequential(
nn.Linear(self.encoder.num_features, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 128)
)
def forward_once(self, x):
output = self.encoder(x)
output = self.fc(output)
return output
def forward(self, img1, img2):
output1 = self.forward_once(img1)
output2 = self.forward_once(img2)
return output1, output2
# Load the model
@st.cache_resource
def load_model():
backbone = 'efficientnet_b2'
onnx_file_path = f"weights/model_{backbone}.onnx"
session = ort.InferenceSession(onnx_file_path)
return session
# Function to run inference
def inference_with_real_set(session, img, real_imgs, threshold=0.5):
distances = []
for real_img in real_imgs:
input_dict = {
session.get_inputs()[0].name: img.numpy(),
session.get_inputs()[1].name: real_img.numpy()
}
outputs = session.run(None, input_dict)
euclidean_distance = np.linalg.norm(outputs[0] - outputs[1])
distances.append(euclidean_distance)
avg_distance = sum(distances) / len(distances)
return "Real" if avg_distance < threshold else "Fake"
# Image preprocessing function
def preprocess_image(image):
transform = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
ToTensorV2(),
])
image = np.array(image)/255.0
image = transform(image=image)['image'].unsqueeze(0)
return image
# Hide Streamlit warnings and footers
hide_warning = """
<style>
.stAlert {display: none;}
.css-18e3th9 {padding-top: 2rem; padding-bottom: 2rem;} /* Adjust top and bottom padding */
.css-1d391kg {max-width: 100% !important; padding-left: 1rem; padding-right: 1rem;} /* Adjust the max-width and side padding */
</style>
"""
st.markdown(hide_warning, unsafe_allow_html=True)
# Streamlit interface
st.title("Forge Signature Siamese")
# Load the model
model = load_model()
# Chia màn hình thành 2 cột
col1, col2 = st.columns([1, 1]) # Cột bên trái để input, cột bên phải để hiển thị kết quả
with col1:
# Upload image to compare
uploaded_image = st.file_uploader("Upload an image to verify", type=["png", "jpg", "jpeg"])
# Upload real reference images
uploaded_real_images = st.file_uploader("Upload real reference images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
# Threshold slider
threshold = st.slider("Threshold", 0.0, 1.0, 0.5)
# Nếu cả hình ảnh tải lên và hình ảnh thật đã được tải lên
if uploaded_image is not None and uploaded_real_images:
# Xử lý ảnh
img = preprocess_image(Image.open(uploaded_image).convert('RGB'))
real_imgs = [preprocess_image(Image.open(img).convert('RGB')) for img in uploaded_real_images]
# Chạy inference
result = inference_with_real_set(model, img, real_imgs, threshold)
with col2:
# Hiển thị nút kết quả với màu sắc tùy chỉnh
if result == "Real":
button_color = "background-color: green; color: white; font-weight: bold;"
button_label = "Real"
else:
button_color = "background-color: red; color: white; font-weight: bold;"
button_label = "Fake"
st.markdown(
f"""
<div style="display: flex; justify-content: center; margin-top: 20px;">
<button style="padding: 10px 20px; font-size: 18px; {button_color}">{button_label}</button>
</div>
""",
unsafe_allow_html=True
)
# Hiển thị hình ảnh và tham chiếu
sub_col1, sub_col2 = st.columns([1, 1])
# Cột nhỏ bên trái: Hình ảnh tải lên
with sub_col1:
st.write("\n")
st.write("**Input Image**")
st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
# Cột nhỏ bên phải: Hình ảnh tham chiếu
with sub_col2:
st.write("\n")
st.write("**Real Reference Images**")
for real_image in uploaded_real_images:
st.image(real_image, caption="Real Image", use_column_width=True)