import streamlit as st from sentence_transformers import SentenceTransformer import datasets import faiss import torch from sentence_transformers.util import semantic_search import time if "initialized" not in st.session_state: st.session_state.dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload")['train'] st.session_state.all_authors = list(set(st.session_state.dataset['author'])) model_name = "sentence-transformers/all-MiniLM-L6-v2" # BAAI/bge-small-en-v1.5" # "Cohere/Cohere-embed-english-light-v3.0" # "sentence-transformers/all-MiniLM-L6-v2" st.session_state.encoder = SentenceTransformer(model_name) st.session_state.embeddings = st.session_state.encoder.encode( st.session_state.dataset["quote"], batch_size=4, show_progress_bar=True, convert_to_numpy=True, normalize_embeddings=True, ) st.session_state.initialized=True dataset_embeddings_tensor = torch.Tensor(st.session_state.embeddings) sentence = "Knowledge of history is power." def search(query, selected_authors): start = time.time() if len(query.strip()) == 0: return "" query_embedding = st.session_state.encoder.encode([query]) sentence_embedding_tensor = torch.Tensor(query_embedding) if len(selected_authors) == 0: author_indexes = [i for i in range(len(st.session_state.dataset))] else: author_indexes = [i for i in range(len(st.session_state.dataset)) if st.session_state.dataset['author'][i] in selected_authors] hits = semantic_search(sentence_embedding_tensor, dataset_embeddings_tensor[author_indexes, :], top_k=5) indices = [author_indexes[i['corpus_id']] for i in hits[0]] if len(indices) == 0: return "" result = "\n\n" for i in indices: result += f"###### {st.session_state.dataset['author'][i]}\n> {st.session_state.dataset['quote'][i]}\n----\n" delay = "%.3f" % (time.time() - start) return f"_Computation time: **{delay} seconds**_{result}" st.markdown( """ """,unsafe_allow_html=True ) col1, col2 = st.columns([8, 2]) text_input = col1.text_input("Type your idea here:") submit_button = col2.button("_Search quotes!_") selected_authors = st.multiselect("(Optional) - Restrict search to these authors:", st.session_state.all_authors) if submit_button: st.markdown(search(text_input, selected_authors))