NCTCMumbai commited on
Commit
0d0cb38
1 Parent(s): c4dff66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -35,7 +35,7 @@ template = env.get_template('template.j2')
35
  template_html = env.get_template('template_html.j2')
36
 
37
  # crossEncoder
38
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
39
  #cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
40
  # Examples
41
  examples = ['What is the 4 digit classification heading for Gold jewellery?',
@@ -49,7 +49,7 @@ def add_text(history, text):
49
  return history, gr.Textbox(value="", interactive=False)
50
 
51
 
52
- def bot(history, api_kind):
53
  top_rerank = 15
54
  top_k_rank = 10
55
  query = history[-1][0]
@@ -66,7 +66,7 @@ def bot(history, api_kind):
66
  logger.warning(f'Finished query vec')
67
  doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
68
 
69
-
70
 
71
  logger.warning(f'Finished search')
72
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
@@ -74,6 +74,10 @@ def bot(history, api_kind):
74
  logger.warning(f'start cross encoder {len(documents)}')
75
  # Retrieve documents relevant to query
76
  query_doc_pair = [[query, doc] for doc in documents]
 
 
 
 
77
  cross_scores = cross_encoder.predict(query_doc_pair)
78
  sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
79
  logger.warning(f'Finished cross encoder {len(documents)}')
@@ -88,16 +92,7 @@ def bot(history, api_kind):
88
  prompt = template.render(documents=documents, query=query)
89
  prompt_html = template_html.render(documents=documents, query=query)
90
 
91
- if api_kind == "HuggingFace":
92
- generate_fn = generate_hf
93
- elif api_kind == "OpenAI":
94
- generate_fn = generate_openai
95
- elif api_kind is None:
96
- gr.Warning("API name was not provided")
97
- raise ValueError("API name was not provided")
98
- else:
99
- gr.Warning(f"API {api_kind} is not supported")
100
- raise ValueError(f"API {api_kind} is not supported")
101
 
102
  history[-1][1] = ""
103
  for character in generate_fn(prompt, history[:-1]):
@@ -125,19 +120,19 @@ with gr.Blocks() as demo:
125
  )
126
  txt_btn = gr.Button(value="Submit text", scale=1)
127
 
128
- api_kind = gr.Radio(choices=["HuggingFace"], value="HuggingFace")
129
 
130
  prompt_html = gr.HTML()
131
  # Turn off interactivity while generating if you click
132
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
133
- bot, [chatbot, api_kind], [chatbot, prompt_html])
134
 
135
  # Turn it back on
136
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
137
 
138
  # Turn off interactivity while generating if you hit enter
139
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
140
- bot, [chatbot, api_kind], [chatbot, prompt_html])
141
 
142
  # Turn it back on
143
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
35
  template_html = env.get_template('template_html.j2')
36
 
37
  # crossEncoder
38
+ #cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
39
  #cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
40
  # Examples
41
  examples = ['What is the 4 digit classification heading for Gold jewellery?',
 
49
  return history, gr.Textbox(value="", interactive=False)
50
 
51
 
52
+ def bot(history, cross_encoder):
53
  top_rerank = 15
54
  top_k_rank = 10
55
  query = history[-1][0]
 
66
  logger.warning(f'Finished query vec')
67
  doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
68
 
69
+
70
 
71
  logger.warning(f'Finished search')
72
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
 
74
  logger.warning(f'start cross encoder {len(documents)}')
75
  # Retrieve documents relevant to query
76
  query_doc_pair = [[query, doc] for doc in documents]
77
+ if cross_encoder=='MiniLM-L6v2' :
78
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
79
+ else:
80
+ cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
81
  cross_scores = cross_encoder.predict(query_doc_pair)
82
  sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
83
  logger.warning(f'Finished cross encoder {len(documents)}')
 
92
  prompt = template.render(documents=documents, query=query)
93
  prompt_html = template_html.render(documents=documents, query=query)
94
 
95
+ generate_fn = generate_hf
 
 
 
 
 
 
 
 
 
96
 
97
  history[-1][1] = ""
98
  for character in generate_fn(prompt, history[:-1]):
 
120
  )
121
  txt_btn = gr.Button(value="Submit text", scale=1)
122
 
123
+ cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='MiniLM-L6v2')
124
 
125
  prompt_html = gr.HTML()
126
  # Turn off interactivity while generating if you click
127
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
128
+ bot, [chatbot, cross_encoder], [chatbot, prompt_html])
129
 
130
  # Turn it back on
131
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
132
 
133
  # Turn off interactivity while generating if you hit enter
134
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
135
+ bot, [chatbot, cross_encoder], [chatbot, prompt_html])
136
 
137
  # Turn it back on
138
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)