Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from tensorflow.keras.preprocessing.text import text_to_word_sequence | |
import pickle | |
import re | |
from tensorflow.keras.models import load_model | |
# Load long model | |
with open('lstm-qa-long-answers-model/tokenizer.pickle', 'rb') as handle: | |
tokenizer = pickle.load(handle) | |
long_answer_model = load_model('lstm-qa-long-answers-model/model.h5') | |
def clean_text(text): | |
text = re.sub(r'<.*?>', '', text) | |
text = re.sub(r'\[\d+\]', '', text) | |
text = re.sub(r'[^a-zA-Z0-9\s().,]', '', text) | |
return text | |
def remove_parentheses(text): | |
pattern = r'\([^)]*\)' | |
return re.sub(pattern, '', text) | |
def predict_correct_answer(question, answer1, answer2): | |
answers = [answer1, answer2] | |
correct_answer = None | |
best_score = 0 | |
for answer in answers: | |
clean_answer = clean_text(answer) | |
question_seq = tokenizer.texts_to_sequences([question]) | |
answer_seq = tokenizer.texts_to_sequences([clean_answer]) | |
padded_question = pad_sequences(question_seq, padding='post') | |
padded_answer = pad_sequences(answer_seq, maxlen=300, padding='post', truncating='post') | |
score = long_answer_model.predict([padded_answer, padded_question])[0][0] | |
if score > best_score: | |
best_score = score | |
correct_answer = clean_answer | |
return correct_answer, best_score | |
# def split_into_sentences(text): | |
# sentences = re.split(r'\.\s*', text) | |
# return [s.strip() for s in sentences if s] | |
# def predict_answer(context, question): | |
# sentences = split_into_sentences(context) | |
# best_sentence = None | |
# best_score = 0 | |
# for sentence in sentences: | |
# clean_sentence = clean_text(sentence) | |
# question_seq = tokenizer.texts_to_sequences([question]) | |
# sentence_seq = tokenizer.texts_to_sequences([clean_sentence]) | |
# max_sentence_length = 300 | |
# padded_question = pad_sequences(question_seq, padding='post') | |
# padded_sentence = pad_sequences(sentence_seq, maxlen=max_sentence_length, padding='post', truncating='post') | |
# score = long_answer_model.predict([padded_sentence, padded_question])[0] | |
# if score > best_score: | |
# best_score = score | |
# best_sentence = clean_sentence | |
# return best_score, best_sentence | |
# Load short model | |
distilbert_base_uncased = pipeline(model="Nighter/QA_wiki_data_short_answer", from_tf=True) | |
bert_base_uncased = pipeline(model="Nighter/QA_bert_base_uncased_wiki_data_short_answer", from_tf=True) | |
roberta_base = pipeline(model="Nighter/QA_wiki_data_roberta_base_short_answer", from_tf=True) | |
longformer_base = pipeline(model="aware-ai/longformer-squadv2") | |
# Function to answer on all models | |
def answer_questions(context, question): | |
# long_score, long_answer = predict_answer(context, question) | |
distilbert_base_uncased_result = distilbert_base_uncased(question=question, context=remove_parentheses(context)) | |
bert_base_uncased_result = bert_base_uncased(question=question, context=remove_parentheses(context)) | |
roberta_base_result = roberta_base(question=question, context=remove_parentheses(context)) | |
longformer_base_result = longformer_base(question=question, context=remove_parentheses(context)) | |
return distilbert_base_uncased_result['answer'], distilbert_base_uncased_result['score'], bert_base_uncased_result['answer'], bert_base_uncased_result['score'], roberta_base_result['answer'], longformer_base_result['score'], longformer_base_result['answer'], roberta_base_result['score'] #, long_answer, long_score | |
# App Interface | |
with gr.Blocks() as app: | |
gr.Markdown("<center> <h1>Question Answering with Short and Long Answer Models </h1> </center><hr>") | |
with gr.Tab("QA Short Answer"): | |
with gr.Row(): | |
with gr.Column(): | |
context_input = gr.Textbox(lines=8, label="Context", placeholder="Input Context here...") | |
question_input = gr.Textbox(lines=3, label="Question", placeholder="Input Question here...") | |
submit_btn = gr.Button("Submit") | |
gr.ClearButton([context_input,question_input]) | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(scale=6): | |
distilbert_base_uncased_output = gr.Textbox(lines=2, label="Distil BERT Base Uncased") | |
with gr.Column(scale=2): | |
distilbert_base_uncased_score = gr.Number(label="Distil BERT Base Uncased Score") | |
with gr.Row(): | |
with gr.Column(scale=6): | |
bert_base_uncased_output = gr.Textbox(lines=2, label="BERT Base Uncased") | |
with gr.Column(scale=2): | |
bert_base_uncased_score = gr.Number(label="BERT Base Uncased Score") | |
with gr.Row(): | |
with gr.Column(scale=6): | |
roberta_base_output = gr.Textbox(lines=2, label="RoBERTa Base") | |
with gr.Column(scale=2): | |
roberta_base_score = gr.Number(label="RoBERTa Base Score") | |
with gr.Row(): | |
with gr.Column(scale=6): | |
longformer_base_output = gr.Textbox(lines=2, label="Longformer Base") | |
with gr.Column(scale=2): | |
longformer_base_score = gr.Number(label="Longformer Base Score") | |
submit_btn.click(fn=answer_questions, inputs=[context_input, question_input], outputs=[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score]) | |
examples='examples' | |
gr.Examples(examples,[context_input, question_input],[distilbert_base_uncased_output, distilbert_base_uncased_score, bert_base_uncased_output, bert_base_uncased_score, roberta_base_output, roberta_base_score, longformer_base_output, longformer_base_score],answer_questions) | |
with gr.Tab("Long Answer Prediction"): | |
with gr.Row(): | |
with gr.Column(): | |
long_question_input = gr.Textbox(lines=3,label="Question", placeholder="Enter the question") | |
answer1_input = gr.Textbox(lines=3,label="Answer 1", placeholder="Enter answer 1") | |
answer2_input = gr.Textbox(lines=3,label="Answer 2", placeholder="Enter answer 2") | |
submit_btn_long = gr.Button("Submit") | |
gr.ClearButton([long_question_input, answer1_input, answer2_input]) | |
with gr.Column(): | |
correct_answer_output = gr.Textbox(lines=3,label="Correct Answer") | |
score_output = gr.Number(label="Score") | |
submit_btn_long.click(fn=predict_correct_answer, inputs=[long_question_input, answer1_input, answer2_input], | |
outputs=[correct_answer_output, score_output]) | |
long_examples = 'long_examples' | |
gr.Examples(long_examples,[long_question_input, answer1_input, answer2_input]) | |
if __name__ == "__main__": | |
app.launch() |