|
import streamlit as st |
|
from sentence_transformers import SentenceTransformer |
|
import datasets |
|
import faiss |
|
import torch |
|
|
|
st.sidebar.text_input("Type your quote here") |
|
|
|
dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload") |
|
|
|
dataset = datasets.Dataset.from_dict(dataset['train'][:100]) |
|
|
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
encoder = SentenceTransformer(model_name) |
|
|
|
embeddings = encoder.encode( |
|
dataset["quote"], |
|
batch_size=4, |
|
show_progress_bar=True, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sentence = "Knowledge of history is power." |
|
sentence_embedding = encoder.encode([sentence]) |
|
|
|
|
|
|
|
sentence_embedding_tensor = torch.Tensor(sentence_embedding) |
|
dataset_embeddings_tensor = torch.Tensor(embeddings) |
|
from sentence_transformers.util import semantic_search |
|
|
|
author_indexes = list(range(10)) |
|
hits = semantic_search(sentence_embedding_tensor, dataset_embeddings_tensor[author_indexes, :], top_k=5) |
|
|
|
list_hits = [author_indexes[i['corpus_id']] for i in hits[0]] |
|
print(list_hits) |
|
print(dataset) |
|
st.write(dataset.select(list_hits)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|