Nighter commited on
Commit
38bc3de
1 Parent(s): c8770f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -43
app.py CHANGED
@@ -7,9 +7,9 @@ import re
7
  from tensorflow.keras.models import load_model
8
 
9
  # Load long model
10
- with open('lstm-qa-long-answers-model/tokenizer.pickle', 'rb') as handle:
11
- tokenizer = pickle.load(handle)
12
- long_answer_model = load_model('lstm-qa-long-answers-model/model.h5')
13
 
14
  def clean_text(text):
15
  text = re.sub(r'<.*?>', '', text)
@@ -21,49 +21,42 @@ def remove_parentheses(text):
21
  pattern = r'\([^)]*\)'
22
  return re.sub(pattern, '', text)
23
 
24
- def split_into_sentences(text):
25
- sentences = re.split(r'\.\s*', text)
26
- return [s.strip() for s in sentences if s]
27
 
28
- def predict_answer(context, question):
29
- sentences = split_into_sentences(context)
30
- best_sentence = None
31
- best_score = 0
32
 
33
- for sentence in sentences:
34
- clean_sentence = clean_text(sentence)
35
- question_seq = tokenizer.texts_to_sequences([question])
36
- sentence_seq = tokenizer.texts_to_sequences([clean_sentence])
37
 
38
- max_sentence_length = 300
39
- padded_question = pad_sequences(question_seq, padding='post')
40
- padded_sentence = pad_sequences(sentence_seq, maxlen=max_sentence_length, padding='post', truncating='post')
41
 
42
- score = long_answer_model.predict([padded_sentence, padded_question])[0]
43
 
44
- if score > best_score:
45
- best_score = score
46
- best_sentence = clean_sentence
47
 
48
- return best_score, best_sentence
49
 
50
  # Load short model
51
- short_answer_model = pipeline(model="Nighter/QA_wiki_data_short_answer", from_tf=True)
 
52
 
53
  # Function to answer on all models
54
  def answer_questions(context, question):
55
- long_score, long_answer = predict_answer(context, question)
56
- # # Check if the original context is longer than 512 tokens
57
- # if len(tokenizer.texts_to_sequences([context])[0]) > 512:
58
- # # If yes, use the long answer as the context for the short answer model
59
- # short_context = long_answer
60
- # else:
61
- # # If no, use the original context
62
- # short_context = remove_parentheses(context)
63
-
64
- # short_answer_result = short_answer_model(question=question, context=short_context)
65
- short_answer_result = short_answer_model(question=question, context=remove_parentheses(context))
66
- return short_answer_result['answer'], short_answer_result['score'], long_answer, long_score
67
 
68
  # App Interface
69
  with gr.Blocks() as app:
@@ -76,15 +69,15 @@ with gr.Blocks() as app:
76
  gr.ClearButton([context_input,question_input])
77
  with gr.Column():
78
  with gr.Row():
79
- with gr.Column(scale=4):
80
- short_answer_output = gr.Textbox(lines=5, label="Distil BERT Short Answer")
81
- with gr.Column(scale=1):
82
- short_score_output = gr.Number(label="Short Answer Score")
83
  with gr.Row():
84
- with gr.Column(scale=4):
85
- long_answer_output = gr.Textbox(lines=5, label="LSTM Long Answer")
86
- with gr.Column(scale=1):
87
- long_score_output = gr.Number(label="Long Answer Score")
88
 
89
  submit_btn.click(fn=answer_questions, inputs=[context_input, question_input], outputs=[short_answer_output, short_score_output, long_answer_output, long_score_output])
90
  examples='examples'
 
7
  from tensorflow.keras.models import load_model
8
 
9
  # Load long model
10
+ # with open('lstm-qa-long-answers-model/tokenizer.pickle', 'rb') as handle:
11
+ # tokenizer = pickle.load(handle)
12
+ # long_answer_model = load_model('lstm-qa-long-answers-model/model.h5')
13
 
14
  def clean_text(text):
15
  text = re.sub(r'<.*?>', '', text)
 
21
  pattern = r'\([^)]*\)'
22
  return re.sub(pattern, '', text)
23
 
24
+ # def split_into_sentences(text):
25
+ # sentences = re.split(r'\.\s*', text)
26
+ # return [s.strip() for s in sentences if s]
27
 
28
+ # def predict_answer(context, question):
29
+ # sentences = split_into_sentences(context)
30
+ # best_sentence = None
31
+ # best_score = 0
32
 
33
+ # for sentence in sentences:
34
+ # clean_sentence = clean_text(sentence)
35
+ # question_seq = tokenizer.texts_to_sequences([question])
36
+ # sentence_seq = tokenizer.texts_to_sequences([clean_sentence])
37
 
38
+ # max_sentence_length = 300
39
+ # padded_question = pad_sequences(question_seq, padding='post')
40
+ # padded_sentence = pad_sequences(sentence_seq, maxlen=max_sentence_length, padding='post', truncating='post')
41
 
42
+ # score = long_answer_model.predict([padded_sentence, padded_question])[0]
43
 
44
+ # if score > best_score:
45
+ # best_score = score
46
+ # best_sentence = clean_sentence
47
 
48
+ # return best_score, best_sentence
49
 
50
  # Load short model
51
+ distilbert_base_uncased = pipeline(model="Nighter/QA_wiki_data_short_answer", from_tf=True)
52
+ bert_base_uncased = pipeline(model="Nighter/QA_bert_base_uncased_wiki_data_short_answer", from_tf=True)
53
 
54
  # Function to answer on all models
55
  def answer_questions(context, question):
56
+ # long_score, long_answer = predict_answer(context, question)
57
+ distilbert_base_uncased_result = distilbert_base_uncased(question=question, context=remove_parentheses(context))
58
+ bert_base_uncased_result =bert_base_uncased(question=question, context=remove_parentheses(context))
59
+ return distilbert_base_uncased_result['answer'], distilbert_base_uncased_result['score'], bert_base_uncased_result['answer'], bert_base_uncased_result['score'] #, long_answer, long_score
 
 
 
 
 
 
 
 
60
 
61
  # App Interface
62
  with gr.Blocks() as app:
 
69
  gr.ClearButton([context_input,question_input])
70
  with gr.Column():
71
  with gr.Row():
72
+ with gr.Column(scale=6):
73
+ short_answer_output = gr.Textbox(lines=5, label="Distil BERT Base Uncased")
74
+ with gr.Column(scale=2):
75
+ short_score_output = gr.Number(label="Distil BERT Base Uncased Score")
76
  with gr.Row():
77
+ with gr.Column(scale=6):
78
+ long_answer_output = gr.Textbox(lines=5, label="BERT Base Uncased")
79
+ with gr.Column(scale=2):
80
+ long_score_output = gr.Number(label="BERT Base Uncased Score")
81
 
82
  submit_btn.click(fn=answer_questions, inputs=[context_input, question_input], outputs=[short_answer_output, short_score_output, long_answer_output, long_score_output])
83
  examples='examples'