File size: 3,592 Bytes
1e57a2c
4cbcb94
1e57a2c
 
 
 
 
 
 
 
4c37b25
1e57a2c
 
e7ded97
1e57a2c
7a3f7ed
84b4358
1e57a2c
 
 
8fc12bf
 
d45163a
3ad472b
 
 
 
 
 
3af85d8
3ad472b
 
 
c5ea378
 
 
 
988f448
3af85d8
c5ea378
 
 
555cafd
c5ea378
3ad472b
 
 
e674289
3ad472b
d17e4ec
3ad472b
e674289
3ad472b
1e57a2c
 
d45163a
 
 
3af85d8
 
1e57a2c
3af85d8
1e57a2c
fdde008
1e57a2c
c5ea378
1e57a2c
 
 
 
3af85d8
94bb2b9
1e57a2c
 
 
 
 
d51ec0e
 
1e57a2c
 
 
4c37b25
 
e6f2d47
4c37b25
 
 
 
 
 
 
e6f2d47
1e57a2c
 
d45163a
e6f2d47
 
 
ec7a970
5119cdd
e674289
667649c
555cafd
667649c
ec7a970
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
from huggingface_hub import hf_hub_download

embedding_path = "abokbot/wikipedia-embedding"

st.header("Wikipedia Search Engine app")

st_model_load = st.text('Loading embeddings, encoders and dataset (takes about 5min)')

@st.cache_resource
def load_embedding():
    print("Loading embedding...")
    path = hf_hub_download(repo_id="abokbot/wikipedia-embedding", filename="wikipedia_en_embedding.pt")
    wikipedia_embedding = torch.load(path, map_location=torch.device('cpu')) 
    print("Embedding loaded!")
    return wikipedia_embedding

wikipedia_embedding = load_embedding()

@st.cache_resource
def load_encoders():
    print("Loading encoders...")
    bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
    bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
    top_k = 32  
    cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
    print("Encoders loaded!")
    return bi_encoder, cross_encoder

bi_encoder, cross_encoder = load_encoders()

@st.cache_resource
def load_wikipedia_dataset():
    print("Loading wikipedia dataset...")
    dataset = load_dataset("abokbot/wikipedia-first-paragraph")["train"]
    print("Dataset loaded!")
    return dataset
    
dataset = load_wikipedia_dataset()
st.success('Search engine ready')
st_model_load.text("")
    
if 'text' not in st.session_state:
    st.session_state.text = ""
st.markdown("Enter query")
st_text_area = st.text_area(
    'E.g. What is the hashing trick? or Largest city in Morocco', 
    value=st.session_state.text, 
    height=25
)


def search():
    st.session_state.text = st_text_area
    query = st_text_area
    print("Input question:", query)
    
    ##### Sematic Search #####
    print("Semantic Search")
    # Encode the query using the bi-encoder and find potentially relevant passages
    top_k = 32
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(question_embedding, wikipedia_embedding, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    print("Re-Ranking")
    cross_inp = [[query, dataset[hit['corpus_id']]["text"]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]
    
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    # Output of top-3 hits from re-ranker
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    results = []
    for hit in hits[:3]:
        results.append(
            {
                "score": round(hit['cross-score'], 3),
                "title": dataset[hit['corpus_id']]["title"],
                "abstract": dataset[hit['corpus_id']]["text"].replace("\n", " "),
                "link": dataset[hit['corpus_id']]["url"]
            }
        )
    return results


# search button
st_search_button = st.button('Search')
if st_search_button:
    results = search()
    st.subheader("Top-3 Search results")
    for i, result in enumerate(results):
        st.markdown(f"#### Result {i+1}")
        st.markdown("**Wikipedia article:** " + result["title"])
        st.markdown("**Link:** " + result["link"])
        st.markdown("**First paragraph:** " + result["abstract"])
        st.text("")