lalital commited on
Commit
6c4ffba
1 Parent(s): 995cf09

Add related files

Browse files
Files changed (3) hide show
  1. app.py +127 -0
  2. pipeline.py +289 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import partial
3
+ from typing import Callable, Dict
4
+ import transformers
5
+ from transformers import (
6
+ AutoModelForTokenClassification,
7
+ AutoTokenizer
8
+ )
9
+
10
+ from pipeline import (
11
+ TokenClassificationPipeline
12
+ )
13
+ import pythainlp
14
+ from pprint import pprint
15
+ from itertools import chain
16
+
17
+ import gradio as gr
18
+
19
+
20
+ ner_pipeline_group = TokenClassificationPipeline(
21
+ model=AutoModelForTokenClassification.from_pretrained(
22
+ 'airesearch/wangchanberta-base-att-spm-uncased',
23
+ revision='finetuned@thainer-ner'
24
+ ),
25
+ tokenizer=AutoTokenizer.from_pretrained(
26
+ 'airesearch/wangchanberta-base-att-spm-uncased',
27
+ revision='finetuned@thainer-ner'
28
+ ),
29
+ space_token='<_>',
30
+ lowercase=True,
31
+ group_entities=True,
32
+ strict=False,
33
+ )
34
+
35
+ color_mapper = {
36
+ "DATE": "#f94144",
37
+ "EMAIL":"#f3722c",
38
+ "LAW":"#f8961e",
39
+ "LEN":"#f9844a",
40
+ "LOCATION":"#f9c74f",
41
+ "MONEY":"#ffcb77",
42
+ "ORGANIZATION":"#f5cac3",
43
+ "PERCENT":"#90be6d",
44
+ "PERSON":"#bfd200",
45
+ "PHONE":"#43aa8b",
46
+ "TIME":"#4d908e",
47
+ "URL":"#577590",
48
+ "ZIP":"#90e0ef",
49
+ }
50
+
51
+
52
+ css_text = 'p{width: 700px; color: #333; border-radius: 3px; border: solid 1.5px #DDD; background-color: #FFF;\n margin: 10px;\n padding: 30px}\n'
53
+ for k,v in color_mapper.items():
54
+ css_text += "span."+f"{k.lower()}" \
55
+ +"{\n background-color: " \
56
+ +f"{v}"+"50;\n color: #333;\n border-right: 4px solid " \
57
+ +f"{v}"+";" \
58
+ + "\n align-items: center;" \
59
+ + "\n margin: 0;" \
60
+ + "\n padding: 2px 8px;" \
61
+ + "\n border-radius: 3px;\n}\n" \
62
+ +"span."+f"{k.lower()}"+"::after {" \
63
+ +"\npadding: 2px 1px;" \
64
+ +"font-size: 9.5px;" \
65
+ +"font-weight: bold;" \
66
+ +"font-family: Monaco;" \
67
+ +"vertical-align: super;" \
68
+ +"content: \"" + k.upper() + "\";" \
69
+ +"}\n" \
70
+
71
+
72
+ def modifiy_segment(text, tag, start, end):
73
+ replaced_text = text[:start] + f'<span class="{tag}">' + text[start:end] +'</span>' + text[end:]
74
+ return replaced_text, len(f'<span class="{tag}">') + len('</span>')
75
+
76
+
77
+ def render_doc_with_label(label: Dict, doc: str):
78
+ attribute_items = []
79
+ for i, ne_span in enumerate(label):
80
+ if ne_span['entity_group'] != 'O':
81
+ attribute_name = ne_span['entity_group']
82
+ attribute_name = attribute_name.lower()
83
+
84
+ begin_char_idx = ne_span['begin_char_index']
85
+
86
+ tagged_text = ne_span['word']
87
+ end_char_idx = begin_char_idx + len(tagged_text)
88
+
89
+ attribute_items.append((attribute_name, begin_char_idx, end_char_idx))
90
+
91
+ attribute_items = sorted(attribute_items, key=lambda x: (x[1]))
92
+ print(f'attribute_items: {attribute_items}')
93
+
94
+ acc_n_extra_chars = 0
95
+ modified_segment = doc
96
+ for _selected_attribute_item in attribute_items:
97
+
98
+ tag, start, end = _selected_attribute_item[0], _selected_attribute_item[1], _selected_attribute_item[2]
99
+
100
+ modified_segment, n_extra_chars = modifiy_segment(modified_segment, tag, start + acc_n_extra_chars, end + acc_n_extra_chars)
101
+ acc_n_extra_chars += n_extra_chars
102
+
103
+ return f'<style>{css_text}</style><p>{modified_segment}</p>'
104
+
105
+ def ner_tagging(text: str):
106
+ results = ner_pipeline_group(text)
107
+ print(f'results:\n{results}')
108
+ html_text = render_doc_with_label(results, text)
109
+
110
+ return json.dumps(results, ensure_ascii=False, indent=4), html_text
111
+
112
+
113
+ demo = gr.Interface(fn=ner_tagging,
114
+ inputs=gr.Textbox(lines=5, placeholder='Input text in Thai', label='Input text'),
115
+ examples=[
116
+ ["ไมโครซอฟท์ได้จัดจำหน่ายบนแพลตฟอร์มไมโครซอฟท์ วินโดวส์ ในเดือนเมษายน 2020"],
117
+ ['ชัชชาติ สิทธิพันธุ์ ผู้ว่าราชการกรุงเทพมหานคร (กทม.) คนที่ 17 เตรียมเข้ารับตำแหน่งอย่างเป็นทางการและเปิดตัวทีมงานในช่วงบ่ายวันนี้ (1 มิ.ย.) หลังรับมอบหนังสือรับรองการเป็นผู้ว่าฯ กทม. ที่สำนักงานคณะกรรมการการเลือกตั้ง (กกต.)'],
118
+ ["สถาบันวิทยาศาสตร์ทางทะเล มหาวิทยาลัยบูรพา เปิดให้บริการมายาวนานกว่า 30 ปี ตั้งอยู่บริเวณด้านหน้า มหาวิทยาลัยบูรพา บนเนื้อที่กว่า 30 ไร่ เป็นสถานที่ท่องเที่ยว ที่จัดแสดงเพื่อให้ความรู้��กี่ยวกับวิทยาศาสตร์ทางทะเล สิ่งมีชีวิตและความเป็นอยู่ของสัตว์ทะเลชนิดต่างๆที่อาศัยอยู่ในเขตน่านน้ำของไทย"],
119
+
120
+
121
+ ],
122
+
123
+ outputs=[gr.Textbox(), gr.HTML()])
124
+
125
+ print(f'\nINFO: transformers.__version__: {transformers.__version__}')
126
+ print(f'\nINFO: pythainlp.__version__: {pythainlp.__version__}')
127
+ demo.launch()
pipeline.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Callable, List, Tuple, Union
4
+ from functools import partial
5
+ import itertools
6
+
7
+ from seqeval.scheme import Tokens, IOB2, IOBES
8
+
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.tokenization_utils import PreTrainedTokenizerBase
11
+ from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize
12
+ newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm')
13
+
14
+ from thai2transformers.preprocess import rm_useless_spaces
15
+
16
+ SPIECE = '▁'
17
+
18
+ class TokenClassificationPipeline:
19
+
20
+ def __init__(self,
21
+ model: PreTrainedModel,
22
+ tokenizer: PreTrainedTokenizerBase,
23
+ pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer,
24
+ lowercase=False,
25
+ space_token='<_>',
26
+ device: int = -1,
27
+ group_entities: bool = False,
28
+ strict: bool = False,
29
+ tag_delimiter: str = '-',
30
+ scheme: str = 'IOB',
31
+ use_crf=False,
32
+ remove_spiece=True):
33
+
34
+ super().__init__()
35
+
36
+ assert isinstance(tokenizer, PreTrainedTokenizerBase)
37
+ # assert isinstance(model, PreTrainedModel)
38
+
39
+ self.model = model
40
+ self.tokenizer = tokenizer
41
+ self.pretokenizer = pretokenizer
42
+ self.lowercase = lowercase
43
+ self.space_token = space_token
44
+ self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}'
45
+ self.group_entities = group_entities
46
+ self.strict = strict
47
+ self.tag_delimiter = tag_delimiter
48
+ self.scheme = scheme
49
+ self.id2label = self.model.config.id2label
50
+ self.label2id = self.model.config.label2id
51
+ self.use_crf = use_crf
52
+ self.remove_spiece = remove_spiece
53
+ self.model.to(self.device)
54
+
55
+ def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
56
+
57
+ if self.lowercase:
58
+ inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs))
59
+
60
+ inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs))
61
+
62
+ tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs))
63
+
64
+ tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \
65
+ list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens))
66
+
67
+ return tokens
68
+
69
+ def _inference(self, input: str):
70
+
71
+ tokens = [[self.tokenizer.bos_token]] + \
72
+ [self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \
73
+ [[self.tokenizer.eos_token]]
74
+ ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens]
75
+ flatten_tokens = list(itertools.chain(*tokens))
76
+ flatten_ids = list(itertools.chain(*ids))
77
+
78
+ input_ids = torch.LongTensor([flatten_ids]).to(self.device)
79
+
80
+ if self.use_crf:
81
+ out = self.model(input_ids=input_ids)
82
+ else:
83
+ out = self.model(input_ids=input_ids, return_dict=True)
84
+ probs = torch.softmax(out['logits'], dim=-1)
85
+ vals, indices = probs.topk(1)
86
+ indices_np = indices.detach().cpu().numpy().reshape(-1)
87
+
88
+ list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] ))
89
+ merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids)
90
+ if self.remove_spiece:
91
+ merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds))
92
+
93
+ # remove start and end tokens
94
+ merged_preds_removed_bos_eos = merged_preds[1:-1]
95
+ # convert to list of Dict objects
96
+ merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \
97
+ for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ]
98
+
99
+ if (not self.group_entities or self.scheme == None) and self.strict == True:
100
+ return merged_preds_return_dict
101
+ elif not self.group_entities and self.strict == False:
102
+
103
+ tags = list(map(lambda x: x['entity'], merged_preds_return_dict))
104
+ processed_tags = self._fix_incorrect_tags(tags)
105
+ for i, item in enumerate(merged_preds_return_dict):
106
+ merged_preds_return_dict[i]['entity'] = processed_tags[i]
107
+ return merged_preds_return_dict
108
+ elif self.group_entities:
109
+ return self._group_entities(merged_preds_removed_bos_eos)
110
+
111
+ def __call__(self, inputs: Union[str, List[str]]):
112
+
113
+ """
114
+
115
+ """
116
+ if type(inputs) == str:
117
+ return self._inference(inputs)
118
+
119
+ if type(inputs) == list:
120
+ results = [ self._inference(text) for text in inputs]
121
+ return results
122
+
123
+
124
+ def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]):
125
+
126
+ token_mapping = [ ]
127
+ for i in range(0, len(ids)):
128
+ for j in range(0, len(ids[i])):
129
+ token_mapping.append(i)
130
+
131
+ grouped_subtokens = []
132
+ _subtoken = []
133
+ prev_idx = 0
134
+
135
+ for i, (subtoken, label) in enumerate(preds):
136
+
137
+ current_idx = token_mapping[i]
138
+ if prev_idx != current_idx:
139
+ grouped_subtokens.append(_subtoken)
140
+ _subtoken = [(subtoken, label)]
141
+ if i == len(preds) -1:
142
+ _subtoken = [(subtoken, label)]
143
+ grouped_subtokens.append(_subtoken)
144
+ elif i == len(preds) -1:
145
+ _subtoken += [(subtoken, label)]
146
+ grouped_subtokens.append(_subtoken)
147
+ else:
148
+ _subtoken += [(subtoken, label)]
149
+ prev_idx = current_idx
150
+
151
+ merged_subtokens = []
152
+ _merged_subtoken = ''
153
+ for subtoken_group in grouped_subtokens:
154
+
155
+ first_token_pred = subtoken_group[0][1]
156
+ _merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group)))
157
+ merged_subtokens.append((_merged_subtoken, first_token_pred))
158
+ return merged_subtokens
159
+
160
+ def _fix_incorrect_tags(self, tags: List[str]) -> List[str]:
161
+
162
+ I_PREFIX = f'I{self.tag_delimiter}'
163
+ E_PREFIX = f'E{self.tag_delimiter}'
164
+ B_PREFIX = f'B{self.tag_delimiter}'
165
+ O_PREFIX = 'O'
166
+
167
+ previous_tag_ne = None
168
+ for i, current_tag in enumerate(tags):
169
+
170
+ current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX
171
+
172
+ if i == 0 and (current_tag.startswith(I_PREFIX) or \
173
+ current_tag.startswith(E_PREFIX)):
174
+ # if a NE tag (with I-, or E- prefix) occuring at the begining of sentence
175
+ # e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O)
176
+ # then, change the prefix of the current tag to B{tag_delimiter}
177
+ tags[i] = B_PREFIX + tags[i][2:]
178
+ elif i >= 1 and tags[i-1] == O_PREFIX and (
179
+ current_tag.startswith(I_PREFIX) or \
180
+ current_tag.startswith(E_PREFIX)):
181
+ # if a NE tag (with I-, or E- prefix) occuring after O tag
182
+ # e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O)
183
+ # then, change the prefix of the current tag to B{tag_delimiter}
184
+ tags[i] = B_PREFIX + tags[i][2:]
185
+ elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \
186
+ tags[i-1].startswith(E_PREFIX) or \
187
+ tags[i-1].startswith(B_PREFIX)) and \
188
+ ( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \
189
+ previous_tag_ne != current_tag_ne:
190
+ # if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE
191
+ # e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER)
192
+ # then, change the prefix of the current tag to B{tag_delimiter}
193
+ tags[i] = B_PREFIX + tags[i][2:]
194
+ elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and (
195
+ current_tag.startswith(I_PREFIX) or \
196
+ current_tag.startswith(E_PREFIX)):
197
+ # if a NE tag (with I-, or E- prefix) occuring at the end of sentence
198
+ # e.g. (O, O, I-LOC) , (O, O, E-LOC)
199
+ # then, change the prefix of the current tag to B{tag_delimiter}
200
+ tags[i] = B_PREFIX + tags[i][2:]
201
+
202
+ previous_tag_ne = current_tag_ne
203
+
204
+ return tags
205
+
206
+ def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
207
+
208
+ if self.scheme not in ['IOB', 'IOBES', 'IOBE']:
209
+ raise AttributeError()
210
+
211
+ tokens, tags = zip(*ner_tags)
212
+ tokens, tags = list(tokens), list(tags)
213
+
214
+ if self.scheme == 'IOBE':
215
+ # Replace E prefix with I prefix
216
+ tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
217
+ if self.scheme == 'IOBES':
218
+ # Replace E prefix with I prefix and replace S prefix with B
219
+ tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
220
+ tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags))
221
+
222
+ if not self.strict:
223
+
224
+ tags = self._fix_incorrect_tags(tags)
225
+
226
+ ent = Tokens(tokens=tags, scheme=IOB2,
227
+ suffix=False, delimiter=self.tag_delimiter)
228
+
229
+ ne_position_mappings = ent.entities
230
+ token_positions = []
231
+ curr_len = 0
232
+ tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens))
233
+ for i, token in enumerate(tokens):
234
+ token_len = len(token)
235
+ if i == 0:
236
+ token_positions.append((0, curr_len + token_len))
237
+ else:
238
+ token_positions.append((curr_len, curr_len + token_len ))
239
+ curr_len += token_len
240
+ print(f'token_positions: {list(zip(tokens, token_positions))}')
241
+ begin_end_pos = []
242
+ begin_end_char_pos = []
243
+ accum_char_len = 0
244
+ for i, ne_position_mapping in enumerate(ne_position_mappings):
245
+ print(f'ne_position_mapping.start: {ne_position_mapping.start}')
246
+ print(f'ne_position_mapping.end: {ne_position_mapping.end}\n')
247
+ begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end))
248
+ begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1]))
249
+ print(f'begin_end_pos: {begin_end_pos}')
250
+ print(f'begin_end_char_pos: {begin_end_char_pos}')
251
+
252
+ j = 0
253
+ # print(f'tokens: {tokens}')
254
+ for i, pos_tuple in enumerate(begin_end_pos):
255
+ # print(f'j = {j}')
256
+ if pos_tuple[0] > 0 and i == 0:
257
+ ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0]))
258
+ j += 1
259
+ if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 :
260
+ ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0]))
261
+ j += 1
262
+
263
+ j += 1
264
+ print('ne_position_mappings', ne_position_mappings)
265
+
266
+ groups = []
267
+ k = 0
268
+ for i, ne_position_mapping in enumerate(ne_position_mappings):
269
+ if type(ne_position_mapping) != tuple:
270
+ ne_position_mapping = ne_position_mapping.to_tuple()
271
+ ne = ne_position_mapping[1]
272
+
273
+ text = ''
274
+ for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]):
275
+ _token = tokens[ne_position]
276
+ text += _token if _token != self.space_token else ' '
277
+ if ne.lower() != 'o':
278
+ groups.append({
279
+ 'entity_group': ne,
280
+ 'word': text,
281
+ 'begin_char_index': begin_end_char_pos[k][0]
282
+ })
283
+ k+=1
284
+ else:
285
+ groups.append({
286
+ 'entity_group': ne,
287
+ 'word': text,
288
+ })
289
+ return groups
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ git+https://github.com/vistec-ai/thai2transformers.git@feature/add_ner_scheme
3
+ pythainlp==2.2.4