import streamlit as st import torch import torch.nn as nn from PIL import Image import albumentations as A from albumentations.pytorch import ToTensorV2 import timm import numpy as np import onnxruntime as ort import warnings warnings.filterwarnings("ignore") # Khởi tạo mô hình class SiameseNetwork(nn.Module): def __init__(self, model_name='resnet18', pretrained=True): super(SiameseNetwork, self).__init__() self.encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.fc = nn.Sequential( nn.Linear(self.encoder.num_features, 256), nn.ReLU(inplace=True), nn.Linear(256, 128) ) def forward_once(self, x): output = self.encoder(x) output = self.fc(output) return output def forward(self, img1, img2): output1 = self.forward_once(img1) output2 = self.forward_once(img2) return output1, output2 # Load the model @st.cache_resource def load_model(): backbone = 'efficientnet_b2' onnx_file_path = f"weights/model_{backbone}.onnx" session = ort.InferenceSession(onnx_file_path) return session # Function to run inference def inference_with_real_set(session, img, real_imgs, threshold=0.5): distances = [] for real_img in real_imgs: input_dict = { session.get_inputs()[0].name: img.numpy(), session.get_inputs()[1].name: real_img.numpy() } outputs = session.run(None, input_dict) euclidean_distance = np.linalg.norm(outputs[0] - outputs[1]) distances.append(euclidean_distance) avg_distance = sum(distances) / len(distances) return "Real" if avg_distance < threshold else "Fake" # Image preprocessing function def preprocess_image(image): transform = A.Compose([ A.Resize(256, 256), A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ToTensorV2(), ]) image = np.array(image)/255.0 image = transform(image=image)['image'].unsqueeze(0) return image # Hide Streamlit warnings and footers hide_warning = """ """ st.markdown(hide_warning, unsafe_allow_html=True) # Streamlit interface st.title("Forge Signature Siamese") # Load the model model = load_model() # Chia màn hình thành 2 cột col1, col2 = st.columns([1, 1]) # Cột bên trái để input, cột bên phải để hiển thị kết quả with col1: # Upload image to compare uploaded_image = st.file_uploader("Upload an image to verify", type=["png", "jpg", "jpeg"]) # Upload real reference images uploaded_real_images = st.file_uploader("Upload real reference images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) # Threshold slider threshold = st.slider("Threshold", 0.0, 1.0, 0.5) # Nếu cả hình ảnh tải lên và hình ảnh thật đã được tải lên if uploaded_image is not None and uploaded_real_images: # Xử lý ảnh img = preprocess_image(Image.open(uploaded_image).convert('RGB')) real_imgs = [preprocess_image(Image.open(img).convert('RGB')) for img in uploaded_real_images] # Chạy inference result = inference_with_real_set(model, img, real_imgs, threshold) with col2: # Hiển thị nút kết quả với màu sắc tùy chỉnh if result == "Real": button_color = "background-color: green; color: white; font-weight: bold;" button_label = "Real" else: button_color = "background-color: red; color: white; font-weight: bold;" button_label = "Fake" st.markdown( f"""
""", unsafe_allow_html=True ) # Hiển thị hình ảnh và tham chiếu sub_col1, sub_col2 = st.columns([1, 1]) # Cột nhỏ bên trái: Hình ảnh tải lên with sub_col1: st.write("\n") st.write("**Input Image**") st.image(uploaded_image, caption="Uploaded Image", use_column_width=True) # Cột nhỏ bên phải: Hình ảnh tham chiếu with sub_col2: st.write("\n") st.write("**Real Reference Images**") for real_image in uploaded_real_images: st.image(real_image, caption="Real Image", use_column_width=True)