|
import streamlit as st |
|
|
|
st.set_page_config( |
|
layout="wide", |
|
initial_sidebar_state="auto", |
|
page_title='ZShot', |
|
page_icon='./logo_zshot.png', |
|
) |
|
|
|
import os |
|
import sys |
|
import warnings |
|
|
|
import spacy |
|
from zshot.linker import LinkerSMXM, LinkerTARS, LinkerRegen |
|
from zshot.utils.data_models import Entity |
|
from zshot.mentions_extractor import MentionsExtractorSpacy |
|
from zshot.mentions_extractor.utils import ExtractorType |
|
from zshot import PipelineConfig, displacy |
|
|
|
sys.path.append(os.path.abspath('./')) |
|
import streamlit_apps_config as config |
|
|
|
warnings.simplefilter('ignore') |
|
|
|
|
|
st.markdown(config.STYLE_CONFIG, unsafe_allow_html=True) |
|
|
|
|
|
|
|
hide_menu_style = """ |
|
<style> |
|
#MainMenu {visibility: hidden;} |
|
</style> |
|
""" |
|
st.markdown(hide_menu_style, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
import base64 |
|
|
|
|
|
@st.cache_data() |
|
def get_base64_of_bin_file(bin_file): |
|
with open(bin_file, 'rb') as f: |
|
data = f.read() |
|
return base64.b64encode(data).decode() |
|
|
|
|
|
@st.cache_data() |
|
def get_img_with_href(local_img_path, target_url, size='big'): |
|
img_format = os.path.splitext(local_img_path)[-1].replace('.', '') |
|
bin_str = get_base64_of_bin_file(local_img_path) |
|
height = '90%' if size == 'big' else '45%' |
|
width = '90%' if size == 'big' else '45%' |
|
html_code = f''' |
|
<a href="{target_url}" style='text-align: center;'> |
|
<img height="{height}" width="{width}" style='display: block; margin-left: auto; margin-right: auto;' src="data:image/{img_format};base64,{bin_str}" /> |
|
</a>''' |
|
return html_code |
|
|
|
|
|
logo_html = get_img_with_href('./logo.png', 'https://www.ibm.com/') |
|
st.sidebar.markdown(logo_html, unsafe_allow_html=True) |
|
logo_html = get_img_with_href('./logo_zshot.png', 'https://github.com/IBM/zshot', size='small') |
|
st.sidebar.markdown(logo_html, unsafe_allow_html=True) |
|
|
|
|
|
linkers = ["REGEN", "SMXM", "TARS"] |
|
st.sidebar.title("Linker to test") |
|
selected_model = st.sidebar.selectbox("", linkers) |
|
|
|
|
|
|
|
if selected_model == "REGEN": |
|
app_title = "REGEN Linker" |
|
app_description = "REGEN is a T5 implementation of GENRE. It performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers." |
|
st.title(app_title) |
|
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
|
|
|
elif selected_model == "SMXM": |
|
app_title = "SMXM Linker" |
|
app_description = "SMXM model uses the description of the entities to give the model information about the entities." |
|
st.title(app_title) |
|
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
|
|
|
elif selected_model == "TARS": |
|
app_title = "TARS Linker" |
|
app_description = "TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for." |
|
st.title(app_title) |
|
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True) |
|
|
|
st.subheader("") |
|
|
|
if 'entities' not in st.session_state: |
|
st.session_state['entities'] = [ |
|
Entity(name="company", description="The name of a company"), |
|
Entity(name="location", description="A physical location"), |
|
Entity(name="chemical compound", description="Any substance composed of identical molecules consisting of atoms of two or more chemical elements.") |
|
] |
|
|
|
def add_ent(): |
|
st.session_state['entities'].append(Entity(name=st.session_state["name"], description=st.session_state["description"])) |
|
st.session_state['name'] = "" |
|
st.session_state['description'] = '' |
|
st.write(st.session_state["name"]) |
|
st.write(st.session_state["description"]) |
|
|
|
for i, entity in enumerate(st.session_state['entities']): |
|
col1, col2, col3 = st.columns([2, 5, 1]) |
|
with col1: |
|
st.text(entity.name) |
|
with col2: |
|
st.text(entity.description) |
|
with col3: |
|
b = st.button('Remove', key=f"ent_{i}") |
|
if b: |
|
st.session_state['entities'].pop(i) |
|
st.experimental_rerun() |
|
|
|
with st.form(key="form"): |
|
col1, col2, col3 = st.columns([2, 5, 1]) |
|
with col1: |
|
st.text_input("Entity Name", key="name") |
|
with col2: |
|
st.text_input("Entity Description", key="description") |
|
with col3: |
|
st.form_submit_button('Add', on_click=add_ent) |
|
|
|
st.markdown("________") |
|
text = st.text_input("Type here your text and press enter to run:", |
|
value="CH2O2 is a chemical compound similar to Acetamide used in International Business " |
|
"Machines Corporation (IBM) to create new materials that act like PAGs.") |
|
|
|
def build_pipeline(model_name=selected_model): |
|
nlp = spacy.blank('en') |
|
mentions_extractor = None |
|
|
|
if model_name == "REGEN": |
|
linker = LinkerRegen() |
|
nlp = spacy.load('en_core_web_sm') |
|
mentions_extractor = MentionsExtractorSpacy(ExtractorType.NER) |
|
elif model_name == "TARS": |
|
linker = LinkerTARS() |
|
elif model_name == "SMXM": |
|
linker = LinkerSMXM() |
|
|
|
config = PipelineConfig( |
|
entities=st.session_state['entities'], |
|
mentions_extractor=mentions_extractor, |
|
linker=linker |
|
) |
|
nlp.add_pipe("zshot", config=config, last=True) |
|
|
|
return nlp |
|
|
|
predict = st.button("Run ZShot") |
|
if predict: |
|
|
|
placeholder = st.empty() |
|
placeholder.info("Processing...") |
|
|
|
nlp = build_pipeline() |
|
doc = nlp(text) |
|
placeholder.empty() |
|
|
|
ent_html = displacy.render(doc, style="ent", jupyter=False) |
|
st.markdown(ent_html, unsafe_allow_html=True) |
|
|
|
st.sidebar.info("""See more: |
|
- Check ZShot Github [here](https://github.com/IBM/zshot) |
|
- Check ZShot documentation [here](https://ibm.github.io/zshot/)""") |
|
|