Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from plip_support import embed_text | |
import numpy as np | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import streamlit as st | |
import clip | |
import torch | |
from transformers import ( | |
VisionTextDualEncoderModel, | |
AutoFeatureExtractor, | |
AutoTokenizer | |
) | |
from transformers import AutoProcessor | |
def embed_texts(model, texts, processor): | |
inputs = processor(text=texts, padding="longest") | |
input_ids = torch.tensor(inputs["input_ids"]) | |
attention_mask = torch.tensor(inputs["attention_mask"]) | |
with torch.no_grad(): | |
embeddings = model.get_text_features( | |
input_ids=input_ids, attention_mask=attention_mask | |
) | |
return embeddings | |
def load_embeddings(embeddings_path): | |
print("loading embeddings") | |
return np.load(embeddings_path) | |
def load_path_clip(): | |
model = VisionTextDualEncoderModel.from_pretrained("clip-italian/clip-italian") | |
processor = AutoProcessor.from_pretrained("clip-italian/clip-italian") | |
return model, processor | |
st.title('PLIP Image Search') | |
plip_dataset = pd.read_csv("tweet_eval_retrieval.tsv", sep="\t") | |
model, processor = load_path_clip() | |
image_embedding = load_embeddings("tweet_eval_embeddings.npy") | |
query = st.text_input('Search Query', '') | |
if query: | |
text_embedding = embed_texts(model, [query], processor)[0].detach().cpu().numpy() | |
text_embedding = text_embedding/np.linalg.norm(text_embedding) | |
best_id = np.argmax(text_embedding.dot(image_embedding.T)) | |
url = (plip_dataset.iloc[best_id]["imageURL"]) | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
st.image(img) | |