A-Roucher commited on
Commit
beda96a
1 Parent(s): b19e82a

feat: new faiss index

Browse files
Files changed (1) hide show
  1. app.py +14 -31
app.py CHANGED
@@ -2,50 +2,34 @@ import streamlit as st
2
  from sentence_transformers import SentenceTransformer
3
  import datasets
4
  import faiss
5
- import torch
6
- from sentence_transformers.util import semantic_search
7
  import time
 
8
 
9
 
10
  if "initialized" not in st.session_state:
11
  st.session_state.dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload")['train']
12
  st.session_state.all_authors = list(set(st.session_state.dataset['author']))
13
- model_name = "sentence-transformers/all-MiniLM-L6-v2" # BAAI/bge-small-en-v1.5" # "Cohere/Cohere-embed-english-light-v3.0" # "sentence-transformers/all-MiniLM-L6-v2"
14
  st.session_state.encoder = SentenceTransformer(model_name)
15
- st.session_state.embeddings = st.session_state.encoder.encode(
16
- st.session_state.dataset["quote"],
17
- batch_size=4,
18
- show_progress_bar=True,
19
- convert_to_numpy=True,
20
- normalize_embeddings=True,
21
- )
22
  st.session_state.initialized=True
23
 
24
- dataset_embeddings_tensor = torch.Tensor(st.session_state.embeddings)
25
-
26
- sentence = "Knowledge of history is power."
27
-
28
- def search(query, selected_authors):
29
  start = time.time()
30
  if len(query.strip()) == 0:
31
  return ""
32
 
33
  query_embedding = st.session_state.encoder.encode([query])
34
- sentence_embedding_tensor = torch.Tensor(query_embedding)
35
-
36
- if len(selected_authors) == 0:
37
- author_indexes = [i for i in range(len(st.session_state.dataset))]
38
- else:
39
- author_indexes = [i for i in range(len(st.session_state.dataset)) if st.session_state.dataset['author'][i] in selected_authors]
40
- hits = semantic_search(sentence_embedding_tensor, dataset_embeddings_tensor[author_indexes, :], top_k=5)
41
 
42
- indices = [author_indexes[i['corpus_id']] for i in hits[0]]
43
-
44
- if len(indices) == 0:
45
- return ""
 
46
  result = "\n\n"
47
- for i in indices:
48
- result += f"###### {st.session_state.dataset['author'][i]}\n> {st.session_state.dataset['quote'][i]}\n----\n"
 
49
  delay = "%.3f" % (time.time() - start)
50
  return f"_Computation time: **{delay} seconds**_{result}"
51
 
@@ -61,10 +45,9 @@ st.markdown(
61
  """,unsafe_allow_html=True
62
  )
63
  col1, col2 = st.columns([8, 2])
64
- text_input = col1.text_input("Type your idea here:")
65
  submit_button = col2.button("_Search quotes!_")
66
- selected_authors = st.multiselect("(Optional) - Restrict search to these authors:", st.session_state.all_authors)
67
 
68
  if submit_button:
69
- st.markdown(search(text_input, selected_authors))
70
 
 
2
  from sentence_transformers import SentenceTransformer
3
  import datasets
4
  import faiss
 
 
5
  import time
6
+ import faiss
7
 
8
 
9
  if "initialized" not in st.session_state:
10
  st.session_state.dataset = datasets.load_dataset('A-Roucher/english_historical_quotes', download_mode="force_redownload")['train']
11
  st.session_state.all_authors = list(set(st.session_state.dataset['author']))
12
+ model_name = "BAAI/bge-small-en-v1.5" # "sentence-transformers/all-MiniLM-L6-v2" # # "Cohere/Cohere-embed-english-light-v3.0" # "sentence-transformers/all-MiniLM-L6-v2"
13
  st.session_state.encoder = SentenceTransformer(model_name)
14
+ st.session_state.index = faiss.read_index('index_alone.faiss')
 
 
 
 
 
 
15
  st.session_state.initialized=True
16
 
17
+ def search(query):
 
 
 
 
18
  start = time.time()
19
  if len(query.strip()) == 0:
20
  return ""
21
 
22
  query_embedding = st.session_state.encoder.encode([query])
 
 
 
 
 
 
 
23
 
24
+ _, samples = st.session_state.index.search(
25
+ query_embedding, k=10
26
+ )
27
+ quotes = st.session_state.dataset.select(samples[0])
28
+
29
  result = "\n\n"
30
+ for i in range(len(quotes)):
31
+ result += f"###### {quotes['author'][i]}\n> {quotes['quote'][i]}\n----\n"
32
+
33
  delay = "%.3f" % (time.time() - start)
34
  return f"_Computation time: **{delay} seconds**_{result}"
35
 
 
45
  """,unsafe_allow_html=True
46
  )
47
  col1, col2 = st.columns([8, 2])
48
+ text_input = col1.text_input("Type your idea here:", placeholder="Knowledge of history is power.")
49
  submit_button = col2.button("_Search quotes!_")
 
50
 
51
  if submit_button:
52
+ st.markdown(search(text_input))
53