jcg00v's picture
Update app.py
e7e958a verified
raw
history blame
8.15 kB
import streamlit as st
from PIL import Image
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
from htbuilder.units import percent, px
from htbuilder.funcs import rgba, rgb
from pathlib import Path
import base64
def clear_text():
st.session_state.text = st.session_state.widget
st.session_state.widget = ""
def get_result_text_es_pt (list_entity, text, lang):
result_words = []
tmp_word = ""
if lang == "es":
punc_tags = ['¿', '?', '¡', '!', ',', '.', ':']
else:
punc_tags = ['?', '!', ',', '.', ':']
for idx, entity in enumerate(list_entity):
tag = entity["entity"]
word = entity["word"]
start = entity["start"]
end = entity["end"]
# check punctuation
punc_in = next((p for p in punc_tags if p in tag), "")
subword = False
# check subwords
if word[0] == "#":
subword = True
if tmp_word == "":
p_s = list_entity[idx-1]["start"]
p_e = list_entity[idx-1]["end"]
tmp_word = text[p_s:p_e] + text[start:end]
else:
tmp_word = tmp_word + text[start:end]
word = tmp_word
else:
tmp_word = ""
word = text[start:end]
if tag == "l":
word = word
elif tag == "u":
word = word.capitalize()
# case with punctuation
else:
if tag[-1] == "l":
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
elif tag[-1] == "u":
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
if tag != "l":
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
if subword == True:
result_words[-1] = word
else:
result_words.append(word)
return " ".join(result_words)
def get_result_text_ca (list_entity, text):
result_words = []
punc_tags = ['?', '!', ',', '.', ':']
tmp_word = ""
for idx, entity in enumerate(list_entity):
start = entity["start"]
end = entity["end"]
tag = entity["entity"]
word = entity["word"]
# check punctuation
punc_in = next((p for p in punc_tags if p in tag), "")
subword = False
# check subwords
if word[0] != "Ġ":
subword = True
if tmp_word == "":
p_s = list_entity[idx-1]["start"]
p_e = list_entity[idx-1]["end"]
tmp_word = text[p_s:p_e] + text[start:end]
else:
tmp_word = tmp_word + text[start:end]
word = tmp_word
else:
tmp_word = ""
word = text[start:end]
if tag == "l":
word = word
elif tag == "u":
word = word.capitalize()
# case with punctuation
else:
if tag[-1] == "l":
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
elif tag[-1] == "u":
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
if tag != "l":
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
if subword == True:
result_words[-1] = word
else:
result_words.append(word)
return " ".join(result_words)
def link(link, text, **style):
return a(_href=link, _target="_blank", style=styles(**style))(text)
def layout(*args):
style = """
<style>
# MainMenu {visibility: hidden;}
footer {visibility: hidden;}
.stApp { bottom: 105px; }
</style>
"""
style_div = styles(
position="fixed",
left=0,
bottom=0,
margin=px(0, 0, 0, 0),
width=percent(100),
color="black",
text_align="center",
height="auto",
opacity=1
)
style_hr = styles(
display="block",
margin=px(8, 8, "auto", "auto"),
border_style="inset",
border_width=px(2)
)
body = p()
foot = div(
style=style_div
)(
hr(
style=style_hr
),
body
)
st.markdown(style, unsafe_allow_html=True)
for arg in args:
if isinstance(arg, str):
body(arg)
elif isinstance(arg, HtmlElement):
body(arg)
st.markdown(str(foot), unsafe_allow_html=True)
def img_to_bytes(img_path):
img_bytes = Path(img_path).read_bytes()
encoded = base64.b64encode(img_bytes).decode()
return encoded
def footer():
logo_path = Path(__file__).with_name("vocali_logo.jpg").parent.absolute()
funding_path = Path(__file__).with_name("logo_funding.png").parent.absolute()
myargs = [
"Made in ",
# image(img_to_bytes(str(logo_path) + "/vocali_logo.jpg"), width=px(50), height=px(50)),
"<img src='data:image/jpg;base64,{}' class='img-fluid' width='50' height='50'>".format(
img_to_bytes(str(logo_path) + "/vocali_logo.jpg")
),
link("https://vocali.net/", "VÓCALI"),
" with funding ",
# image(img_to_bytes(str(funding_path) + "/logo_funding.png"), height=px(50), width=px(200)),
"<img src='data:image/png;base64,{}' class='img-fluid' width='250' height='50'>".format(
img_to_bytes(str(funding_path) + "/logo_funding.png")
),
br(),
"This work was funded by the Spanish Government, the Spanish Ministry of Economy and Digital Transformation through the Digital Transformation through the 'Recovery, Transformation and Resilience Plan' and also funded by the European Union NextGenerationEU/PRTR through the research project 2021/C005/0015007",
]
layout(*myargs)
if __name__ == "__main__":
if "text" not in st.session_state:
st.session_state.text = ""
st.title('Sanivert Punctuation And Capitalization Restoration')
model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es)
model_ca = AutoModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
pipe_pt = pipeline("token-classification", model=model_pt, tokenizer=tokenizer_pt)
input_text = st.selectbox(
label = "Choose an language",
options = ["Spanish", "Portuguese", "Catalan"]
)
st.subheader("Enter the text to be analyzed.")
st.text_input('Enter text', key='widget', on_change=clear_text) #text is stored in this variable
text = st.session_state.text
print(text)
if input_text == "Spanish":
result_pipe = pipe_es(text)
out = get_result_text_es_pt(result_pipe, text, "es")
elif input_text == "Portuguese":
result_pipe = pipe_pt(text)
out = get_result_text_es_pt(result_pipe, text, "pt")
elif input_text == "Catalan":
result_pipe = pipe_ca(text)
out = get_result_text_ca(result_pipe, text)
st.markdown(out, unsafe_allow_html=True)
footer()