Adir Gozlan commited on
Commit
6d5ec26
1 Parent(s): a200fe6

late commit

Browse files
Files changed (2) hide show
  1. app.py +33 -8
  2. backend/cross_encoder.py +29 -0
app.py CHANGED
@@ -11,9 +11,13 @@ from jinja2 import Environment, FileSystemLoader
11
 
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
 
 
 
14
 
15
 
16
  TOP_K = int(os.getenv("TOP_K", 4))
 
17
 
18
  proj_dir = Path(__file__).parent
19
  # Setting up the logging
@@ -34,7 +38,7 @@ def add_text(history, text):
34
  return history, gr.Textbox(value="", interactive=False)
35
 
36
 
37
- def bot(history, api_kind):
38
  query = history[-1][0]
39
 
40
  if not query:
@@ -42,12 +46,32 @@ def bot(history, api_kind):
42
 
43
  logger.info('Retrieving documents...')
44
  # Retrieve documents relevant to query
45
- document_start = perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- documents = retrieve(query, TOP_K)
48
 
49
- document_time = perf_counter() - document_start
50
- logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
51
 
52
  # Create Prompt
53
  prompt = template.render(documents=documents, query=query)
@@ -86,19 +110,20 @@ with gr.Blocks() as demo:
86
  )
87
  txt_btn = gr.Button(value="Submit text", scale=1)
88
 
89
- api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
 
90
 
91
  prompt_html = gr.HTML()
92
  # Turn off interactivity while generating if you click
93
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
94
- bot, [chatbot, api_kind], [chatbot, prompt_html])
95
 
96
  # Turn it back on
97
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
98
 
99
  # Turn off interactivity while generating if you hit enter
100
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
101
- bot, [chatbot, api_kind], [chatbot, prompt_html])
102
 
103
  # Turn it back on
104
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
11
 
12
  from backend.query_llm import generate_hf, generate_openai
13
  from backend.semantic_search import retrieve
14
+ from backend.cross_encoder import rerank_with_cross_encoder
15
+
16
+
17
 
18
 
19
  TOP_K = int(os.getenv("TOP_K", 4))
20
+ TOP_K_RERANK = int(os.getenv("TOP_K_RERANK", 40))
21
 
22
  proj_dir = Path(__file__).parent
23
  # Setting up the logging
 
38
  return history, gr.Textbox(value="", interactive=False)
39
 
40
 
41
+ def bot(history, api_kind, cross_enc):
42
  query = history[-1][0]
43
 
44
  if not query:
 
46
 
47
  logger.info('Retrieving documents...')
48
  # Retrieve documents relevant to query
49
+ documents = []
50
+ if not cross_enc:
51
+ document_start = perf_counter()
52
+
53
+ documents = retrieve(query, TOP_K)
54
+
55
+ document_time = perf_counter() - document_start
56
+ logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
57
+
58
+ else:
59
+ document_start = perf_counter()
60
+
61
+ documents = retrieve(query, TOP_K_RERANK)
62
+
63
+ document_time = perf_counter() - document_start
64
+ logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
65
+
66
+ logger.info('Reranking documents')
67
+ document_start = perf_counter()
68
+
69
+ documents = rerank_with_cross_encoder(cross_enc, documents, query)
70
+
71
+ document_time = perf_counter() - document_start
72
 
73
+ logger.info(f'Finished Reranking documents in {round(document_time, 2)} seconds...')
74
 
 
 
75
 
76
  # Create Prompt
77
  prompt = template.render(documents=documents, query=query)
 
110
  )
111
  txt_btn = gr.Button(value="Submit text", scale=1)
112
 
113
+ api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="LLM")
114
+ cross_enc = gr.Radio(choices=["None", "cross-encoder/ms-marco-MiniLM-L-6-v2", "BAAI/bge-reranker-large"], value=None, label="Cross Encoder")
115
 
116
  prompt_html = gr.HTML()
117
  # Turn off interactivity while generating if you click
118
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
119
+ bot, [chatbot, api_kind, cross_enc], [chatbot, prompt_html])
120
 
121
  # Turn it back on
122
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
123
 
124
  # Turn off interactivity while generating if you hit enter
125
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
126
+ bot, [chatbot, api_kind, cross_enc], [chatbot, prompt_html])
127
 
128
  # Turn it back on
129
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
backend/cross_encoder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ cross_encoder = None
7
+ cross_enc_tokenizer = None
8
+
9
+ TOP_K_RERANK = os.getenv("TOP_K_RERANK", 40)
10
+
11
+
12
+ @torch.no_grad()
13
+ def rerank_with_cross_encoder(cross_enc_name, documents, query):
14
+ if cross_enc_name is None or len(documents) <= 1:
15
+ return documents
16
+
17
+ global cross_encoder, cross_enc_tokenizer
18
+ if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
19
+ cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
20
+ cross_encoder.eval()
21
+ cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
22
+
23
+ features = cross_enc_tokenizer(
24
+ [query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
25
+ )
26
+ scores = cross_encoder(**features).logits.squeeze()
27
+ ranks = torch.argsort(scores, descending=True)
28
+ documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
29
+ return documents