zshot / app.py
marmg's picture
Renamed GENRE to REGEN
e55a05a
raw
history blame
6.23 kB
import streamlit as st
st.set_page_config(
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title='ZShot', # String or None. Strings get appended with "• Streamlit".
page_icon='./logo_zshot.png', # String, anything supported by st.image, or None.
)
import os
import sys
import spacy
from zshot.linker import LinkerSMXM, LinkerTARS, LinkerRegen
from zshot.utils.data_models import Entity
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.mentions_extractor.utils import ExtractorType
from zshot import PipelineConfig
from spacy import displacy
sys.path.append(os.path.abspath('./'))
import streamlit_apps_config as config
## Marking down NER Style
st.markdown(config.STYLE_CONFIG, unsafe_allow_html=True)
########## To Remove the Main Menu Hamburger ########
hide_menu_style = """
<style>
#MainMenu {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)
########## Side Bar ########
## loading logo(newer version with href)
import base64
@st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
@st.cache(allow_output_mutation=True)
def get_img_with_href(local_img_path, target_url, size='big'):
img_format = os.path.splitext(local_img_path)[-1].replace('.', '')
bin_str = get_base64_of_bin_file(local_img_path)
height = '90%' if size == 'big' else '45%'
width = '90%' if size == 'big' else '45%'
html_code = f'''
<a href="{target_url}" style='text-align: center;'>
<img height="{height}" width="{width}" style='display: block; margin-left: auto; margin-right: auto;' src="data:image/{img_format};base64,{bin_str}" />
</a>'''
return html_code
logo_html = get_img_with_href('./logo.png', 'https://www.ibm.com/')
st.sidebar.markdown(logo_html, unsafe_allow_html=True)
logo_html = get_img_with_href('./logo_zshot.png', 'https://github.com/IBM/zshot', size='small')
st.sidebar.markdown(logo_html, unsafe_allow_html=True)
# sidebar info
linkers = ["REGEN", "SMXM", "TARS"]
st.sidebar.title("Linker to test")
selected_model = st.sidebar.selectbox("", linkers)
######## Main Page #########
if selected_model == "REGEN":
app_title = "REGEN Linker"
app_description = "REGEN is a T5 implementation of GENRE. It performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
elif selected_model == "SMXM":
app_title = "SMXM Linker"
app_description = "SMXM model uses the description of the entities to give the model information about the entities."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
elif selected_model == "TARS":
app_title = "TARS Linker"
app_description = "TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
st.subheader("")
if 'entities' not in st.session_state:
st.session_state['entities'] = [
Entity(name="company", description="The name of a company"),
Entity(name="location", description="A physical location"),
Entity(name="chemical compound", description="Any substance composed of identical molecules consisting of atoms of two or more chemical elements.")
]
def add_ent():
st.session_state['entities'].append(Entity(name=st.session_state["name"], description=st.session_state["description"]))
st.session_state['name'] = ""
st.session_state['description'] = ''
st.write(st.session_state["name"])
st.write(st.session_state["description"])
for i, entity in enumerate(st.session_state['entities']):
col1, col2, col3 = st.columns([2, 5, 1])
with col1:
st.text(entity.name)
with col2:
st.text(entity.description)
with col3:
b = st.button('Remove', key=f"ent_{i}")
if b:
st.session_state['entities'].pop(i)
st.experimental_rerun() # This causes the app to rerun
with st.form(key="form"):
col1, col2, col3 = st.columns([2, 5, 1])
with col1:
st.text_input("Entity Name", key="name")
with col2:
st.text_input("Entity Description", key="description")
with col3:
st.form_submit_button('Add', on_click=add_ent)
st.markdown("________")
text = st.text_input("Type here your text and press enter to run:",
value="CH2O2 is a chemical compound similar to Acetamide used in International Business "
"Machines Corporation (IBM) to create new materials that act like PAGs.")
def build_pipeline(model_name=selected_model):
nlp = spacy.blank('en')
mentions_extractor = None
if model_name == "REGEN":
linker = LinkerRegen()
nlp = spacy.load('en_core_web_sm')
mentions_extractor = MentionsExtractorSpacy(ExtractorType.NER)
elif model_name == "TARS":
linker = LinkerTARS()
elif model_name == "SMXM":
linker = LinkerSMXM()
config = PipelineConfig(
entities=st.session_state['entities'],
mentions_extractor=mentions_extractor,
linker=linker
)
nlp.add_pipe("zshot", config=config, last=True)
return nlp
predict = st.button("Run ZShot")
if predict:
# placeholder for warning
placeholder = st.empty()
placeholder.info("Processing...")
nlp = build_pipeline()
doc = nlp(text)
placeholder.empty()
ent_html = displacy.render(doc, style="ent", jupyter=False) # Display the entity visualization in the browser:
st.markdown(ent_html, unsafe_allow_html=True)
st.sidebar.info("""See more:
- Check ZShot Github [here](https://github.com/IBM/zshot)
- Check ZShot documentation [here](https://ibm.github.io/zshot/)""")