abokbot commited on
Commit
d45163a
·
1 Parent(s): 3ad472b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -21,6 +21,7 @@ wikipedia_embedding = load_embedding()
21
  st.success('Embedding loaded!')
22
  st_model_load.text("")
23
 
 
24
  def load_encoders():
25
  print("Loading encoders...")
26
  bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
@@ -40,20 +41,11 @@ st_text_area = st.text_area(
40
  value=st.session_state.text,
41
  height=100
42
  )
43
- """
44
-
45
 
46
- #We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
47
- # cf https://www.sbert.net/docs/pretrained-models/msmarco-v3.html
48
- bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
49
- bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
50
- top_k = 32 #Number of passages we want to retrieve with the bi-encoder
51
 
52
- #The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
53
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2')
54
-
55
- def search(query):
56
- print("Input question:", query)
57
  ##### Sematic Search #####
58
  # Encode the query using the bi-encoder and find potentially relevant passages
59
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
@@ -81,4 +73,6 @@ def search(query):
81
  "link: ", dataset["url"][hit['corpus_id']],"\n")
82
 
83
 
84
- """
 
 
 
21
  st.success('Embedding loaded!')
22
  st_model_load.text("")
23
 
24
+ @st.cache_resource
25
  def load_encoders():
26
  print("Loading encoders...")
27
  bi_encoder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
 
41
  value=st.session_state.text,
42
  height=100
43
  )
 
 
44
 
 
 
 
 
 
45
 
46
+ def search():
47
+ st.session_state.text = st_text_area
48
+ query = st_text_area
 
 
49
  ##### Sematic Search #####
50
  # Encode the query using the bi-encoder and find potentially relevant passages
51
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
 
73
  "link: ", dataset["url"][hit['corpus_id']],"\n")
74
 
75
 
76
+ # search button
77
+ st_search_button = st.button('Search', on_click=search)
78
+