Aragoner commited on
Commit
b3c39b6
·
verified ·
1 Parent(s): bed9663

Update backend/semantic_search.py

Browse files
Files changed (1) hide show
  1. backend/semantic_search.py +10 -24
backend/semantic_search.py CHANGED
@@ -1,12 +1,9 @@
1
  import lancedb
2
- import os
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
  import torch
7
- import time
8
  import os
9
- from pathlib import Path
10
 
11
  db = lancedb.connect(".lancedb")
12
 
@@ -19,39 +16,28 @@ CROSS_ENCODER = os.getenv("CROSS_ENCODER")
19
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
20
  cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER)
21
  cross_encoder.eval()
 
22
  cross_encoder_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER)
23
 
24
 
25
- def rerank(query, documents, k):
26
- """Use cross-encoder to rerank documents retrieved from the retriever."""
27
- tokens = cross_encoder_tokenizer([query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt")
28
  with torch.no_grad():
29
- logits = cross_encoder(**tokens).logits
30
- scores = logits.reshape(-1).tolist()
31
- documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
32
- return [doc[0] for doc in documents[:k]]
33
-
34
-
35
- # def retrieve(query, k):
36
- # query_vec = retriever.encode(query)
37
- # try:
38
- # documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
39
- # documents = [doc[TEXT_COLUMN] for doc in documents]
40
- #
41
- # return documents
42
- #
43
- # except Exception as e:
44
- # raise gr.Error(str(e))
45
 
46
 
47
- def retrieve(query, top_k_retriever=25, use_reranking=True, top_k_reranker=5):
48
  query_vec = retriever.encode(query)
49
  try:
50
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(top_k_retriever).to_list()
51
  documents = [doc[TEXT_COLUMN] for doc in documents]
52
 
53
  if use_reranking:
54
- documents = rerank(query, documents, top_k_reranker)
55
 
56
  return documents
57
 
 
1
  import lancedb
 
2
  import gradio as gr
3
  from sentence_transformers import SentenceTransformer
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
 
6
  import os
 
7
 
8
  db = lancedb.connect(".lancedb")
9
 
 
16
  retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
17
  cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER)
18
  cross_encoder.eval()
19
+
20
  cross_encoder_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER)
21
 
22
 
23
+ def reranking(query, list_of_documents, k):
24
+ received_tokens = cross_encoder_tokenizer([query] * len(list_of_documents), list_of_documents, padding=True, truncation=True, return_tensors="pt")
 
25
  with torch.no_grad():
26
+ logits_on_tokens = cross_encoder(**received_tokens).logits
27
+ probabilities = logits_on_tokens.reshape(-1).tolist()
28
+ documents = sorted(zip(list_of_documents, probabilities), key=lambda x: x[1], reverse=True)
29
+ result = [document[0] for document in documents[:k]]
30
+ return result
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
+ def retrieve(query, top_k_retriever=30, use_reranking=True, top_k_reranker=5):
34
  query_vec = retriever.encode(query)
35
  try:
36
  documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(top_k_retriever).to_list()
37
  documents = [doc[TEXT_COLUMN] for doc in documents]
38
 
39
  if use_reranking:
40
+ documents = reranking(query, documents, top_k_reranker)
41
 
42
  return documents
43