File size: 6,223 Bytes
9f45711 89b2c4c 9f45711 f7ee444 9f45711 89b2c4c 9f45711 c064ff6 9f45711 c064ff6 9f45711 e55a05a 9f45711 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import streamlit as st
st.set_page_config(
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
page_title='ZShot', # String or None. Strings get appended with "• Streamlit".
page_icon='./logo_zshot.png', # String, anything supported by st.image, or None.
)
import os
import sys
import warnings
import spacy
from zshot.linker import LinkerSMXM, LinkerTARS, LinkerRegen
from zshot.utils.data_models import Entity
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.mentions_extractor.utils import ExtractorType
from zshot import PipelineConfig, displacy
sys.path.append(os.path.abspath('./'))
import streamlit_apps_config as config
warnings.simplefilter('ignore')
## Marking down NER Style
st.markdown(config.STYLE_CONFIG, unsafe_allow_html=True)
########## To Remove the Main Menu Hamburger ########
hide_menu_style = """
<style>
#MainMenu {visibility: hidden;}
</style>
"""
st.markdown(hide_menu_style, unsafe_allow_html=True)
########## Side Bar ########
## loading logo(newer version with href)
import base64
@st.cache_data()
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
@st.cache_data()
def get_img_with_href(local_img_path, target_url, size='big'):
img_format = os.path.splitext(local_img_path)[-1].replace('.', '')
bin_str = get_base64_of_bin_file(local_img_path)
height = '90%' if size == 'big' else '45%'
width = '90%' if size == 'big' else '45%'
html_code = f'''
<a href="{target_url}" style='text-align: center;'>
<img height="{height}" width="{width}" style='display: block; margin-left: auto; margin-right: auto;' src="data:image/{img_format};base64,{bin_str}" />
</a>'''
return html_code
logo_html = get_img_with_href('./logo.png', 'https://www.ibm.com/')
st.sidebar.markdown(logo_html, unsafe_allow_html=True)
logo_html = get_img_with_href('./logo_zshot.png', 'https://github.com/IBM/zshot', size='small')
st.sidebar.markdown(logo_html, unsafe_allow_html=True)
# sidebar info
linkers = ["REGEN", "SMXM", "TARS"]
st.sidebar.title("Linker to test")
selected_model = st.sidebar.selectbox("", linkers)
######## Main Page #########
if selected_model == "REGEN":
app_title = "REGEN Linker"
app_description = "REGEN is a T5 implementation of GENRE. It performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
elif selected_model == "SMXM":
app_title = "SMXM Linker"
app_description = "SMXM model uses the description of the entities to give the model information about the entities."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
elif selected_model == "TARS":
app_title = "TARS Linker"
app_description = "TARS doesn't need the descriptions of the entities, so if you can't provide the descriptions of the entities maybe this is the approach you're looking for."
st.title(app_title)
st.markdown("<h2>" + app_description + "</h2>", unsafe_allow_html=True)
st.subheader("")
if 'entities' not in st.session_state:
st.session_state['entities'] = [
Entity(name="company", description="The name of a company"),
Entity(name="location", description="A physical location"),
Entity(name="chemical compound", description="Any substance composed of identical molecules consisting of atoms of two or more chemical elements.")
]
def add_ent():
st.session_state['entities'].append(Entity(name=st.session_state["name"], description=st.session_state["description"]))
st.session_state['name'] = ""
st.session_state['description'] = ''
st.write(st.session_state["name"])
st.write(st.session_state["description"])
for i, entity in enumerate(st.session_state['entities']):
col1, col2, col3 = st.columns([2, 5, 1])
with col1:
st.text(entity.name)
with col2:
st.text(entity.description)
with col3:
b = st.button('Remove', key=f"ent_{i}")
if b:
st.session_state['entities'].pop(i)
st.experimental_rerun() # This causes the app to rerun
with st.form(key="form"):
col1, col2, col3 = st.columns([2, 5, 1])
with col1:
st.text_input("Entity Name", key="name")
with col2:
st.text_input("Entity Description", key="description")
with col3:
st.form_submit_button('Add', on_click=add_ent)
st.markdown("________")
text = st.text_input("Type here your text and press enter to run:",
value="CH2O2 is a chemical compound similar to Acetamide used in International Business "
"Machines Corporation (IBM) to create new materials that act like PAGs.")
def build_pipeline(model_name=selected_model):
nlp = spacy.blank('en')
mentions_extractor = None
if model_name == "REGEN":
linker = LinkerRegen()
nlp = spacy.load('en_core_web_sm')
mentions_extractor = MentionsExtractorSpacy(ExtractorType.NER)
elif model_name == "TARS":
linker = LinkerTARS()
elif model_name == "SMXM":
linker = LinkerSMXM()
config = PipelineConfig(
entities=st.session_state['entities'],
mentions_extractor=mentions_extractor,
linker=linker
)
nlp.add_pipe("zshot", config=config, last=True)
return nlp
predict = st.button("Run ZShot")
if predict:
# placeholder for warning
placeholder = st.empty()
placeholder.info("Processing...")
nlp = build_pipeline()
doc = nlp(text)
placeholder.empty()
ent_html = displacy.render(doc, style="ent", jupyter=False) # Display the entity visualization in the browser:
st.markdown(ent_html, unsafe_allow_html=True)
st.sidebar.info("""See more:
- Check ZShot Github [here](https://github.com/IBM/zshot)
- Check ZShot documentation [here](https://ibm.github.io/zshot/)""")
|