import contextlib import streamlit as st import streamlit.components.v1 as components from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import utils from kb import KB import wikipedia MAX_TOPICS= 5 BUTTON_COLUMS = 4 st.header("Extracting a Knowledge Graph from text") # Loading the model def load_model(): tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") return tokenizer, model def generate_kb(): st_model_load = st.text('Loading NER model... It may take a while.') tokenizer, model = load_model() st.success('Model loaded!') st_model_load.text("") kb = utils.from_text_to_kb(' '.join(st.session_state['wiki_text']), model, tokenizer, "", verbose=True) utils.save_network_html(kb, filename="networks/network.html") st.session_state.kb_chart = "networks/network.html" st.session_state.kb_text = kb.get_textual_representation() st.session_state.error_url = None def show_textbox(): if len(st.session_state['wiki_text']) != 0: for i, t in enumerate(st.session_state['wiki_text']): new_expander = st.expander(label=f"{t[:30]}...", expanded=(i==0)) with new_expander: st.markdown(t) def wiki_show_text(page_title): with st.spinner(text="Fetching wiki page..."): # print(st.session_state['wiki_suggestions']) try: page = wikipedia.page(title=page_title, auto_suggest=False) st.session_state['wiki_text'].append(page.summary) st.session_state['topics'].append(page_title.lower()) st.session_state['wiki_suggestions'].remove(page_title) show_textbox() except wikipedia.DisambiguationError as e: with st.spinner(text="Woops, ambigious term, recalculating options..."): st.session_state['wiki_suggestions'].remove(page_title) temp = st.session_state['wiki_suggestions'] + e.options[:3] st.session_state['wiki_suggestions'] = list(set(temp)) show_textbox() except wikipedia.WikipediaException: st.session_state['wiki_suggestions'].remove(page_title) def wiki_add_text(term): if len(st.session_state['wiki_text']) > MAX_TOPICS: return try: page = wikipedia.page(title=term, auto_suggest=False) extra_text = page.summary st.session_state['wiki_text'].append(extra_text) st.session_state['topics'].append(term.lower()) st.session_state['nodes'].remove(term) except wikipedia.DisambiguationError as e: with st.spinner(text="Woops, ambigious term, recalculating options..."): st.session_state['nodes'].remove(term) temp = st.session_state['nodes'] + e.options[:3] st.session_state['nodes'] = list(set(temp)) except wikipedia.WikipediaException as e: st.session_state['nodes'].remove(term) def reset_thread(): st.session_state['wiki_text'] = [] st.session_state['topics'] = [] st.session_state['nodes'] = [] st.session_state['has_run_wiki'] = False st.session_state['wiki_suggestions'] = [] st.session_state['html_wiki'] = '' def show_wiki_hub_page(): cols = st.columns([7, 1]) b_cols = st.columns([2, 1.2, 8]) with cols[0]: st.text_input("Search", on_change=wiki_show_suggestion, key="text", value="graphs, are, awesome") with cols[1]: st.text('') st.text('') st.button("Search", on_click=wiki_show_suggestion, key="show_suggestion_key") with b_cols[0]: st.button("Generate KB", on_click=generate_kb) with b_cols[1]: st.button("Reset", on_click=reset_thread) def wiki_show_suggestion(): with st.spinner(text="Fetching wiki topics..."): text = st.session_state.text if (text is not None) and (text != ""): subjects = text.split(",")[:MAX_TOPICS] for subj in subjects: st.session_state['wiki_suggestions'] += wikipedia.search(subj, results = 3) show_wiki_suggestions_buttons() def show_wiki_suggestions_buttons(): if len(st.session_state['wiki_suggestions']) == 0: return num_buttons = len(st.session_state['wiki_suggestions']) # st.session_state['wiki_suggestions'] = list(set(st.session_state['wiki_suggestions'])) num_cols = num_buttons if 0 < num_buttons < BUTTON_COLUMS else BUTTON_COLUMS columns = st.columns([1] * num_cols ) for q in range(1 + num_buttons//num_cols): for i, (c, s) in enumerate(zip(columns, st.session_state['wiki_suggestions'][q*num_cols: (q+1)*num_cols])): with c: with contextlib.suppress(Exception): st.button(s, on_click=wiki_show_text, args=(s,), key=str(i)+s+"wiki_suggestion") def init_variables(): if 'wiki_suggestions' not in st.session_state: st.session_state['wiki_text'] = [] st.session_state['topics'] = [] st.session_state['nodes'] = [] st.session_state['has_run_wiki'] = True st.session_state['wiki_suggestions'] = [] st.session_state['html_wiki'] = '' init_variables() show_wiki_hub_page() # kb chart session state if 'kb_chart' not in st.session_state: st.session_state.kb_chart = None if 'kb_text' not in st.session_state: st.session_state.kb_text = None if 'error_url' not in st.session_state: st.session_state.error_url = None # show graph if st.session_state.error_url: st.markdown(st.session_state.error_url) elif st.session_state.kb_chart: with st.container(): st.subheader("Generated KB") st.markdown("*You can interact with the graph and zoom.*") html_source_code = open(st.session_state.kb_chart, 'r', encoding='utf-8').read() components.html(html_source_code, width=700, height=700) st.markdown(st.session_state.kb_text)