import os import numpy as np import boto3 import streamlit as st import faiss import pandas as pd from PIL import Image from model_prediction import Ranker from io import BytesIO import cv2 @st.cache def load_model(): return Ranker() def load_faiss_index(): return faiss.read_index('embeddings.index') def load_labels(): return pd.read_csv("labels.csv") class ModelLoader: model = None index = None labels = None @classmethod def get_model(cls): if cls.model is None: cls.model = load_model() return cls.model @classmethod def get_index(cls): if cls.index is None: cls.index = load_faiss_index() return cls.index @classmethod def get_labels(cls): if cls.labels is None: cls.labels = load_labels() return cls.labels target_size = (224, 224) st.set_page_config(page_title="Product Retrieval App") st.title("Product Retrieval App") st.markdown("""The Product Retrieval App is a demonstration of a computer vision model created by Intelliarts. It can analyze and interpret visual data , i.e., shapes, colors, and textures from uploaded digital images. The data is then used to conduct a search on the web. The output of the computer vision model is a set of images that are predicted to be most similar to the input image. To use the Product Retrieval App, you need to: 1. Select an image that depicts the item of interest. Acceptable formats are JPG, JPEG, and PNG. 2. Upload the image by either dragging and dropping the file into the search field or selecting a file from your computer using the “browse files” button. 3. Scroll to the bottom of the page to review the output results.""", unsafe_allow_html=True) uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) loading_text = st.empty() s3 = boto3.client( 's3', aws_access_key_id='AKIAUUWYLZEQYT6ESW4Q', aws_secret_access_key='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', region_name='eu-west-1' ) bucket_name = "product-retrieval" if uploaded_file is not None: image = Image.open(uploaded_file) image = np.asarray(image) if len(image.shape) > 2 and image.shape[2] == 4: image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR) image = Image.fromarray(image) st.image(image, caption="Uploaded image", use_column_width=True) loading_text.text("Loading predictions...") model = ModelLoader.get_model() index = ModelLoader.get_index() labels = ModelLoader.get_labels() image_embedding = model.predict(image) distances, indices = index.search(image_embedding, 12) predicted_images = labels["path"][indices[0]].to_list() loading_text.empty() col1, col2, col3, col4 = st.columns(4) for i, img_path in enumerate(predicted_images): response = s3.get_object(Bucket=bucket_name, Key=img_path.split("/")[-1]) image_data = response['Body'].read() img = Image.open(BytesIO(image_data)).resize(target_size) if i % 4 == 0: column = col1 elif i % 4 == 1: column = col2 elif i % 4 == 2: column = col3 else: column = col4 with column: st.image(img, caption=f"Predicted image {i+1}", use_column_width=True)