File size: 4,124 Bytes
7980e1c
 
9a6f34b
7980e1c
 
9a6f34b
 
 
 
7980e1c
9a6f34b
 
 
 
7980e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6f34b
 
 
7980e1c
 
 
9a6f34b
 
 
7980e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6f34b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7980e1c
91d9f42
7980e1c
 
 
9a6f34b
7980e1c
 
 
 
9a6f34b
7980e1c
 
 
9a6f34b
 
7980e1c
9a6f34b
 
7980e1c
9a6f34b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Tuple

import gradio as gr
import requests
import xmltodict
from PyPDF2 import PdfReader
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from transformers.pipelines.question_answering import QuestionAnsweringPipeline

QA_MODEL_NAME = "ixa-ehu/SciBERT-SQuAD-QuAC"
TEMP_PDF_PATH = "/tmp/arxiv_paper.pdf"
ARXIV_URL_PATTERN = r"(http|https)://(arxiv.org/pdf/)+([0-9]+\.[0-9]+)\.pdf"


def is_valid_url(url: str) -> bool:
    return re.fullmatch(ARXIV_URL_PATTERN, url) is not None


@dataclass
class PaperMetaData:
    arxiv_id: str
    title: str
    summary: str
    text: str

    @staticmethod
    def _clean_field(text: str) -> str:
        text = re.sub(r"\n", " ", text)
        text = re.sub(r"\s+", " ", text)
        return text

    @classmethod
    def from_api(cls, arxiv_id: str, text: str) -> PaperMetaData:
        paper_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
        response = requests.get(paper_url)
        paper_dict = xmltodict.parse(response.content)["feed"]["entry"]
        return PaperMetaData(
            arxiv_id=arxiv_id,
            title=cls._clean_field(paper_dict["title"]),
            summary=cls._clean_field(paper_dict["summary"]),
            text=text,
        )


def clean_text(text: str) -> str:
    text = re.sub(r"\x03|\x02", "", text)
    text = re.sub(r"-\s+", "", text)
    text = re.sub(r"\n", " ", text)
    return text


class PDFPaper:
    def __init__(self, url: str):
        if not is_valid_url(url):
            raise ValueError("The URL provided is not a valid arxiv PDF url.")
        self.url = url
        self.arxiv_id = re.fullmatch(ARXIV_URL_PATTERN, url).group(3)

    def _download(self, download_path: str = TEMP_PDF_PATH) -> None:
        pdf_r = requests.get(self.url)
        pdf_r.raise_for_status()
        with open(download_path, "wb") as pdf_file:
            pdf_file.write(pdf_r.content)

    def read_text(self, pdf_path: str = TEMP_PDF_PATH) -> str:
        self._download(pdf_path)
        reader = PdfReader(pdf_path)
        pdf_text = " ".join([page.extract_text() for page in reader.pages])
        return clean_text(pdf_text)

    def get_paper_full_data(self) -> PaperMetaData:
        return PaperMetaData.from_api(arxiv_id=self.arxiv_id, text=self.read_text())


def get_paper_data(url: str) -> Tuple[str, str, str]:
    paper_data = PDFPaper(url=url).get_paper_full_data()
    return paper_data.title, paper_data.summary, paper_data.text


def get_qa_pipeline(qa_model_name: str = QA_MODEL_NAME) -> QuestionAnsweringPipeline:
    tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
    model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
    qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)
    return qa_pipeline


def get_answer(question: str, context: str) -> str:
    qa_pipeline = get_qa_pipeline()
    prediction = qa_pipeline(question=question, context=context)
    return prediction["answer"]


demo = gr.Blocks()

with demo:
    gr.Markdown("# arXiv Paper Q&A\nImport an arXiv paper and ask questions about it!")
    gr.Markdown("⚠️ This Space is still Work in progress! ⚠️")
    gr.Markdown("## 📄 Import the paper on arXiv")
    arxiv_url = gr.Textbox(
        label="arXiv Paper URL", placeholder="Insert here the URL of a paper on arXiv"
    )
    fetch_document_button = gr.Button("Import Paper")
    paper_title = gr.Textbox(label="Paper Title")
    paper_summary = gr.Textbox(label="Paper Summary")
    paper_text = gr.Textbox(label="Paper Text")
    fetch_document_button.click(
        fn=get_paper_data,
        inputs=arxiv_url,
        outputs=[paper_title, paper_summary, paper_text],
    )

    gr.Markdown("## 🤨 Ask a question about the paper")
    question = gr.Textbox(label="Ask a question about the paper:")
    ask_button = gr.Button("Ask me 🤖")
    answer = gr.Textbox(label="Answer:")
    ask_button.click(fn=get_answer, inputs=[question, paper_summary], outputs=answer)


demo.launch()