|
import streamlit as st |
|
from PIL import Image |
|
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer |
|
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts |
|
from htbuilder.units import percent, px |
|
from htbuilder.funcs import rgba, rgb |
|
from pathlib import Path |
|
import base64 |
|
|
|
def clear_text(): |
|
st.session_state.text = st.session_state.widget |
|
st.session_state.widget = "" |
|
|
|
|
|
def get_result_text_es_pt (list_entity, text, lang): |
|
result_words = [] |
|
tmp_word = "" |
|
if lang == "es": |
|
punc_tags = ['¿', '?', '¡', '!', ',', '.', ':'] |
|
else: |
|
punc_tags = ['?', '!', ',', '.', ':'] |
|
|
|
for idx, entity in enumerate(list_entity): |
|
tag = entity["entity"] |
|
word = entity["word"] |
|
start = entity["start"] |
|
end = entity["end"] |
|
|
|
|
|
punc_in = next((p for p in punc_tags if p in tag), "") |
|
|
|
subword = False |
|
|
|
if word[0] == "#": |
|
subword = True |
|
if tmp_word == "": |
|
p_s = list_entity[idx-1]["start"] |
|
p_e = list_entity[idx-1]["end"] |
|
tmp_word = text[p_s:p_e] + text[start:end] |
|
else: |
|
tmp_word = tmp_word + text[start:end] |
|
word = tmp_word |
|
else: |
|
tmp_word = "" |
|
word = text[start:end] |
|
|
|
if tag == "l": |
|
word = word |
|
elif tag == "u": |
|
word = word.capitalize() |
|
|
|
else: |
|
if tag[-1] == "l": |
|
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in) |
|
elif tag[-1] == "u": |
|
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in) |
|
|
|
if tag != "l": |
|
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>' |
|
|
|
if subword == True: |
|
result_words[-1] = word |
|
else: |
|
result_words.append(word) |
|
|
|
return " ".join(result_words) |
|
|
|
|
|
|
|
def get_result_text_ca (list_entity, text): |
|
result_words = [] |
|
punc_tags = ['?', '!', ',', '.', ':'] |
|
tmp_word = "" |
|
for idx, entity in enumerate(list_entity): |
|
start = entity["start"] |
|
end = entity["end"] |
|
tag = entity["entity"] |
|
word = entity["word"] |
|
|
|
|
|
punc_in = next((p for p in punc_tags if p in tag), "") |
|
|
|
subword = False |
|
|
|
if word[0] != "Ġ": |
|
subword = True |
|
if tmp_word == "": |
|
p_s = list_entity[idx-1]["start"] |
|
p_e = list_entity[idx-1]["end"] |
|
tmp_word = text[p_s:p_e] + text[start:end] |
|
else: |
|
tmp_word = tmp_word + text[start:end] |
|
word = tmp_word |
|
else: |
|
tmp_word = "" |
|
word = text[start:end] |
|
|
|
if tag == "l": |
|
word = word |
|
elif tag == "u": |
|
word = word.capitalize() |
|
|
|
else: |
|
if tag[-1] == "l": |
|
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in) |
|
elif tag[-1] == "u": |
|
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in) |
|
|
|
if tag != "l": |
|
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>' |
|
|
|
if subword == True: |
|
result_words[-1] = word |
|
else: |
|
result_words.append(word) |
|
|
|
return " ".join(result_words) |
|
|
|
|
|
def link(link, text, **style): |
|
return a(_href=link, _target="_blank", style=styles(**style))(text) |
|
|
|
def layout(*args): |
|
|
|
style = """ |
|
<style> |
|
# MainMenu {visibility: hidden;} |
|
footer {visibility: hidden;} |
|
.stApp { bottom: 105px; } |
|
</style> |
|
""" |
|
|
|
style_div = styles( |
|
position="fixed", |
|
left=0, |
|
bottom=0, |
|
margin=px(0, 0, 0, 0), |
|
width=percent(100), |
|
color="black", |
|
text_align="center", |
|
height="auto", |
|
opacity=1 |
|
) |
|
|
|
style_hr = styles( |
|
display="block", |
|
margin=px(8, 8, "auto", "auto"), |
|
border_style="inset", |
|
border_width=px(2) |
|
) |
|
|
|
body = p() |
|
foot = div( |
|
style=style_div |
|
)( |
|
hr( |
|
style=style_hr |
|
), |
|
body |
|
) |
|
|
|
st.markdown(style, unsafe_allow_html=True) |
|
|
|
for arg in args: |
|
if isinstance(arg, str): |
|
body(arg) |
|
|
|
elif isinstance(arg, HtmlElement): |
|
body(arg) |
|
|
|
st.markdown(str(foot), unsafe_allow_html=True) |
|
|
|
|
|
def img_to_bytes(img_path): |
|
img_bytes = Path(img_path).read_bytes() |
|
encoded = base64.b64encode(img_bytes).decode() |
|
return encoded |
|
|
|
|
|
def footer(): |
|
logo_path = Path(__file__).with_name("vocali_logo.jpg").parent.absolute() |
|
funding_path = Path(__file__).with_name("logo_funding.png").parent.absolute() |
|
|
|
myargs = [ |
|
"Made in ", |
|
|
|
"<img src='data:image/jpg;base64,{}' class='img-fluid' width='50' height='50'>".format( |
|
img_to_bytes(str(logo_path) + "/vocali_logo.jpg") |
|
), |
|
link("https://vocali.net/", "VÓCALI"), |
|
" with funding ", |
|
|
|
"<img src='data:image/png;base64,{}' class='img-fluid' width='250' height='50'>".format( |
|
img_to_bytes(str(funding_path) + "/logo_funding.png") |
|
), |
|
br(), |
|
"This work was funded by the Spanish Government, the Spanish Ministry of Economy and Digital Transformation through the Digital Transformation through the 'Recovery, Transformation and Resilience Plan' and also funded by the European Union NextGenerationEU/PRTR through the research project 2021/C005/0015007", |
|
] |
|
layout(*myargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if "text" not in st.session_state: |
|
st.session_state.text = "" |
|
|
|
st.title('Sanivert Punctuation And Capitalization Restoration') |
|
model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert") |
|
tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert") |
|
pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es) |
|
|
|
model_ca = AutoModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert") |
|
tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert") |
|
pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca) |
|
|
|
model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert") |
|
tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert") |
|
pipe_pt = pipeline("token-classification", model=model_pt, tokenizer=tokenizer_pt) |
|
|
|
|
|
input_text = st.selectbox( |
|
label = "Choose an language", |
|
options = ["Spanish", "Portuguese", "Catalan"] |
|
) |
|
|
|
st.subheader("Enter the text to be analyzed.") |
|
st.text_input('Enter text', key='widget', on_change=clear_text) |
|
text = st.session_state.text |
|
print(text) |
|
if input_text == "Spanish": |
|
result_pipe = pipe_es(text) |
|
out = get_result_text_es_pt(result_pipe, text, "es") |
|
elif input_text == "Portuguese": |
|
result_pipe = pipe_pt(text) |
|
out = get_result_text_es_pt(result_pipe, text, "pt") |
|
elif input_text == "Catalan": |
|
result_pipe = pipe_ca(text) |
|
out = get_result_text_ca(result_pipe, text) |
|
|
|
st.markdown(out, unsafe_allow_html=True) |
|
footer() |