import streamlit as st import pandas as pd from plip_support import embed_text import numpy as np from PIL import Image import requests from io import BytesIO import streamlit as st import clip import torch from transformers import ( VisionTextDualEncoderModel, AutoFeatureExtractor, AutoTokenizer ) from transformers import AutoProcessor def embed_texts(model, texts, processor): inputs = processor(text=texts, padding="longest") input_ids = torch.tensor(inputs["input_ids"]) attention_mask = torch.tensor(inputs["attention_mask"]) with torch.no_grad(): embeddings = model.get_text_features( input_ids=input_ids, attention_mask=attention_mask ) return embeddings @st.cache_resource def load_embeddings(embeddings_path): print("loading embeddings") return np.load(embeddings_path) @st.cache_resource def load_path_clip(): model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian") processor = AutoProcessor.from_pretrained("clip-italian/clip-italian") return model, processor st.title('PLIP Image Search') plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") model, processor = load_path_clip() image_embedding = load_embeddings("tweet_eval_embeddings.npy") query = st.text_input('Search Query', '') if query: text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() text_embedding = text_embedding/np.linalg.norm(text_embedding) best_id = np.argmax(text_embedding.dot(image_embedding.T)) url = (plip_dataset.iloc[best_id]["imageURL"]) response = requests.get(url) img = Image.open(BytesIO(response.content)) st.image(img)