File size: 7,222 Bytes
5d8b19c
 
 
 
 
 
 
 
 
8b61459
 
 
5d8b19c
 
 
 
 
 
 
 
 
 
 
02acc16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38bc3de
 
 
5d8b19c
38bc3de
 
 
 
5d8b19c
38bc3de
 
 
 
5d8b19c
38bc3de
 
 
5d8b19c
38bc3de
5d8b19c
38bc3de
 
 
5d8b19c
38bc3de
5d8b19c
 
38bc3de
 
9d7c11b
a74e62b
5d8b19c
c8770f9
5d8b19c
38bc3de
 
9d7c11b
 
fa08b06
 
5d8b19c
c8770f9
5d8b19c
 
5901e15
 
 
eb6e91b
 
5901e15
 
 
 
72e98b3
 
5901e15
50caa97
5901e15
72e98b3
 
5901e15
50caa97
5901e15
72e98b3
 
5901e15
50caa97
fa08b06
72e98b3
 
fa08b06
 
4beb64f
fa08b06
4beb64f
fa08b06
4beb64f
 
 
7d058f6
 
 
 
 
9170818
 
7d058f6
 
 
02acc16
 
 
9170818
ea63eb3
680c7ae
ea63eb3
5d8b19c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
from transformers import pipeline
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import text_to_word_sequence
import pickle
import re
from tensorflow.keras.models import load_model

# Load long model
with open('lstm-qa-long-answers-model/tokenizer.pickle', 'rb') as handle:
    tokenizer = pickle.load(handle)
long_answer_model = load_model('lstm-qa-long-answers-model/model.h5')

def clean_text(text):
    text = re.sub(r'<.*?>', '', text)  
    text = re.sub(r'\[\d+\]', '', text)  
    text = re.sub(r'[^a-zA-Z0-9\s().,]', '', text) 
    return text

def remove_parentheses(text):
    pattern = r'\([^)]*\)'
    return re.sub(pattern, '', text)

def predict_correct_answer(question, answer1, answer2):
    answers = [answer1, answer2]
    correct_answer = None
    best_score = 0
    for answer in answers:
        clean_answer = clean_text(answer)
        question_seq = tokenizer.texts_to_sequences([question])
        answer_seq = tokenizer.texts_to_sequences([clean_answer])
        padded_question = pad_sequences(question_seq, padding='post')
        padded_answer = pad_sequences(answer_seq, maxlen=300, padding='post', truncating='post')
        score = long_answer_model.predict([padded_answer, padded_question])[0][0]
        if score > best_score:
            best_score = score
            correct_answer = clean_answer

    return correct_answer, best_score

# def split_into_sentences(text):
#     sentences = re.split(r'\.\s*', text)
#     return [s.strip() for s in sentences if s]

# def predict_answer(context, question):
#     sentences = split_into_sentences(context)
#     best_sentence = None
#     best_score = 0

#     for sentence in sentences:
#         clean_sentence = clean_text(sentence)
#         question_seq = tokenizer.texts_to_sequences([question])
#         sentence_seq = tokenizer.texts_to_sequences([clean_sentence])

#         max_sentence_length = 300
#         padded_question = pad_sequences(question_seq, padding='post')
#         padded_sentence = pad_sequences(sentence_seq, maxlen=max_sentence_length, padding='post', truncating='post')

#         score = long_answer_model.predict([padded_sentence, padded_question])[0]

#         if score > best_score:
#             best_score = score
#             best_sentence = clean_sentence

#     return best_score, best_sentence

# Load short model
distilbert_base_uncased = pipeline(model="Nighter/QA_wiki_data_short_answer", from_tf=True)
bert_base_uncased  = pipeline(model="Nighter/QA_bert_base_uncased_wiki_data_short_answer", from_tf=True)
roberta_base = pipeline(model="Nighter/QA_wiki_data_roberta_base_short_answer", from_tf=True)
longformer_base = pipeline(model="aware-ai/longformer-squadv2")

# Function to answer on all models
def answer_questions(context, question):
    # long_score, long_answer = predict_answer(context, question)
    distilbert_base_uncased_result = distilbert_base_uncased(question=question, context=remove_parentheses(context))
    bert_base_uncased_result = bert_base_uncased(question=question, context=remove_parentheses(context))
    roberta_base_result = roberta_base(question=question, context=remove_parentheses(context))
    longformer_base_result = longformer_base(question=question, context=remove_parentheses(context))
    return distilbert_base_uncased_result['answer'], distilbert_base_uncased_result['score'], bert_base_uncased_result['answer'], bert_base_uncased_result['score'], roberta_base_result['answer'], longformer_base_result['score'], longformer_base_result['answer'], roberta_base_result['score'] #, long_answer, long_score

# App Interface
with gr.Blocks() as app:
    gr.Markdown("<center> <h1>Question Answering with Short and Long Answer Models </h1> </center><hr>")
    with gr.Tab("QA Short Answer"):
        with gr.Row():
            with gr.Column():
                context_input = gr.Textbox(lines=8, label="Context", placeholder="Input Context here...") 
                question_input = gr.Textbox(lines=3, label="Question", placeholder="Input Question here...")
                submit_btn = gr.Button("Submit")
                gr.ClearButton([context_input,question_input])
            with gr.Column():
                with gr.Row():
                    with gr.Column(scale=6):
                        distilbert_base_uncased_output = gr.Textbox(lines=2, label="Distil BERT Base Uncased")
                    with gr.Column(scale=2):
                        distilbert_base_uncased_score = gr.Number(label="Distil BERT Base Uncased Score")
                with gr.Row():
                    with gr.Column(scale=6):
                        bert_base_uncased_output = gr.Textbox(lines=2, label="BERT Base Uncased")
                    with gr.Column(scale=2):
                        bert_base_uncased_score = gr.Number(label="BERT Base Uncased Score")
                with gr.Row():
                    with gr.Column(scale=6):
                        roberta_base_output = gr.Textbox(lines=2, label="RoBERTa Base")
                    with gr.Column(scale=2):
                        roberta_base_score = gr.Number(label="RoBERTa Base Score")
                with gr.Row():
                    with gr.Column(scale=6):
                        longformer_base_output = gr.Textbox(lines=2, label="Longformer Base")
                    with gr.Column(scale=2):
                        longformer_base_score = gr.Number(label="Longformer Base Score")

        submit_btn.click(fn=answer_questions, inputs=[context_input, question_input], outputs=[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score])
        examples='examples'
        gr.Examples(examples,[context_input, question_input],[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score],answer_questions)
        
    with gr.Tab("Long Answer Prediction"):
        with gr.Row():
            with gr.Column():
                long_question_input = gr.Textbox(lines=3,label="Question", placeholder="Enter the question")
                answer1_input = gr.Textbox(lines=3,label="Answer 1", placeholder="Enter answer 1")
                answer2_input = gr.Textbox(lines=3,label="Answer 2", placeholder="Enter answer 2")
                submit_btn_long = gr.Button("Submit")
                gr.ClearButton([long_question_input, answer1_input, answer2_input])

            with gr.Column():
                correct_answer_output = gr.Textbox(lines=3,label="Correct Answer")
                score_output = gr.Number(label="Score")
    
        submit_btn_long.click(fn=predict_correct_answer, inputs=[long_question_input, answer1_input, answer2_input], 
                         outputs=[correct_answer_output, score_output])
        
        long_examples = 'long_examples'
        gr.Examples(long_examples,[long_question_input, answer1_input, answer2_input])
                    
if __name__ == "__main__":
    app.launch()