import streamlit as st from PIL import Image from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts from htbuilder.units import percent, px from htbuilder.funcs import rgba, rgb from pathlib import Path import base64 import pandas as pd def clear_text(): st.session_state.text = st.session_state.widget st.session_state.widget = "" def get_result_text_es_pt (list_entity, text, lang): result_words = [] tmp_word = "" if lang == "es": punc_tags = ['¿', '?', '¡', '!', ',', '.', ':'] else: punc_tags = ['?', '!', ',', '.', ':'] for idx, entity in enumerate(list_entity): tag = entity["entity"] word = entity["word"] start = entity["start"] end = entity["end"] # check punctuation punc_in = next((p for p in punc_tags if p in tag), "") subword = False # check subwords if word[0] == "#": subword = True if tmp_word == "": p_s = list_entity[idx-1]["start"] p_e = list_entity[idx-1]["end"] tmp_word = text[p_s:p_e] + text[start:end] else: tmp_word = tmp_word + text[start:end] word = tmp_word else: tmp_word = "" word = text[start:end] if tag == "l": word = word elif tag == "u": word = word.capitalize() # case with punctuation else: if tag[-1] == "l": word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in) elif tag[-1] == "u": word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in) if tag != "l": word = '' + word + '' if subword == True: result_words[-1] = word else: result_words.append(word) return " ".join(result_words) def get_result_text_ca (list_entity, text): result_words = [] punc_tags = ['?', '!', ',', '.', ':'] tmp_word = "" for idx, entity in enumerate(list_entity): start = entity["start"] end = entity["end"] tag = entity["entity"] word = entity["word"] # check punctuation punc_in = next((p for p in punc_tags if p in tag), "") subword = False # check subwords if word[0] != "Ġ": subword = True if tmp_word == "": p_s = list_entity[idx-1]["start"] p_e = list_entity[idx-1]["end"] tmp_word = text[p_s:p_e] + text[start:end] else: tmp_word = tmp_word + text[start:end] word = tmp_word else: tmp_word = "" word = text[start:end] if tag == "l": word = word elif tag == "u": word = word.capitalize() # case with punctuation else: if tag[-1] == "l": word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in) elif tag[-1] == "u": word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in) if tag != "l": word = '' + word + '' if subword == True: result_words[-1] = word else: result_words.append(word) return " ".join(result_words) def link(link, text, **style): return a(_href=link, _target="_blank", style=styles(**style))(text) def layout(*args): style = """ """ style_div = styles( position="fixed", left=0, bottom=0, margin=px(0, 0, 0, 0), width=percent(100), color="black", text_align="center", height="auto", opacity=1 ) style_hr = styles( display="block", margin=px(8, 8, "auto", "auto"), border_style="inset", border_width=px(2) ) body = p() foot = div( style=style_div )( hr( style=style_hr ), body ) st.markdown(style, unsafe_allow_html=True) for arg in args: if isinstance(arg, str): body(arg) elif isinstance(arg, HtmlElement): body(arg) st.markdown(str(foot), unsafe_allow_html=True) def img_to_bytes(img_path): img_bytes = Path(img_path).read_bytes() encoded = base64.b64encode(img_bytes).decode() return encoded def footer(): logo_path = Path(__file__).with_name("vocali_logo.jpg").parent.absolute() funding_path = Path(__file__).with_name("logo_funding.png").parent.absolute() myargs = [ "Made in ", "".format( img_to_bytes(str(logo_path) + "/vocali_logo.jpg") ), " with funding ", "".format( img_to_bytes(str(funding_path) + "/logo_funding.png") ), br(), "This work was funded by the Spanish Government, the Spanish Ministry of Economy and Digital Transformation through the Digital Transformation through the 'Recovery, Transformation and Resilience Plan' and also funded by the European Union NextGenerationEU/PRTR through the research project 2021/C005/0015007", ] layout(*myargs) if __name__ == "__main__": if "text" not in st.session_state: st.session_state.text = "" st.title('Sanivert Punctuation And Capitalization Restoration') st.markdown("The model restores the following punctuation -- [? ! , . :] and also the capitalization of words.") model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert") tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert") pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es) model_ca = AutoModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert") tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert") pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca) model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert") tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert") pipe_pt = pipeline("token-classification", model=model_pt, tokenizer=tokenizer_pt) st.subheader('Text examples in Spanish') data_spanish = [['has tenido alguna enfermedad en la última semana', '¿Has tenido alguna enfermedad en la última semana?'], ['sufre la enfermedad de parkinson', 'Sufre la enfermedad de Parkinson'], ['el paciente presenta los siguientes síntomas náuseas vértigo disnea fiebre y dolor abdominal', 'El paciente presenta los siguientes síntomas: náuseas, vértigo, disnea, fiebre y dolor abdominal.']] st.table(pd.DataFrame(data_spanish, columns=['Input', 'Output'])) st.subheader('Text examples in Catalan') data = [['has tingut alguna malaltia a la darrera setmana', 'Has tingut alguna malaltia a la darrera setmana?'], ['pateix la malaltia de parkinson', 'Pateix la malaltia de Parkinson.'], ["pacient presenta els següents símptomes nàusees vertigen dispnea febre i dolor abdominal", "Pacient presenta els següents símptomes: nàusees, vertigen, dispnea, febre i dolor abdominal."]] st.table(pd.DataFrame(data, columns=['Input', 'Output'])) st.subheader('Text examples in Portuguese') data_pt = [['sofre da doença de parkinson', 'Sofre da doença de parkinson?'], ['teve alguma doença na última semana', 'Teve alguma doença na última semana?'], ['o doente apresenta os seguintes sintomas náuseas vertigens dispneia febre e dor abdominal', 'O doente apresenta os seguintes sintomas: náuseas, vertigens, dispneia, febre e dor abdominal.']] st.table(pd.DataFrame(data_pt, columns=['Input', 'Output'])) input_text = st.selectbox( label = "Choose an language", options = ["Spanish", "Portuguese", "Catalan"] ) st.subheader("Enter the text to be analyzed.") st.text_input('Enter text', key='widget', on_change=clear_text) #text is stored in this variable text = st.session_state.text print(text) if input_text == "Spanish": result_pipe = pipe_es(text) out = get_result_text_es_pt(result_pipe, text, "es") elif input_text == "Portuguese": result_pipe = pipe_pt(text) out = get_result_text_es_pt(result_pipe, text, "pt") elif input_text == "Catalan": result_pipe = pipe_ca(text) out = get_result_text_ca(result_pipe, text) st.markdown(out, unsafe_allow_html=True) footer()