Add related files
Browse files- app.py +127 -0
- pipeline.py +289 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from functools import partial
|
3 |
+
from typing import Callable, Dict
|
4 |
+
import transformers
|
5 |
+
from transformers import (
|
6 |
+
AutoModelForTokenClassification,
|
7 |
+
AutoTokenizer
|
8 |
+
)
|
9 |
+
|
10 |
+
from pipeline import (
|
11 |
+
TokenClassificationPipeline
|
12 |
+
)
|
13 |
+
import pythainlp
|
14 |
+
from pprint import pprint
|
15 |
+
from itertools import chain
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
|
19 |
+
|
20 |
+
ner_pipeline_group = TokenClassificationPipeline(
|
21 |
+
model=AutoModelForTokenClassification.from_pretrained(
|
22 |
+
'airesearch/wangchanberta-base-att-spm-uncased',
|
23 |
+
revision='finetuned@thainer-ner'
|
24 |
+
),
|
25 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
26 |
+
'airesearch/wangchanberta-base-att-spm-uncased',
|
27 |
+
revision='finetuned@thainer-ner'
|
28 |
+
),
|
29 |
+
space_token='<_>',
|
30 |
+
lowercase=True,
|
31 |
+
group_entities=True,
|
32 |
+
strict=False,
|
33 |
+
)
|
34 |
+
|
35 |
+
color_mapper = {
|
36 |
+
"DATE": "#f94144",
|
37 |
+
"EMAIL":"#f3722c",
|
38 |
+
"LAW":"#f8961e",
|
39 |
+
"LEN":"#f9844a",
|
40 |
+
"LOCATION":"#f9c74f",
|
41 |
+
"MONEY":"#ffcb77",
|
42 |
+
"ORGANIZATION":"#f5cac3",
|
43 |
+
"PERCENT":"#90be6d",
|
44 |
+
"PERSON":"#bfd200",
|
45 |
+
"PHONE":"#43aa8b",
|
46 |
+
"TIME":"#4d908e",
|
47 |
+
"URL":"#577590",
|
48 |
+
"ZIP":"#90e0ef",
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
css_text = 'p{width: 700px; color: #333; border-radius: 3px; border: solid 1.5px #DDD; background-color: #FFF;\n margin: 10px;\n padding: 30px}\n'
|
53 |
+
for k,v in color_mapper.items():
|
54 |
+
css_text += "span."+f"{k.lower()}" \
|
55 |
+
+"{\n background-color: " \
|
56 |
+
+f"{v}"+"50;\n color: #333;\n border-right: 4px solid " \
|
57 |
+
+f"{v}"+";" \
|
58 |
+
+ "\n align-items: center;" \
|
59 |
+
+ "\n margin: 0;" \
|
60 |
+
+ "\n padding: 2px 8px;" \
|
61 |
+
+ "\n border-radius: 3px;\n}\n" \
|
62 |
+
+"span."+f"{k.lower()}"+"::after {" \
|
63 |
+
+"\npadding: 2px 1px;" \
|
64 |
+
+"font-size: 9.5px;" \
|
65 |
+
+"font-weight: bold;" \
|
66 |
+
+"font-family: Monaco;" \
|
67 |
+
+"vertical-align: super;" \
|
68 |
+
+"content: \"" + k.upper() + "\";" \
|
69 |
+
+"}\n" \
|
70 |
+
|
71 |
+
|
72 |
+
def modifiy_segment(text, tag, start, end):
|
73 |
+
replaced_text = text[:start] + f'<span class="{tag}">' + text[start:end] +'</span>' + text[end:]
|
74 |
+
return replaced_text, len(f'<span class="{tag}">') + len('</span>')
|
75 |
+
|
76 |
+
|
77 |
+
def render_doc_with_label(label: Dict, doc: str):
|
78 |
+
attribute_items = []
|
79 |
+
for i, ne_span in enumerate(label):
|
80 |
+
if ne_span['entity_group'] != 'O':
|
81 |
+
attribute_name = ne_span['entity_group']
|
82 |
+
attribute_name = attribute_name.lower()
|
83 |
+
|
84 |
+
begin_char_idx = ne_span['begin_char_index']
|
85 |
+
|
86 |
+
tagged_text = ne_span['word']
|
87 |
+
end_char_idx = begin_char_idx + len(tagged_text)
|
88 |
+
|
89 |
+
attribute_items.append((attribute_name, begin_char_idx, end_char_idx))
|
90 |
+
|
91 |
+
attribute_items = sorted(attribute_items, key=lambda x: (x[1]))
|
92 |
+
print(f'attribute_items: {attribute_items}')
|
93 |
+
|
94 |
+
acc_n_extra_chars = 0
|
95 |
+
modified_segment = doc
|
96 |
+
for _selected_attribute_item in attribute_items:
|
97 |
+
|
98 |
+
tag, start, end = _selected_attribute_item[0], _selected_attribute_item[1], _selected_attribute_item[2]
|
99 |
+
|
100 |
+
modified_segment, n_extra_chars = modifiy_segment(modified_segment, tag, start + acc_n_extra_chars, end + acc_n_extra_chars)
|
101 |
+
acc_n_extra_chars += n_extra_chars
|
102 |
+
|
103 |
+
return f'<style>{css_text}</style><p>{modified_segment}</p>'
|
104 |
+
|
105 |
+
def ner_tagging(text: str):
|
106 |
+
results = ner_pipeline_group(text)
|
107 |
+
print(f'results:\n{results}')
|
108 |
+
html_text = render_doc_with_label(results, text)
|
109 |
+
|
110 |
+
return json.dumps(results, ensure_ascii=False, indent=4), html_text
|
111 |
+
|
112 |
+
|
113 |
+
demo = gr.Interface(fn=ner_tagging,
|
114 |
+
inputs=gr.Textbox(lines=5, placeholder='Input text in Thai', label='Input text'),
|
115 |
+
examples=[
|
116 |
+
["ไมโครซอฟท์ได้จัดจำหน่ายบนแพลตฟอร์มไมโครซอฟท์ วินโดวส์ ในเดือนเมษายน 2020"],
|
117 |
+
['ชัชชาติ สิทธิพันธุ์ ผู้ว่าราชการกรุงเทพมหานคร (กทม.) คนที่ 17 เตรียมเข้ารับตำแหน่งอย่างเป็นทางการและเปิดตัวทีมงานในช่วงบ่ายวันนี้ (1 มิ.ย.) หลังรับมอบหนังสือรับรองการเป็นผู้ว่าฯ กทม. ที่สำนักงานคณะกรรมการการเลือกตั้ง (กกต.)'],
|
118 |
+
["สถาบันวิทยาศาสตร์ทางทะเล มหาวิทยาลัยบูรพา เปิดให้บริการมายาวนานกว่า 30 ปี ตั้งอยู่บริเวณด้านหน้า มหาวิทยาลัยบูรพา บนเนื้อที่กว่า 30 ไร่ เป็นสถานที่ท่องเที่ยว ที่จัดแสดงเพื่อให้ความรู้��กี่ยวกับวิทยาศาสตร์ทางทะเล สิ่งมีชีวิตและความเป็นอยู่ของสัตว์ทะเลชนิดต่างๆที่อาศัยอยู่ในเขตน่านน้ำของไทย"],
|
119 |
+
|
120 |
+
|
121 |
+
],
|
122 |
+
|
123 |
+
outputs=[gr.Textbox(), gr.HTML()])
|
124 |
+
|
125 |
+
print(f'\nINFO: transformers.__version__: {transformers.__version__}')
|
126 |
+
print(f'\nINFO: pythainlp.__version__: {pythainlp.__version__}')
|
127 |
+
demo.launch()
|
pipeline.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import Callable, List, Tuple, Union
|
4 |
+
from functools import partial
|
5 |
+
import itertools
|
6 |
+
|
7 |
+
from seqeval.scheme import Tokens, IOB2, IOBES
|
8 |
+
|
9 |
+
from transformers.modeling_utils import PreTrainedModel
|
10 |
+
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
11 |
+
from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize
|
12 |
+
newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm')
|
13 |
+
|
14 |
+
from thai2transformers.preprocess import rm_useless_spaces
|
15 |
+
|
16 |
+
SPIECE = '▁'
|
17 |
+
|
18 |
+
class TokenClassificationPipeline:
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
model: PreTrainedModel,
|
22 |
+
tokenizer: PreTrainedTokenizerBase,
|
23 |
+
pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer,
|
24 |
+
lowercase=False,
|
25 |
+
space_token='<_>',
|
26 |
+
device: int = -1,
|
27 |
+
group_entities: bool = False,
|
28 |
+
strict: bool = False,
|
29 |
+
tag_delimiter: str = '-',
|
30 |
+
scheme: str = 'IOB',
|
31 |
+
use_crf=False,
|
32 |
+
remove_spiece=True):
|
33 |
+
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
assert isinstance(tokenizer, PreTrainedTokenizerBase)
|
37 |
+
# assert isinstance(model, PreTrainedModel)
|
38 |
+
|
39 |
+
self.model = model
|
40 |
+
self.tokenizer = tokenizer
|
41 |
+
self.pretokenizer = pretokenizer
|
42 |
+
self.lowercase = lowercase
|
43 |
+
self.space_token = space_token
|
44 |
+
self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}'
|
45 |
+
self.group_entities = group_entities
|
46 |
+
self.strict = strict
|
47 |
+
self.tag_delimiter = tag_delimiter
|
48 |
+
self.scheme = scheme
|
49 |
+
self.id2label = self.model.config.id2label
|
50 |
+
self.label2id = self.model.config.label2id
|
51 |
+
self.use_crf = use_crf
|
52 |
+
self.remove_spiece = remove_spiece
|
53 |
+
self.model.to(self.device)
|
54 |
+
|
55 |
+
def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
|
56 |
+
|
57 |
+
if self.lowercase:
|
58 |
+
inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs))
|
59 |
+
|
60 |
+
inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs))
|
61 |
+
|
62 |
+
tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs))
|
63 |
+
|
64 |
+
tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \
|
65 |
+
list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens))
|
66 |
+
|
67 |
+
return tokens
|
68 |
+
|
69 |
+
def _inference(self, input: str):
|
70 |
+
|
71 |
+
tokens = [[self.tokenizer.bos_token]] + \
|
72 |
+
[self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \
|
73 |
+
[[self.tokenizer.eos_token]]
|
74 |
+
ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens]
|
75 |
+
flatten_tokens = list(itertools.chain(*tokens))
|
76 |
+
flatten_ids = list(itertools.chain(*ids))
|
77 |
+
|
78 |
+
input_ids = torch.LongTensor([flatten_ids]).to(self.device)
|
79 |
+
|
80 |
+
if self.use_crf:
|
81 |
+
out = self.model(input_ids=input_ids)
|
82 |
+
else:
|
83 |
+
out = self.model(input_ids=input_ids, return_dict=True)
|
84 |
+
probs = torch.softmax(out['logits'], dim=-1)
|
85 |
+
vals, indices = probs.topk(1)
|
86 |
+
indices_np = indices.detach().cpu().numpy().reshape(-1)
|
87 |
+
|
88 |
+
list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] ))
|
89 |
+
merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids)
|
90 |
+
if self.remove_spiece:
|
91 |
+
merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds))
|
92 |
+
|
93 |
+
# remove start and end tokens
|
94 |
+
merged_preds_removed_bos_eos = merged_preds[1:-1]
|
95 |
+
# convert to list of Dict objects
|
96 |
+
merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \
|
97 |
+
for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ]
|
98 |
+
|
99 |
+
if (not self.group_entities or self.scheme == None) and self.strict == True:
|
100 |
+
return merged_preds_return_dict
|
101 |
+
elif not self.group_entities and self.strict == False:
|
102 |
+
|
103 |
+
tags = list(map(lambda x: x['entity'], merged_preds_return_dict))
|
104 |
+
processed_tags = self._fix_incorrect_tags(tags)
|
105 |
+
for i, item in enumerate(merged_preds_return_dict):
|
106 |
+
merged_preds_return_dict[i]['entity'] = processed_tags[i]
|
107 |
+
return merged_preds_return_dict
|
108 |
+
elif self.group_entities:
|
109 |
+
return self._group_entities(merged_preds_removed_bos_eos)
|
110 |
+
|
111 |
+
def __call__(self, inputs: Union[str, List[str]]):
|
112 |
+
|
113 |
+
"""
|
114 |
+
|
115 |
+
"""
|
116 |
+
if type(inputs) == str:
|
117 |
+
return self._inference(inputs)
|
118 |
+
|
119 |
+
if type(inputs) == list:
|
120 |
+
results = [ self._inference(text) for text in inputs]
|
121 |
+
return results
|
122 |
+
|
123 |
+
|
124 |
+
def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]):
|
125 |
+
|
126 |
+
token_mapping = [ ]
|
127 |
+
for i in range(0, len(ids)):
|
128 |
+
for j in range(0, len(ids[i])):
|
129 |
+
token_mapping.append(i)
|
130 |
+
|
131 |
+
grouped_subtokens = []
|
132 |
+
_subtoken = []
|
133 |
+
prev_idx = 0
|
134 |
+
|
135 |
+
for i, (subtoken, label) in enumerate(preds):
|
136 |
+
|
137 |
+
current_idx = token_mapping[i]
|
138 |
+
if prev_idx != current_idx:
|
139 |
+
grouped_subtokens.append(_subtoken)
|
140 |
+
_subtoken = [(subtoken, label)]
|
141 |
+
if i == len(preds) -1:
|
142 |
+
_subtoken = [(subtoken, label)]
|
143 |
+
grouped_subtokens.append(_subtoken)
|
144 |
+
elif i == len(preds) -1:
|
145 |
+
_subtoken += [(subtoken, label)]
|
146 |
+
grouped_subtokens.append(_subtoken)
|
147 |
+
else:
|
148 |
+
_subtoken += [(subtoken, label)]
|
149 |
+
prev_idx = current_idx
|
150 |
+
|
151 |
+
merged_subtokens = []
|
152 |
+
_merged_subtoken = ''
|
153 |
+
for subtoken_group in grouped_subtokens:
|
154 |
+
|
155 |
+
first_token_pred = subtoken_group[0][1]
|
156 |
+
_merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group)))
|
157 |
+
merged_subtokens.append((_merged_subtoken, first_token_pred))
|
158 |
+
return merged_subtokens
|
159 |
+
|
160 |
+
def _fix_incorrect_tags(self, tags: List[str]) -> List[str]:
|
161 |
+
|
162 |
+
I_PREFIX = f'I{self.tag_delimiter}'
|
163 |
+
E_PREFIX = f'E{self.tag_delimiter}'
|
164 |
+
B_PREFIX = f'B{self.tag_delimiter}'
|
165 |
+
O_PREFIX = 'O'
|
166 |
+
|
167 |
+
previous_tag_ne = None
|
168 |
+
for i, current_tag in enumerate(tags):
|
169 |
+
|
170 |
+
current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX
|
171 |
+
|
172 |
+
if i == 0 and (current_tag.startswith(I_PREFIX) or \
|
173 |
+
current_tag.startswith(E_PREFIX)):
|
174 |
+
# if a NE tag (with I-, or E- prefix) occuring at the begining of sentence
|
175 |
+
# e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O)
|
176 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
177 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
178 |
+
elif i >= 1 and tags[i-1] == O_PREFIX and (
|
179 |
+
current_tag.startswith(I_PREFIX) or \
|
180 |
+
current_tag.startswith(E_PREFIX)):
|
181 |
+
# if a NE tag (with I-, or E- prefix) occuring after O tag
|
182 |
+
# e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O)
|
183 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
184 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
185 |
+
elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \
|
186 |
+
tags[i-1].startswith(E_PREFIX) or \
|
187 |
+
tags[i-1].startswith(B_PREFIX)) and \
|
188 |
+
( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \
|
189 |
+
previous_tag_ne != current_tag_ne:
|
190 |
+
# if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE
|
191 |
+
# e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER)
|
192 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
193 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
194 |
+
elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and (
|
195 |
+
current_tag.startswith(I_PREFIX) or \
|
196 |
+
current_tag.startswith(E_PREFIX)):
|
197 |
+
# if a NE tag (with I-, or E- prefix) occuring at the end of sentence
|
198 |
+
# e.g. (O, O, I-LOC) , (O, O, E-LOC)
|
199 |
+
# then, change the prefix of the current tag to B{tag_delimiter}
|
200 |
+
tags[i] = B_PREFIX + tags[i][2:]
|
201 |
+
|
202 |
+
previous_tag_ne = current_tag_ne
|
203 |
+
|
204 |
+
return tags
|
205 |
+
|
206 |
+
def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
207 |
+
|
208 |
+
if self.scheme not in ['IOB', 'IOBES', 'IOBE']:
|
209 |
+
raise AttributeError()
|
210 |
+
|
211 |
+
tokens, tags = zip(*ner_tags)
|
212 |
+
tokens, tags = list(tokens), list(tags)
|
213 |
+
|
214 |
+
if self.scheme == 'IOBE':
|
215 |
+
# Replace E prefix with I prefix
|
216 |
+
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
|
217 |
+
if self.scheme == 'IOBES':
|
218 |
+
# Replace E prefix with I prefix and replace S prefix with B
|
219 |
+
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
|
220 |
+
tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags))
|
221 |
+
|
222 |
+
if not self.strict:
|
223 |
+
|
224 |
+
tags = self._fix_incorrect_tags(tags)
|
225 |
+
|
226 |
+
ent = Tokens(tokens=tags, scheme=IOB2,
|
227 |
+
suffix=False, delimiter=self.tag_delimiter)
|
228 |
+
|
229 |
+
ne_position_mappings = ent.entities
|
230 |
+
token_positions = []
|
231 |
+
curr_len = 0
|
232 |
+
tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens))
|
233 |
+
for i, token in enumerate(tokens):
|
234 |
+
token_len = len(token)
|
235 |
+
if i == 0:
|
236 |
+
token_positions.append((0, curr_len + token_len))
|
237 |
+
else:
|
238 |
+
token_positions.append((curr_len, curr_len + token_len ))
|
239 |
+
curr_len += token_len
|
240 |
+
print(f'token_positions: {list(zip(tokens, token_positions))}')
|
241 |
+
begin_end_pos = []
|
242 |
+
begin_end_char_pos = []
|
243 |
+
accum_char_len = 0
|
244 |
+
for i, ne_position_mapping in enumerate(ne_position_mappings):
|
245 |
+
print(f'ne_position_mapping.start: {ne_position_mapping.start}')
|
246 |
+
print(f'ne_position_mapping.end: {ne_position_mapping.end}\n')
|
247 |
+
begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end))
|
248 |
+
begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1]))
|
249 |
+
print(f'begin_end_pos: {begin_end_pos}')
|
250 |
+
print(f'begin_end_char_pos: {begin_end_char_pos}')
|
251 |
+
|
252 |
+
j = 0
|
253 |
+
# print(f'tokens: {tokens}')
|
254 |
+
for i, pos_tuple in enumerate(begin_end_pos):
|
255 |
+
# print(f'j = {j}')
|
256 |
+
if pos_tuple[0] > 0 and i == 0:
|
257 |
+
ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0]))
|
258 |
+
j += 1
|
259 |
+
if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 :
|
260 |
+
ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0]))
|
261 |
+
j += 1
|
262 |
+
|
263 |
+
j += 1
|
264 |
+
print('ne_position_mappings', ne_position_mappings)
|
265 |
+
|
266 |
+
groups = []
|
267 |
+
k = 0
|
268 |
+
for i, ne_position_mapping in enumerate(ne_position_mappings):
|
269 |
+
if type(ne_position_mapping) != tuple:
|
270 |
+
ne_position_mapping = ne_position_mapping.to_tuple()
|
271 |
+
ne = ne_position_mapping[1]
|
272 |
+
|
273 |
+
text = ''
|
274 |
+
for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]):
|
275 |
+
_token = tokens[ne_position]
|
276 |
+
text += _token if _token != self.space_token else ' '
|
277 |
+
if ne.lower() != 'o':
|
278 |
+
groups.append({
|
279 |
+
'entity_group': ne,
|
280 |
+
'word': text,
|
281 |
+
'begin_char_index': begin_end_char_pos[k][0]
|
282 |
+
})
|
283 |
+
k+=1
|
284 |
+
else:
|
285 |
+
groups.append({
|
286 |
+
'entity_group': ne,
|
287 |
+
'word': text,
|
288 |
+
})
|
289 |
+
return groups
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
git+https://github.com/vistec-ai/thai2transformers.git@feature/add_ner_scheme
|
3 |
+
pythainlp==2.2.4
|