KonradSzafer commited on
Commit
b7068fd
·
1 Parent(s): cf57696

question and answer postprocessing

Browse files
Files changed (2) hide show
  1. benchmark/__main__.py +1 -0
  2. qa_engine/qa_engine.py +30 -1
benchmark/__main__.py CHANGED
@@ -33,6 +33,7 @@ def main():
33
 
34
  wandb.init(
35
  project='HF-Docs-QA',
 
36
  name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
37
  mode='run', # run/disabled
38
  config=filtered_config
 
33
 
34
  wandb.init(
35
  project='HF-Docs-QA',
36
+ entity='hf-qa-bot',
37
  name=f'{config.question_answering_model_id} - {config.embedding_model_id} - {config.index_repo_id}',
38
  mode='run', # run/disabled
39
  config=filtered_config
qa_engine/qa_engine.py CHANGED
@@ -228,6 +228,33 @@ class QAEngine():
228
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def get_response(self, question: str, messages_context: str = '') -> Response:
232
  """
233
  Generate an answer to the specified question.
@@ -271,7 +298,9 @@ class QAEngine():
271
  response.set_sources(sources=[str(m['source']) for m in metadata])
272
 
273
  logger.info('Running LLM chain')
274
- answer = self.llm_chain.run(question=question, context=context)
 
 
275
  response.set_answer(answer)
276
  logger.info('Received answer')
277
 
 
228
  self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
229
 
230
 
231
+ @staticmethod
232
+ def _preprocess_question(question: str) -> str:
233
+ if question[-1] != '?':
234
+ question += '?'
235
+ return question
236
+
237
+
238
+ @staticmethod
239
+ def _postprocess_answer(answer: str) -> str:
240
+ '''
241
+ Preprocess the answer by removing unnecessary sequences and stop sequences.
242
+ '''
243
+ REMOVE_SEQUENCES = [
244
+ 'Factually: ', 'Answer: ', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'
245
+ ]
246
+ STOP_SEQUENCES = [
247
+ '\nUser:', '\nYou:'
248
+ ]
249
+ for seq in REMOVE_SEQUENCES:
250
+ answer = answer.replace(seq, '')
251
+ for seq in STOP_SEQUENCES:
252
+ if seq in answer:
253
+ answer = answer[:answer.index(seq)]
254
+ answer = answer.strip()
255
+ return answer
256
+
257
+
258
  def get_response(self, question: str, messages_context: str = '') -> Response:
259
  """
260
  Generate an answer to the specified question.
 
298
  response.set_sources(sources=[str(m['source']) for m in metadata])
299
 
300
  logger.info('Running LLM chain')
301
+ question_processed = QAEngine._preprocess_question(question)
302
+ answer = self.llm_chain.run(question=question_processed, context=context)
303
+ answer = QAEngine._postprocess_answer(answer)
304
  response.set_answer(answer)
305
  logger.info('Received answer')
306