jcg00v commited on
Commit
9e95735
·
verified ·
1 Parent(s): 67037c0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -34
app.py CHANGED
@@ -3,10 +3,9 @@ from PIL import Image
3
  from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
4
  from streamlit_extras.app_logo import add_logo
5
 
6
-
7
  def logo():
8
- add_logo("vocali_logo.jpg", height=300)
9
-
10
 
11
  def get_result_text_es_pt (list_entity, text, lang):
12
  result_words = []
@@ -45,7 +44,7 @@ def get_result_text_es_pt (list_entity, text, lang):
45
  word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
46
 
47
  if tag != "l":
48
- word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>''
49
 
50
  if subword == True:
51
  result_words[-1] = word
@@ -90,6 +89,7 @@ def get_result_text_ca (list_entity, text):
90
  word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
91
  elif tag[-1] == "u":
92
  word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
 
93
  if tag != "l":
94
  word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
95
 
@@ -99,44 +99,41 @@ def get_result_text_ca (list_entity, text):
99
  result_words.append(word)
100
 
101
  return " ".join(result_words)
102
-
103
 
104
  if __name__ == "__main__":
105
- logo()
106
- st.title('Sanivert Punctuation And Capitalization Restoration')
107
-
108
- model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
109
  tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
110
- pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es)
111
 
112
- model_ca = ModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
113
  tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
114
- pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
115
 
116
- model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
117
  tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
118
- pipe_pt = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
119
 
120
- input_text = st.selectbox(
121
  label = "Choose an language",
122
  options = ["Spanish", "Portuguese", "Catalan"]
123
  )
124
-
125
- st.subheader("Enter the text to be analyzed.")
126
- text = st.text_input('Enter text') #text is stored in this variable
127
- if input_text == "Spanish":
128
- result_pipe = pipe_es(text)
129
- out = get_result_text_es_pt(result_pipe, text, "es")
130
- elif input_text == "Portuguese":
131
- result_pipe = pipe_pt(text)
132
- out = get_result_text_es_pt(result_pipe, text, "pt")
133
- elif input_text == "Catalan":
134
- result_pipe = pipe_ca(text)
135
- out = get_result_text_ca(result_pipe, text)
136
-
137
- out = get_prediction(text, input_text)
138
- st.markdown(out, unsafe_allow_html=True)
139
- text = ""
140
-
141
-
142
-
 
3
  from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
4
  from streamlit_extras.app_logo import add_logo
5
 
 
6
  def logo():
7
+ add_logo("vocali_logo.jpeg", height=300)
8
+
9
 
10
  def get_result_text_es_pt (list_entity, text, lang):
11
  result_words = []
 
44
  word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
45
 
46
  if tag != "l":
47
+ word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
48
 
49
  if subword == True:
50
  result_words[-1] = word
 
89
  word = (punc_in + word) if punc_in in ["¿", "¡"] else (word + punc_in)
90
  elif tag[-1] == "u":
91
  word = (punc_in + word.capitalize()) if punc_in in ["¿", "¡"] else (word.capitalize() + punc_in)
92
+
93
  if tag != "l":
94
  word = '<span style="font-weight:bold; color:rgb(142, 208, 129);">' + word + '</span>'
95
 
 
99
  result_words.append(word)
100
 
101
  return " ".join(result_words)
102
+
103
 
104
  if __name__ == "__main__":
105
+
106
+ logo()
107
+ st.title('Sanivert Punctuation And Capitalization Restoration')
108
+ model_es = AutoModelForTokenClassification.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
109
  tokenizer_es = AutoTokenizer.from_pretrained("VOCALINLP/spanish_capitalization_punctuation_restoration_sanivert")
110
+ pipe_es = pipeline("token-classification", model=model_es, tokenizer=tokenizer_es)
111
 
112
+ model_ca = AutoModelForTokenClassification.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
113
  tokenizer_ca = AutoTokenizer.from_pretrained("VOCALINLP/catalan_capitalization_punctuation_restoration_sanivert")
114
+ pipe_ca = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
115
 
116
+ model_pt = AutoModelForTokenClassification.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
117
  tokenizer_pt = AutoTokenizer.from_pretrained("VOCALINLP/portuguese_capitalization_punctuation_restoration_sanivert")
118
+ pipe_pt = pipeline("token-classification", model=model_ca, tokenizer=tokenizer_ca)
119
 
120
+ input_text = st.selectbox(
121
  label = "Choose an language",
122
  options = ["Spanish", "Portuguese", "Catalan"]
123
  )
124
+
125
+ st.subheader("Enter the text to be analyzed.")
126
+ text = st.text_input('Enter text') #text is stored in this variable
127
+ if input_text == "Spanish":
128
+ result_pipe = pipe_es(text)
129
+ out = get_result_text_es_pt(result_pipe, text, "es")
130
+ elif input_text == "Portuguese":
131
+ result_pipe = pipe_pt(text)
132
+ out = get_result_text_es_pt(result_pipe, text, "pt")
133
+ elif input_text == "Catalan":
134
+ result_pipe = pipe_ca(text)
135
+ out = get_result_text_ca(result_pipe, text)
136
+
137
+ out = get_prediction(text, input_text)
138
+ st.markdown(out, unsafe_allow_html=True)
139
+ text = ""