Leotis commited on
Commit
6779ad6
1 Parent(s): 092dcc1

add inference class

Browse files
Files changed (3) hide show
  1. app.py +8 -2
  2. data/fire-and-blood.docx.txt +0 -0
  3. inference.py +64 -0
app.py CHANGED
@@ -1,7 +1,13 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  iface.launch()
 
1
  import gradio as gr
2
+ from inference import Question_and_Answer_System
3
 
4
+ qa = Question_and_Answer_System()
5
+
6
+
7
+
8
+ def greet(question):
9
+ prediction = qa.answer_question(question)
10
+ return prediction
11
 
12
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
13
  iface.launch()
data/fire-and-blood.docx.txt ADDED
The diff for this file is too large to render. See raw diff
 
inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from haystack.pipelines.standard_pipelines import TextIndexingPipeline
4
+ from haystack.document_stores import InMemoryDocumentStore
5
+ from haystack.nodes import BM25Retriever
6
+ from haystack.nodes import FARMReader
7
+ from haystack.pipelines import ExtractiveQAPipeline
8
+ from haystack.utils import print_answers
9
+
10
+
11
+ class Question_and_Answer_System:
12
+ def __init__(self):
13
+ self.pipe_line = None
14
+ doc_dir = "data"
15
+ document_store = self.prepare_documents(doc_dir)
16
+ self.reader = self.create_reader(document_store)
17
+ self.retriever = self.create_retriever(document_store)
18
+ self.pipe_line = self.create_pipeline(self.reader, self.retriever)
19
+
20
+ def setup_logging(self):
21
+ logging.basicConfig(
22
+ format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING
23
+ )
24
+ logging.getLogger("haystack").setLevel(logging.INFO)
25
+
26
+ def prepare_documents(self, doc_dir):
27
+ document_store = InMemoryDocumentStore(use_bm25=True)
28
+ doc_dir = "data"
29
+ files_to_index = [doc_dir + "/" + f for f in os.listdir(doc_dir)]
30
+ indexing_pipeline = TextIndexingPipeline(document_store)
31
+ indexing_pipeline.run_batch(file_paths=files_to_index)
32
+ return document_store
33
+
34
+ def create_retriever(self, document_store):
35
+ retriever = BM25Retriever(document_store=document_store)
36
+ return retriever
37
+
38
+ def create_reader(self, document_store):
39
+ reader = FARMReader(
40
+ model_name_or_path="deepset/roberta-base-squad2", use_gpu=False
41
+ )
42
+ return reader
43
+
44
+ def create_pipeline(self, reader, retriever):
45
+ self.pipe_line = ExtractiveQAPipeline(reader, retriever)
46
+ return self.pipe_line
47
+
48
+ def answer_question(self, question: str):
49
+ prediction = self.pipe_line.run(
50
+ query=question,
51
+ params={"Retriever": {"top_k": 10}, "Reader": {"top_k": 5}},
52
+ )
53
+
54
+ return prediction
55
+
56
+ def format_answers(self, prediction):
57
+ print_answers(
58
+ prediction, details="minimum" ## Choose from `minimum`, `medium`, and `all`
59
+ )
60
+
61
+
62
+
63
+
64
+