Spaces:
Runtime error
Runtime error
import contextlib | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import utils | |
from kb import KB | |
import wikipedia | |
MAX_TOPICS= 5 | |
BUTTON_COLUMS = 4 | |
st.header("Extracting a Knowledge Graph from text") | |
# Loading the model | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") | |
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") | |
return tokenizer, model | |
def generate_kb(): | |
st_model_load = st.text('Loading NER model... It may take a while.') | |
tokenizer, model = load_model() | |
st.success('Model loaded!') | |
st_model_load.text("") | |
kb = utils.from_text_to_kb(' '.join(st.session_state['wiki_text']), model, tokenizer, "", verbose=True) | |
utils.save_network_html(kb, filename="networks/network.html") | |
st.session_state.kb_chart = "networks/network.html" | |
st.session_state.kb_text = kb.get_textual_representation() | |
st.session_state.error_url = None | |
def show_textbox(): | |
if len(st.session_state['wiki_text']) != 0: | |
for i, t in enumerate(st.session_state['wiki_text']): | |
new_expander = st.expander(label=f"{t[:30]}...", expanded=(i==0)) | |
with new_expander: | |
st.markdown(t) | |
def wiki_show_text(page_title): | |
with st.spinner(text="Fetching wiki page..."): | |
# print(st.session_state['wiki_suggestions']) | |
try: | |
page = wikipedia.page(title=page_title, auto_suggest=False) | |
st.session_state['wiki_text'].append(page.summary) | |
st.session_state['topics'].append(page_title.lower()) | |
st.session_state['wiki_suggestions'].remove(page_title) | |
show_textbox() | |
except wikipedia.DisambiguationError as e: | |
with st.spinner(text="Woops, ambigious term, recalculating options..."): | |
st.session_state['wiki_suggestions'].remove(page_title) | |
temp = st.session_state['wiki_suggestions'] + e.options[:3] | |
st.session_state['wiki_suggestions'] = list(set(temp)) | |
show_textbox() | |
except wikipedia.WikipediaException: | |
st.session_state['wiki_suggestions'].remove(page_title) | |
def wiki_add_text(term): | |
if len(st.session_state['wiki_text']) > MAX_TOPICS: | |
return | |
try: | |
page = wikipedia.page(title=term, auto_suggest=False) | |
extra_text = page.summary | |
st.session_state['wiki_text'].append(extra_text) | |
st.session_state['topics'].append(term.lower()) | |
st.session_state['nodes'].remove(term) | |
except wikipedia.DisambiguationError as e: | |
with st.spinner(text="Woops, ambigious term, recalculating options..."): | |
st.session_state['nodes'].remove(term) | |
temp = st.session_state['nodes'] + e.options[:3] | |
st.session_state['nodes'] = list(set(temp)) | |
except wikipedia.WikipediaException as e: | |
st.session_state['nodes'].remove(term) | |
def reset_thread(): | |
st.session_state['wiki_text'] = [] | |
st.session_state['topics'] = [] | |
st.session_state['nodes'] = [] | |
st.session_state['has_run_wiki'] = False | |
st.session_state['wiki_suggestions'] = [] | |
st.session_state['html_wiki'] = '' | |
def show_wiki_hub_page(): | |
cols = st.columns([7, 1]) | |
b_cols = st.columns([2, 1.2, 8]) | |
with cols[0]: | |
st.text_input("Search", on_change=wiki_show_suggestion, key="text", value="graphs, are, awesome") | |
with cols[1]: | |
st.text('') | |
st.text('') | |
st.button("Search", on_click=wiki_show_suggestion, key="show_suggestion_key") | |
with b_cols[0]: | |
st.button("Generate KB", on_click=generate_kb) | |
with b_cols[1]: | |
st.button("Reset", on_click=reset_thread) | |
def wiki_show_suggestion(): | |
with st.spinner(text="Fetching wiki topics..."): | |
text = st.session_state.text | |
if (text is not None) and (text != ""): | |
subjects = text.split(",")[:MAX_TOPICS] | |
for subj in subjects: | |
st.session_state['wiki_suggestions'] += wikipedia.search(subj, results = 3) | |
show_wiki_suggestions_buttons() | |
def show_wiki_suggestions_buttons(): | |
if len(st.session_state['wiki_suggestions']) == 0: | |
return | |
num_buttons = len(st.session_state['wiki_suggestions']) | |
# st.session_state['wiki_suggestions'] = list(set(st.session_state['wiki_suggestions'])) | |
num_cols = num_buttons if 0 < num_buttons < BUTTON_COLUMS else BUTTON_COLUMS | |
columns = st.columns([1] * num_cols ) | |
for q in range(1 + num_buttons//num_cols): | |
for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'][q*num_cols: (q+1)*num_cols])): | |
with c: | |
with contextlib.suppress(Exception): | |
st.button(s, on_click=wiki_show_text, args=(s,), key=str(i)+s+"wiki_suggestion") | |
def init_variables(): | |
if 'wiki_suggestions' not in st.session_state: | |
st.session_state['wiki_text'] = [] | |
st.session_state['topics'] = [] | |
st.session_state['nodes'] = [] | |
st.session_state['has_run_wiki'] = True | |
st.session_state['wiki_suggestions'] = [] | |
st.session_state['html_wiki'] = '' | |
init_variables() | |
show_wiki_hub_page() | |
# kb chart session state | |
if 'kb_chart' not in st.session_state: | |
st.session_state.kb_chart = None | |
if 'kb_text' not in st.session_state: | |
st.session_state.kb_text = None | |
if 'error_url' not in st.session_state: | |
st.session_state.error_url = None | |
# show graph | |
if st.session_state.error_url: | |
st.markdown(st.session_state.error_url) | |
elif st.session_state.kb_chart: | |
with st.container(): | |
st.subheader("Generated KB") | |
st.markdown("*You can interact with the graph and zoom.*") | |
html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read() | |
components.html(html_source_code, width=700, height=700) | |
st.markdown(st.session_state.kb_text) |