File size: 4,959 Bytes
b3b8331 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import streamlit as st
from PIL import Image
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
from streamlit_extras.app_logo import add_logo
def logo():
add_logo("vocali_logo.jpeg", height=300)
def get_result_text_es_pt (list_entity, text, lang):
result_words = []
if lang == "es":
punc_tags = ['¿', '?', '¡', '!', ',', '.', ':']
else:
punc_tags = ['?', '!', ',', '.', ':']
for entity in list_entity:
tag = entity["entity"]
word = entity["word"]
start = entity["start"]
end = entity["end"]
# check punctuation
punc_in = next((p for p in punc_tags if p in tag), "")
subword = False
# check subwords
if word[0] == "#":
subword = True
if punc_in != "":
word = result_words[-1].replace(punc_in, "") + text[start:end]
else:
word = result_words[-1] + text[start:end]
if tag == "l":
word = word
elif tag == "u":
word = word.capitalize()
# case with punctuation
else:
if tag[-1] == "l":
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
elif tag[-1] == "u":
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
if tag != "l":
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
if subword == True:
result_words[-1] = word
else:
result_words.append(word)
return " ".join(result_words)
def get_result_text_ca (list_entity, text):
result_words = []
punc_tags = ['?', '!', ',', '.', ':']
for entity in list_entity:
start = entity["start"]
end = entity["end"]
tag = entity["entity"]
word = entity["word"]
# check punctuation
punc_in = next((p for p in punc_tags if p in tag), "")
subword = False
# check subwords
if word[0] != "Ġ":
subword = True
if punc_in != "":
word = result_words[-1].replace(punc_in, "") + text[start:end]
else:
word = result_words[-1] + text[start:end]
else:
word = text[start:end]
if tag == "l":
word = word
elif tag == "u":
word = word.capitalize()
# case with punctuation
else:
if tag[-1] == "l":
word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
elif tag[-1] == "u":
word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
if tag != "l":
word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
if subword == True:
result_words[-1] = word
else:
result_words.append(word)
return " ".join(result_words)
if __name__ == "__main__":
logo()
st.title('Sanivert Punctuation And Capitalization Restoration')
model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es)
model_ca = ModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
pipe_pt = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
input_text = st.selectbox(
label = "Choose an language",
options = ["Spanish", "Portuguese", "Catalan"]
)
st.subheader("Enter the text to be analyzed.")
text = st.text_input('Enter text') #text is stored in this variable
if input_text == "Spanish":
result_pipe = pipe_es(text)
out = get_result_text_es_pt(result_pipe, text, "es")
elif input_text == "Portuguese":
result_pipe = pipe_pt(text)
out = get_result_text_es_pt(result_pipe, text, "pt")
elif input_text == "Catalan":
result_pipe = pipe_ca(text)
out = get_result_text_ca(result_pipe, text)
out = get_prediction(text, input_text)
st.markdown(out, unsafe_allow_html=True)
text = ""
|