Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline | |
from pdfminer.high_level import extract_text | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
import nltk | |
from nltk.tokenize import sent_tokenize | |
from rank_bm25 import BM25Okapi | |
nltk.download('punkt') | |
# QA model | |
qa_model_name = "deepset/roberta-large-squad2" | |
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer) | |
# Summarization model | |
summarization_model_name = "facebook/bart-large-cnn" | |
summarizer = pipeline("summarization", model=summarization_model_name) | |
def read_pdf(file): | |
try: | |
text = extract_text(file) | |
if not text: | |
raise ValueError("PDF extraction failed.") | |
return text | |
except Exception as e: | |
return str(e) | |
def retrieve_relevant_text_bm25(question, sentences, top_n=3): | |
try: | |
tokenized_corpus = [sent.split() for sent in sentences] | |
bm25 = BM25Okapi(tokenized_corpus) | |
tokenized_query = question.split() | |
doc_scores = bm25.get_scores(tokenized_query) | |
top_n_indices = np.argsort(doc_scores)[::-1][:top_n] | |
relevant_texts = [sentences[i] for i in top_n_indices] | |
return " ".join(relevant_texts) | |
except Exception as e: | |
return str(e) | |
def answer_question(pdf, question, num_words): | |
try: | |
text = read_pdf(pdf) | |
if isinstance(text, str): | |
return text | |
if "summarize" in question.lower(): | |
try: | |
summarized_text = summarizer(text, max_length=num_words, min_length=1) | |
return summarized_text[0]['summary_text'].strip() | |
except RuntimeError as e: | |
if "Input length of input_ids is" in str(e) and "but `max_length` is set to" in str(e): | |
return "PDF is too long for summarization. Please provide a shorter PDF or ask a more specific question." | |
else: | |
return f"Summarization Error: {e}" | |
except Exception as e: | |
return f"Summarization Error: {e}" | |
sentences = sent_tokenize(text) | |
relevant_text = retrieve_relevant_text_bm25(question, sentences) | |
if not relevant_text: | |
return "Could not find relevant information in the PDF." | |
response = qa_pipeline(question=question, context=relevant_text) | |
answer = response.get('answer') | |
if not answer: | |
return "Could not find an answer in the relevant text." | |
answer = answer.strip() | |
answer = " ".join(answer.split()) | |
if len(answer.split()) > num_words: | |
try: | |
summarized_answer = summarizer(answer, max_length=num_words + 10, min_length=1) | |
answer = summarized_answer[0]['summary_text'] | |
answer = answer.strip() | |
answer = " ".join(answer.split()) | |
if len(answer.split()) > num_words: | |
answer = " ".join(answer.split()[:num_words]) | |
except RuntimeError as e: | |
if "Input length of input_ids is" in str(e) and "but `max_length` is set to" in str(e): | |
answer = " ".join(answer.split()[:num_words]) | |
else: | |
return f"Summarization Error: {e}" | |
except Exception as e: | |
return f"Summarization Error: {e}" | |
elif len(answer.split()) < num_words and relevant_text: | |
remaining_words = num_words - len(answer.split()) | |
added_words = 0 | |
added_sentences = [] | |
for sentence in sent_tokenize(relevant_text): | |
sentence_words = sentence.split() | |
words_to_add = min(remaining_words - added_words, len(sentence_words)) | |
if words_to_add > 0: | |
added_sentences.append(" ".join(sentence_words[:words_to_add])) | |
added_words += words_to_add | |
if added_words >= remaining_words: | |
break | |
answer += " " + " ".join(added_sentences) | |
answer = answer.strip() | |
answer = " ".join(answer.split()) | |
if len(answer.split()) > num_words: | |
answer = " ".join(answer.split()[:num_words]) | |
return answer.strip() | |
except Exception as e: | |
return str(e) | |
with gr.Blocks() as iface: | |
gr.Markdown("PDF Q&A with RoBERTa | made by NP") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
question_input = gr.Textbox(lines=2, placeholder="Ask a question", label="Question") | |
btn = gr.Button("Submit") | |
with gr.Column(scale=1): | |
pdf_input = gr.File(type="filepath", label="Upload PDF") | |
num_words_slider = gr.Slider(minimum=1, maximum=500, value=100, step=1, label="Number of Words") | |
answer_output = gr.Textbox(label="Answer", lines=5) | |
btn.click(fn=answer_question, inputs=[pdf_input, question_input, num_words_slider], outputs=answer_output) | |
if __name__ == "__main__": | |
iface.launch() |