Spaces:
Running
Running
anakin87
commited on
Commit
·
4c2a969
1
Parent(s):
d6bdb02
various improvements
Browse files- README.md +1 -1
- Rock_fact_checker.py +117 -0
- app_utils/backend_utils.py +58 -0
- app_utils/config.py +8 -0
- app_utils/frontend_utils.py +15 -0
- data/statements.txt +5 -0
- pages/Info.py +1 -1
README.md
CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
|
|
5 |
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.10.0
|
8 |
+
app_file: Rock_fact_checker.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
---
|
Rock_fact_checker.py
CHANGED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
import time
|
4 |
+
import streamlit as st
|
5 |
+
import logging
|
6 |
+
from json import JSONDecodeError
|
7 |
+
# from markdown import markdown
|
8 |
+
# from annotated_text import annotation
|
9 |
+
# from urllib.parse import unquote
|
10 |
+
import random
|
11 |
+
|
12 |
+
from app_utils.backend_utils import load_questions, query
|
13 |
+
from app_utils.frontend_utils import set_state_if_absent, reset_results
|
14 |
+
from app_utils.config import RETRIEVER_TOP_K
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
|
19 |
+
|
20 |
+
questions = load_questions()
|
21 |
+
|
22 |
+
# Persistent state
|
23 |
+
set_state_if_absent('question', "Elvis Presley is alive")
|
24 |
+
set_state_if_absent('answer', '')
|
25 |
+
set_state_if_absent('results', None)
|
26 |
+
set_state_if_absent('raw_json', None)
|
27 |
+
set_state_if_absent('random_question_requested', False)
|
28 |
+
|
29 |
+
|
30 |
+
## MAIN CONTAINER
|
31 |
+
st.write("# Fact checking 🎸 Rocks!")
|
32 |
+
st.write()
|
33 |
+
st.markdown("""
|
34 |
+
##### Enter a factual statement about [Rock music](https://en.wikipedia.org/wiki/List_of_mainstream_rock_performers) and let the AI check it out for you...
|
35 |
+
""")
|
36 |
+
# Search bar
|
37 |
+
question = st.text_input("", value=st.session_state.question,
|
38 |
+
max_chars=100, on_change=reset_results)
|
39 |
+
col1, col2 = st.columns(2)
|
40 |
+
col1.markdown(
|
41 |
+
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
|
42 |
+
col2.markdown(
|
43 |
+
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
|
44 |
+
# Run button
|
45 |
+
run_pressed = col1.button("Run")
|
46 |
+
# Random question button
|
47 |
+
if col2.button("Random question"):
|
48 |
+
reset_results()
|
49 |
+
question = random.choice(questions)
|
50 |
+
# Avoid picking the same question twice (the change is not visible on the UI)
|
51 |
+
while question == st.session_state.question:
|
52 |
+
question = random.choice(questions)
|
53 |
+
st.session_state.question = question
|
54 |
+
st.session_state.random_question_requested = True
|
55 |
+
# Re-runs the script setting the random question as the textbox value
|
56 |
+
# Unfortunately necessary as the Random Question button is _below_ the textbox
|
57 |
+
# raise st.script_runner.RerunException(
|
58 |
+
# st.script_request_queue.RerunData(None))
|
59 |
+
else:
|
60 |
+
st.session_state.random_question_requested = False
|
61 |
+
run_query = (run_pressed or question != st.session_state.question) \
|
62 |
+
and not st.session_state.random_question_requested
|
63 |
+
|
64 |
+
# Get results for query
|
65 |
+
if run_query and question:
|
66 |
+
time_start = time.time()
|
67 |
+
reset_results()
|
68 |
+
st.session_state.question = question
|
69 |
+
with st.spinner("🧠 Performing neural search on documents..."):
|
70 |
+
try:
|
71 |
+
st.session_state.results = query(
|
72 |
+
question, RETRIEVER_TOP_K)
|
73 |
+
time_end = time.time()
|
74 |
+
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
|
75 |
+
print(f'elapsed time: {time_end - time_start}')
|
76 |
+
except JSONDecodeError as je:
|
77 |
+
st.error(
|
78 |
+
"👓 An error occurred reading the results. Is the document store working?")
|
79 |
+
return
|
80 |
+
except Exception as e:
|
81 |
+
logging.exception(e)
|
82 |
+
st.error("🐞 An error occurred during the request.")
|
83 |
+
return
|
84 |
+
|
85 |
+
# # Display results
|
86 |
+
# if st.session_state.results:
|
87 |
+
# st.write("## Results:")
|
88 |
+
# alert_irrelevance = True
|
89 |
+
# if len(st.session_state.results['answers']) == 0:
|
90 |
+
# st.info("""🤔 Haystack is unsure whether any of
|
91 |
+
# the documents contain an answer to your question. Try to reformulate it!""")
|
92 |
+
|
93 |
+
# for result in st.session_state.results['answers']:
|
94 |
+
# result = result.to_dict()
|
95 |
+
# if result["answer"]:
|
96 |
+
# if alert_irrelevance and result['score'] < LOW_RELEVANCE_THRESHOLD:
|
97 |
+
# alert_irrelevance = False
|
98 |
+
# st.write("""
|
99 |
+
# <h4 style='color: darkred'>Attention, the
|
100 |
+
# following answers have low relevance:</h4>""",
|
101 |
+
# unsafe_allow_html=True)
|
102 |
+
|
103 |
+
# answer, context = result["answer"], result["context"]
|
104 |
+
# start_idx = context.find(answer)
|
105 |
+
# end_idx = start_idx + len(answer)
|
106 |
+
# # Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190
|
107 |
+
# st.write(markdown("- ..."+context[:start_idx] +
|
108 |
+
# str(annotation(answer, "ANSWER", "#3e1c21", "white")) +
|
109 |
+
# context[end_idx:]+"..."), unsafe_allow_html=True)
|
110 |
+
# source = ""
|
111 |
+
# name = unquote(result['meta']['name']).replace('_', ' ')
|
112 |
+
# url = result['meta']['url']
|
113 |
+
# source = f"[{name}]({url})"
|
114 |
+
# st.markdown(
|
115 |
+
# f"**Score:** {result['score']:.2f} - **Source:** {source}")
|
116 |
+
|
117 |
+
main()
|
app_utils/backend_utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from haystack.document_stores import FAISSDocumentStore
|
3 |
+
from haystack.nodes import EmbeddingRetriever
|
4 |
+
from haystack.pipelines import Pipeline
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
from app_utils.entailment_checker import EntailmentChecker
|
9 |
+
|
10 |
+
from app_utils.config import STATEMENTS_PATH, INDEX_DIR, RETRIEVER_MODEL, RETRIEVER_MODEL_FORMAT, NLI_MODEL
|
11 |
+
|
12 |
+
# cached to make index and models load only at start
|
13 |
+
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True)
|
14 |
+
def start_haystack():
|
15 |
+
"""
|
16 |
+
load document store, retriever, reader and create pipeline
|
17 |
+
"""
|
18 |
+
shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
|
19 |
+
document_store = FAISSDocumentStore(
|
20 |
+
faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
|
21 |
+
faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
|
22 |
+
print(f'Index size: {document_store.get_document_count()}')
|
23 |
+
|
24 |
+
retriever = EmbeddingRetriever(
|
25 |
+
document_store=document_store,
|
26 |
+
embedding_model=RETRIEVER_MODEL,
|
27 |
+
model_format=RETRIEVER_MODEL_FORMAT
|
28 |
+
)
|
29 |
+
|
30 |
+
entailment_checker = EntailmentChecker(model_name_or_path=NLI_MODEL,
|
31 |
+
use_gpu=False)
|
32 |
+
|
33 |
+
|
34 |
+
pipe = Pipeline()
|
35 |
+
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
|
36 |
+
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
|
37 |
+
return pipe
|
38 |
+
|
39 |
+
pipe = start_haystack()
|
40 |
+
# the pipeline is not included as parameter of the following function,
|
41 |
+
# because it is difficult to cache
|
42 |
+
@st.cache(persist=True, allow_output_mutation=True)
|
43 |
+
def query(question: str, retriever_top_k: int = 5):
|
44 |
+
"""Run query and get answers"""
|
45 |
+
params = {"retriever": {"top_k": retriever_top_k}}
|
46 |
+
results = pipe.run(question, params=params)
|
47 |
+
print(results)
|
48 |
+
return results
|
49 |
+
|
50 |
+
@st.cache()
|
51 |
+
def load_questions():
|
52 |
+
"""Load statements from file"""
|
53 |
+
with open(STATEMENTS_PATH) as fin:
|
54 |
+
questions = [line.strip() for line in fin.readlines()
|
55 |
+
if not line.startswith('#')]
|
56 |
+
return questions
|
57 |
+
|
58 |
+
|
app_utils/config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
INDEX_DIR = 'data/index'
|
3 |
+
STATEMENTS_PATH = 'data/statements.txt'
|
4 |
+
|
5 |
+
RETRIEVER_MODEL = "sentence-transformers/msmarco-distilbert-base-tas-b"
|
6 |
+
RETRIEVER_MODEL_FORMAT = "sentence_transformers"
|
7 |
+
RETRIEVER_TOP_K = 5
|
8 |
+
NLI_MODEL = "valhalla/distilbart-mnli-12-1"
|
app_utils/frontend_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def set_state_if_absent(key, value):
|
5 |
+
if key not in st.session_state:
|
6 |
+
st.session_state[key] = value
|
7 |
+
|
8 |
+
# Small callback to reset the interface in case the text of the question changes
|
9 |
+
def reset_results(*args):
|
10 |
+
st.session_state.answer = None
|
11 |
+
st.session_state.results = None
|
12 |
+
st.session_state.raw_json = None
|
13 |
+
|
14 |
+
|
15 |
+
|
data/statements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Kurt Cobain died in 1994
|
2 |
+
Kurt Cobain died in 2008
|
3 |
+
Green Day are a heavy metal band
|
4 |
+
Green Day are a punk rock band
|
5 |
+
The Beatles' first album was released in 1985
|
pages/Info.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
|