yeshpanovrustem commited on
Commit
ab2661e
·
1 Parent(s): 2da2475

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -102
app.py CHANGED
@@ -1,106 +1,211 @@
 
1
  from nltk.tokenize import word_tokenize
 
2
  import streamlit as st
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
5
 
6
- # use @st.cache decorator to cache model because it is too large, we do not want to reload it every time
7
- # use allow_output_mutation = True to tell streamlit that model should be treated as immutable object — singleton
8
- @st.cache(allow_output_mutation = True)
9
-
10
- # load model and tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
12
- model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
13
-
14
- labels_dict = {0: 'O',
15
- 1: 'B-ADAGE',
16
- 2: 'I-ADAGE',
17
- 3: 'B-ART',
18
- 4: 'I-ART',
19
- 5: 'B-CARDINAL',
20
- 6: 'I-CARDINAL',
21
- 7: 'B-CONTACT',
22
- 8: 'I-CONTACT',
23
- 9: 'B-DATE',
24
- 10: 'I-DATE',
25
- 11: 'B-DISEASE',
26
- 12: 'I-DISEASE',
27
- 13: 'B-EVENT',
28
- 14: 'I-EVENT',
29
- 15: 'B-FACILITY',
30
- 16: 'I-FACILITY',
31
- 17: 'B-GPE',
32
- 18: 'I-GPE',
33
- 19: 'B-LANGUAGE',
34
- 20: 'I-LANGUAGE',
35
- 21: 'B-LAW',
36
- 22: 'I-LAW',
37
- 23: 'B-LOCATION',
38
- 24: 'I-LOCATION',
39
- 25: 'B-MISCELLANEOUS',
40
- 26: 'I-MISCELLANEOUS',
41
- 27: 'B-MONEY',
42
- 28: 'I-MONEY',
43
- 29: 'B-NON_HUMAN',
44
- 30: 'I-NON_HUMAN',
45
- 31: 'B-NORP',
46
- 32: 'I-NORP',
47
- 33: 'B-ORDINAL',
48
- 34: 'I-ORDINAL',
49
- 35: 'B-ORGANISATION',
50
- 36: 'I-ORGANISATION',
51
- 37: 'B-PERSON',
52
- 38: 'I-PERSON',
53
- 39: 'B-PERCENTAGE',
54
- 40: 'I-PERCENTAGE',
55
- 41: 'B-POSITION',
56
- 42: 'I-POSITION',
57
- 43: 'B-PRODUCT',
58
- 44: 'I-PRODUCT',
59
- 45: 'B-PROJECT',
60
- 46: 'I-PROJECT',
61
- 47: 'B-QUANTITY',
62
- 48: 'I-QUANTITY',
63
- 49: 'B-TIME',
64
- 50: 'I-TIME'}
65
-
66
- # # define function for ner
67
- # def label_sentence(text):
68
- # load pipeline
69
- nlp = pipeline("ner", model = model, tokenizer = tokenizer)
70
- example = "Қазақстан Республикасы — Шығыс Еуропа мен Орталық Азияда орналасқан мемлекет."
71
-
72
- single_sentence_tokens = word_tokenize(example)
73
- tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt")
74
- tokens = tokenized_input.tokens()
75
- output = model(**tokenized_input).logits
76
- predictions = torch.argmax(output, dim = 2)
77
-
78
- # convert label IDs to label names
79
- word_ids = tokenized_input.word_ids(batch_index = 0)
80
- # print(count, word_ids)
81
- previous_word_id = None
82
- labels = []
83
- for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()):
84
- # # Special tokens have a word id that is None. We set the label to -100 so they are
85
- # # automatically ignored in the loss function.
86
- # print(token, word_id, prediction)
87
- if word_id is None or word_id == previous_word_id:
88
- continue
89
- elif word_id != previous_word_id:
90
- labels.append(labels_dict[prediction])
91
- previous_word_id = word_id
92
- # print(len(sentence_tokens), sentence_tokens)
93
- # print(len(labels), labels)
94
- assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!"
95
-
96
- for token, label in zip(single_sentence_tokens, labels):
97
- print(token, label)
98
-
99
-
100
-
101
- # st.markdown("# Hello")
102
- # # st.set_page_config(page_title = "Kazakh Named Entity Recognition", page_icon = "🔍")
103
- # # st.title("🔍 Kazakh Named Entity Recognition")
104
-
105
- # x = st.slider('Select a value')
106
- # st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from annotated_text import annotated_text, parameters, annotation
2
  from nltk.tokenize import word_tokenize
3
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
4
  import streamlit as st
5
  import torch
 
6
 
7
+ # add the caching decorator and use custom text for spinner
8
+ @st.cache_resource(show_spinner = "Loading the model...")
9
+
10
+ def label_text(text):
11
+ if text != "":
12
+ tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
13
+ model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh")
14
+ nlp = pipeline("ner", model = model, tokenizer = tokenizer)
15
+
16
+ labels_dict = {0: 'O',
17
+ 1: 'B-ADAGE',
18
+ 2: 'I-ADAGE',
19
+ 3: 'B-ART',
20
+ 4: 'I-ART',
21
+ 5: 'B-CARDINAL',
22
+ 6: 'I-CARDINAL',
23
+ 7: 'B-CONTACT',
24
+ 8: 'I-CONTACT',
25
+ 9: 'B-DATE',
26
+ 10: 'I-DATE',
27
+ 11: 'B-DISEASE',
28
+ 12: 'I-DISEASE',
29
+ 13: 'B-EVENT',
30
+ 14: 'I-EVENT',
31
+ 15: 'B-FACILITY',
32
+ 16: 'I-FACILITY',
33
+ 17: 'B-GPE',
34
+ 18: 'I-GPE',
35
+ 19: 'B-LANGUAGE',
36
+ 20: 'I-LANGUAGE',
37
+ 21: 'B-LAW',
38
+ 22: 'I-LAW',
39
+ 23: 'B-LOCATION',
40
+ 24: 'I-LOCATION',
41
+ 25: 'B-MISCELLANEOUS',
42
+ 26: 'I-MISCELLANEOUS',
43
+ 27: 'B-MONEY',
44
+ 28: 'I-MONEY',
45
+ 29: 'B-NON_HUMAN',
46
+ 30: 'I-NON_HUMAN',
47
+ 31: 'B-NORP',
48
+ 32: 'I-NORP',
49
+ 33: 'B-ORDINAL',
50
+ 34: 'I-ORDINAL',
51
+ 35: 'B-ORGANISATION',
52
+ 36: 'I-ORGANISATION',
53
+ 37: 'B-PERSON',
54
+ 38: 'I-PERSON',
55
+ 39: 'B-PERCENTAGE',
56
+ 40: 'I-PERCENTAGE',
57
+ 41: 'B-POSITION',
58
+ 42: 'I-POSITION',
59
+ 43: 'B-PRODUCT',
60
+ 44: 'I-PRODUCT',
61
+ 45: 'B-PROJECT',
62
+ 46: 'I-PROJECT',
63
+ 47: 'B-QUANTITY',
64
+ 48: 'I-QUANTITY',
65
+ 49: 'B-TIME',
66
+ 50: 'I-TIME'}
67
+
68
+ single_sentence_tokens = word_tokenize(text)
69
+ tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt")
70
+ tokens = tokenized_input.tokens()
71
+ output = model(**tokenized_input).logits
72
+ predictions = torch.argmax(output, dim = 2)
73
+
74
+ # convert label IDs to label names
75
+ word_ids = tokenized_input.word_ids(batch_index = 0)
76
+ previous_word_id = None
77
+ labels = []
78
+ for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()):
79
+ # # Special tokens have a word id that is None. We set the label to -100 so they are
80
+ # # automatically ignored in the loss function.
81
+ if word_id is None or word_id == previous_word_id:
82
+ continue
83
+ elif word_id != previous_word_id:
84
+ labels.append(labels_dict[prediction])
85
+ previous_word_id = word_id
86
+ assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!"
87
+
88
+ sentence_tokens = []
89
+ sentence_labels = []
90
+
91
+ token_list = []
92
+ label_list = []
93
+
94
+ previous_token = ""
95
+ previous_label = ""
96
+
97
+ for token, label in zip(single_sentence_tokens, labels):
98
+ current_token = token
99
+ current_label = label
100
+
101
+ # starting loop
102
+ if previous_label == "":
103
+ previous_token = current_token
104
+ previous_label = current_label
105
+
106
+ # collecting compound named entities
107
+ elif (previous_label.startswith("B-")) and (current_label.startswith("I-")):
108
+ token_list.append(previous_token)
109
+ label_list.append(previous_label)
110
+ elif (previous_label.startswith("I-")) and (current_label.startswith("I-")):
111
+ token_list.append(previous_token)
112
+ label_list.append(previous_label)
113
+ elif (previous_label.startswith("I-")) and (not current_label.startswith("I-")):
114
+ token_list.append(previous_token)
115
+ label_list.append(previous_label)
116
+ sentence_tokens.append(token_list)
117
+ sentence_labels.append(label_list)
118
+ token_list = []
119
+ label_list = []
120
+ # collecting single named entities:
121
+ elif (not previous_label.startswith("I-")) and (not current_label.startswith("I-")):
122
+ token_list.append(previous_token)
123
+ label_list.append(previous_label)
124
+ sentence_tokens.append(token_list)
125
+ sentence_labels.append(label_list)
126
+ token_list = []
127
+ label_list = []
128
+ previous_token = current_token
129
+ previous_label = current_label
130
+ token_list.append(previous_token)
131
+ label_list.append(previous_label)
132
+ sentence_tokens.append(token_list)
133
+ sentence_labels.append(label_list)
134
+
135
+ output = []
136
+ for sentence_token, sentence_label in zip(sentence_tokens, sentence_labels):
137
+ if len(sentence_label[0]) > 1:
138
+ if len(sentence_label) > 1:
139
+ output.append((" ".join(sentence_token), sentence_label[0].split("-")[1]))
140
+ else:
141
+ output.append((sentence_token[0], sentence_label[0].split("-")[1]))
142
+ else:
143
+ # output.append((sentence_token[0], sentence_label[0]))
144
+ output.append(sentence_token[0])
145
+
146
+ modified_output = []
147
+ for element in output:
148
+ if not isinstance(element, tuple):
149
+ if element.isalnum():
150
+ modified_output.append(' ' + element + ' ')
151
+ else:
152
+ modified_output.append(' ' + element + ' ')
153
+ else:
154
+ tuple_first = f" {element[0]} "
155
+ tuple_second = element[1]
156
+ new_tuple = (tuple_first, tuple_second)
157
+ modified_output.append(new_tuple)
158
+ else:
159
+ return st.markdown("<p id = 'warning'>PLEASE INSERT YOUR TEXT</p>", unsafe_allow_html = True)
160
+ return modified_output
161
+
162
+ #########################
163
+ #### CREATE SIDEBAR #####
164
+ #########################
165
+
166
+ with open("style.css") as f:
167
+ css = f.read()
168
+
169
+ st.sidebar.markdown(f'<style>{css}</style>', unsafe_allow_html = True)
170
+
171
+ st.sidebar.markdown("<h1>Kazakh NER</h1>", unsafe_allow_html = True)
172
+ st.sidebar.markdown("<h2>Named entity classes</h2>", unsafe_allow_html = True)
173
+
174
+ with st.sidebar.expander("ADAGE"): st.write("Well-known Kazakh proverbs and sayings")
175
+ with st.sidebar.expander("ART"): st.write("Titles of books, songs, television programmes, etc.")
176
+ with st.sidebar.expander("CARDINAL"): st.write("Cardinal numbers, including whole numbers, fractions, and decimals")
177
+ with st.sidebar.expander("CONTACT"): st.write("Addresses, emails, phone numbers, URLs")
178
+ with st.sidebar.expander("DATE"): st.write("Dates or periods of 24 hours or more")
179
+ with st.sidebar.expander("DISEASE"): st.write("Diseases or medical conditions")
180
+ with st.sidebar.expander("EVENT"): st.write("Named events and phenomena")
181
+ with st.sidebar.expander("FACILITY"): st.write("Names of man-made structures")
182
+ with st.sidebar.expander("GPE"): st.write("Names of geopolitical entities")
183
+ with st.sidebar.expander("LANGUAGE"): st.write("Named languages")
184
+ with st.sidebar.expander("LAW"): st.write("Named legal documents")
185
+ with st.sidebar.expander("LOCATION"): st.write("Names of geographical locations other than GPEs")
186
+ with st.sidebar.expander("MISCELLANEOUS"): st.write("Entities of interest but hard to assign a proper tag to")
187
+ with st.sidebar.expander("MONEY"): st.write("Monetary values")
188
+ with st.sidebar.expander("NON_HUMAN"): st.write("Names of pets, animals or non-human creatures")
189
+ with st.sidebar.expander("NORP"): st.write("Adjectival forms of GPE and LOCATION; named religions, etc.")
190
+ with st.sidebar.expander("ORDINAL"): st.write("Ordinal numbers, including adverbials")
191
+ with st.sidebar.expander("ORGANISATION"): st.write("Names of companies, government agencies, etc.")
192
+ with st.sidebar.expander("PERCENTAGES"): st.write("Percentages")
193
+ with st.sidebar.expander("PERSON"): st.write("Names of persons")
194
+ with st.sidebar.expander("POSITION"): st.write("Names of posts and job titles")
195
+ with st.sidebar.expander("PRODUCT"): st.write("Names of products")
196
+ with st.sidebar.expander("PROJECT"): st.write("Names of projects, policies, plans, etc.")
197
+ with st.sidebar.expander("QUANTITY"): st.write("Length, distance, etc. measurements")
198
+ with st.sidebar.expander("TIME"): st.write("Times of day and time duration less than 24 hours")
199
+
200
+ ######################
201
+ #### CREATE FORM #####
202
+ ######################
203
+
204
+ text_field = st.form(key = 'text_field')
205
+ form_text = text_field.text_input('Insert your text here')
206
+ submit = text_field.form_submit_button('Submit')
207
+
208
+ st.markdown('Press **Submit** to have your text labelled')
209
+
210
+ if submit:
211
+ annotated_text(label_text(form_text))