Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import faiss | |
import numpy as np | |
import wget | |
from PIL import Image | |
from io import BytesIO | |
from sentence_transformers import SentenceTransformer | |
# dataset = load_dataset("imagefolder", data_files="https://huggingface.co/datasets/nlphuji/flickr30k/blob/main/flickr30k-images.zip") | |
# Load the pre-trained sentence encoder | |
model_name = "sentence-transformers/all-distilroberta-v1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = SentenceTransformer(model_name) | |
# Load the FAISS index | |
index_name = 'index.faiss' | |
index_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/faiss_flickr8k.index' | |
wget.download(index_url, index_name) | |
index = faiss.read_index(index_name) | |
# Map the image ids to the corresponding image URLs | |
image_map_name = 'captions.json' | |
image_map_url = 'https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/captions.json' | |
wget.download(image_map_url, image_map_name) | |
with open(image_map_name, 'r') as f: | |
caption_dict = json.load(f) | |
image_list = list(caption_dict.keys()) | |
caption_list = list(caption_dict.values()) | |
def search(query, k=5): | |
# Encode the query | |
query_tokens = tokenizer.encode(query, return_tensors='pt') | |
query_embedding = model.encode(query_tokens).detach().numpy() | |
# Search for the nearest neighbors in the FAISS index | |
D, I = index.search(query_embedding, k) | |
# Map the image ids to the corresponding image URLs | |
image_urls = [] | |
for i in I[0]: | |
text_id = i | |
image_id = str(image_list[i]) | |
image_url = "https://huggingface.co/spaces/shivangibithel/Text2ImageRetrieval/blob/main/Images/" + image_id | |
image_urls.append(image_url) | |
return image_urls | |
st.title("Image Search App") | |
query = st.text_input("Enter your search query here:") | |
if st.button("Search"): | |
if query: | |
image_urls = search(query) | |
# Display the images | |
st.image(image_urls, width=200) | |
if __name__ == '__main__': | |
st.set_page_config(page_title='Image Search App', layout='wide') | |
st.cache(allow_output_mutation=True) | |
run_app() | |