|
import streamlit as st |
|
from sentence_transformers import SentenceTransformer |
|
import datasets |
|
import faiss |
|
import torch |
|
from sentence_transformers.util import semantic_search |
|
import time |
|
|
|
|
|
if "initialized" not in st.session_state: |
|
st.session_state.dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload")['train'] |
|
st.session_state.all_authors = list(set(st.session_state.dataset['author'])) |
|
model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
st.session_state.encoder = SentenceTransformer(model_name) |
|
st.session_state.embeddings = st.session_state.encoder.encode( |
|
st.session_state.dataset["quote"], |
|
batch_size=4, |
|
show_progress_bar=True, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True, |
|
) |
|
st.session_state.initialized=True |
|
|
|
dataset_embeddings_tensor = torch.Tensor(st.session_state.embeddings) |
|
|
|
sentence = "Knowledge of history is power." |
|
|
|
def search(query, selected_authors): |
|
start = time.time() |
|
if len(query.strip()) == 0: |
|
return "" |
|
|
|
query_embedding = st.session_state.encoder.encode([query]) |
|
sentence_embedding_tensor = torch.Tensor(query_embedding) |
|
|
|
if len(selected_authors) == 0: |
|
author_indexes = [i for i in range(len(st.session_state.dataset))] |
|
else: |
|
author_indexes = [i for i in range(len(st.session_state.dataset)) if st.session_state.dataset['author'][i] in selected_authors] |
|
hits = semantic_search(sentence_embedding_tensor, dataset_embeddings_tensor[author_indexes, :], top_k=5) |
|
|
|
indices = [author_indexes[i['corpus_id']] for i in hits[0]] |
|
|
|
if len(indices) == 0: |
|
return "" |
|
result = "\n\n" |
|
for i in indices: |
|
result += f"###### {st.session_state.dataset['author'][i]}\n> {st.session_state.dataset['quote'][i]}\n----\n" |
|
delay = "%.3f" % (time.time() - start) |
|
return f"_Computation time: **{delay} seconds**_{result}" |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
div[data-testid="column"] |
|
{ |
|
align-self:flex-end; |
|
} |
|
</style> |
|
""",unsafe_allow_html=True |
|
) |
|
col1, col2 = st.columns([8, 2]) |
|
text_input = col1.text_input("Type your idea here:") |
|
submit_button = col2.button("_Search quotes!_") |
|
selected_authors = st.multiselect("(Optional) - Restrict search to these authors:", st.session_state.all_authors) |
|
|
|
if submit_button: |
|
st.markdown(search(text_input, selected_authors)) |
|
|
|
|