from datasets import load_dataset import streamlit as st import torch from transformers import AutoTokenizer, AutoModel import faiss import numpy as np import wget from PIL import Image from io import BytesIO from sentence_transformers import SentenceTransformer # dataset = load_dataset("nlphuji/flickr30k", streaming=True) # df = pd.DataFrame.from_dict(dataset["train"]) # Load the pre-trained sentence encoder model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = SentenceTransformer(model_name) # # Load the pre-trained image model # image_model_name = 'image_model.ckpt' # image_model_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/images/vqvae.png' # wget.download(image_model_url, image_model_name) # image_model = torch.load(image_model_name, map_location=torch.device('cpu')) # image_model.eval() # Load the FAISS index index_name = 'index.faiss' index_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/faiss.index' wget.download(index_url, index_name) index = faiss.read_index(index_name) # Map the image ids to the corresponding image URLs image_map_name = 'image_map.json' image_map_url = 'https://huggingface.co/models/flax-community/deit-tiny-random/faiss_files/image_map.json' wget.download(image_map_url, image_map_name) image_map = {} with open(image_map_name, 'r') as f: image_map = json.load(f) def search(query, k=5): # Encode the query query_tokens = tokenizer.encode(query, return_tensors='pt') query_embedding = model.encode(query_tokens).detach().numpy() # Search for the nearest neighbors in the FAISS index D, I = index.search(query_embedding, k) # Map the image ids to the corresponding image URLs image_urls = [] for i in I[0]: image_id = str(i) image_url = image_map[image_id] image_urls.append(image_url) return image_urls st.title("Image Search App") query = st.text_input("Enter your search query here:") if st.button("Search"): if query: image_urls = search(query) # Display the images st.image(image_urls, width=200) if __name__ == '__main__': st.set_page_config(page_title='Image Search App', layout='wide') st.cache(allow_output_mutation=True) run_app()