File size: 3,592 Bytes
1e57a2c 4cbcb94 1e57a2c 4c37b25 1e57a2c e7ded97 1e57a2c 7a3f7ed 84b4358 1e57a2c 8fc12bf d45163a 3ad472b 3af85d8 3ad472b c5ea378 988f448 3af85d8 c5ea378 555cafd c5ea378 3ad472b e674289 3ad472b d17e4ec 3ad472b e674289 3ad472b 1e57a2c d45163a 3af85d8 1e57a2c 3af85d8 1e57a2c fdde008 1e57a2c c5ea378 1e57a2c 3af85d8 94bb2b9 1e57a2c d51ec0e 1e57a2c 4c37b25 e6f2d47 4c37b25 e6f2d47 1e57a2c d45163a e6f2d47 ec7a970 5119cdd e674289 667649c 555cafd 667649c ec7a970 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import streamlit as st
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from huggingface_hub import hf_hub_download
embedding_path = "abokbot/wikipedia-embedding"
st.header("Wikipedia Search Engine app")
st_model_load = st.text('Loading embeddings, encoders and dataset (takes about 5min)')
def load_embedding():
print("Loading embedding...")
path = hf_hub_download(repo_id="abokbot/wikipedia-embedding", filename="")
wikipedia_embedding = torch.load(path, map_location=torch.device('cpu'))
print("Embedding loaded!")
return wikipedia_embedding
wikipedia_embedding = load_embedding()
def load_encoders():
print("Loading encoders...")
bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
top_k = 32
cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
print("Encoders loaded!")
return bi_encoder, cross_encoder
bi_encoder, cross_encoder = load_encoders()
def load_wikipedia_dataset():
print("Loading wikipedia dataset...")
dataset = load_dataset("abokbot/wikipedia-first-paragraph")["train"]
print("Dataset loaded!")
return dataset
dataset = load_wikipedia_dataset()
st.success('Search engine ready')
if 'text' not in st.session_state:
st.session_state.text = ""
st.markdown("Enter query")
st_text_area = st.text_area(
'E.g. What is the hashing trick? or Largest city in Morocco',
def search():
st.session_state.text = st_text_area
query = st_text_area
print("Input question:", query)
##### Sematic Search #####
print("Semantic Search")
# Encode the query using the bi-encoder and find potentially relevant passages
top_k = 32
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, wikipedia_embedding, top_k=top_k)
hits = hits[0] # Get the hits for the first query
##### Re-Ranking #####
# Now, score all retrieved passages with the cross_encoder
cross_inp = [[query, dataset[hit['corpus_id']]["text"]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
# Output of top-3 hits from re-ranker
print("Top-3 Cross-Encoder Re-ranker hits")
results = []
for hit in hits[:3]:
"score": round(hit['cross-score'], 3),
"title": dataset[hit['corpus_id']]["title"],
"abstract": dataset[hit['corpus_id']]["text"].replace("\n", " "),
"link": dataset[hit['corpus_id']]["url"]
return results
# search button
st_search_button = st.button('Search')
if st_search_button:
results = search()
st.subheader("Top-3 Search results")
for i, result in enumerate(results):
st.markdown(f"#### Result {i+1}")
st.markdown("**Wikipedia article:** " + result["title"])
st.markdown("**Link:** " + result["link"])
st.markdown("**First paragraph:** " + result["abstract"])
st.text("") |