Emanuela Boros
commited on
Commit
·
c549c79
1
Parent(s):
4efcbf3
update handler
Browse files- generic_ner.py +48 -4
generic_ner.py
CHANGED
@@ -2,8 +2,9 @@ from transformers import Pipeline
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import nltk
|
5 |
-
|
6 |
-
nltk.download(
|
|
|
7 |
from nltk.chunk import conlltags2tree
|
8 |
from nltk import pos_tag
|
9 |
from nltk.tree import Tree
|
@@ -107,9 +108,13 @@ def get_entities(tokens, tags, confidences, text):
|
|
107 |
entities.append(
|
108 |
{
|
109 |
"entity": original_label,
|
110 |
-
"score": round(
|
|
|
|
|
111 |
"index": (idx, idx + len(subtree)),
|
112 |
-
"word": text[
|
|
|
|
|
113 |
"start": entity_start_position,
|
114 |
"end": entity_end_position,
|
115 |
}
|
@@ -221,6 +226,44 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
221 |
outputs = self.model(input_ids, attention_mask)
|
222 |
return outputs, text_sentences, text
|
223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
def postprocess(self, outputs, **kwargs):
|
226 |
"""
|
@@ -249,4 +292,5 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
249 |
|
250 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
251 |
|
|
|
252 |
return entities
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import nltk
|
5 |
+
|
6 |
+
nltk.download("averaged_perceptron_tagger")
|
7 |
+
nltk.download("averaged_perceptron_tagger_eng")
|
8 |
from nltk.chunk import conlltags2tree
|
9 |
from nltk import pos_tag
|
10 |
from nltk.tree import Tree
|
|
|
108 |
entities.append(
|
109 |
{
|
110 |
"entity": original_label,
|
111 |
+
"score": round(
|
112 |
+
np.average(confidences[idx : idx + len(subtree)]) * 100, 2
|
113 |
+
),
|
114 |
"index": (idx, idx + len(subtree)),
|
115 |
+
"word": text[
|
116 |
+
entity_start_position:entity_end_position
|
117 |
+
], # original_string,
|
118 |
"start": entity_start_position,
|
119 |
"end": entity_end_position,
|
120 |
}
|
|
|
226 |
outputs = self.model(input_ids, attention_mask)
|
227 |
return outputs, text_sentences, text
|
228 |
|
229 |
+
def is_within(self, entity1, entity2):
|
230 |
+
"""Check if entity1 is fully within the bounds of entity2."""
|
231 |
+
return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]
|
232 |
+
|
233 |
+
def postprocess_entities(self, ner_results):
|
234 |
+
# Collect all entities in one list for processing
|
235 |
+
all_entities = []
|
236 |
+
for key in ner_results:
|
237 |
+
all_entities.extend(ner_results[key])
|
238 |
+
|
239 |
+
# Sort entities by start position, then by end position (to handle nested structures)
|
240 |
+
all_entities.sort(key=lambda x: (x["start"], -x["end"]))
|
241 |
+
|
242 |
+
# Create a new list for final processed entities
|
243 |
+
final_entities = []
|
244 |
+
|
245 |
+
# Process each entity and check for nesting
|
246 |
+
for i, entity in enumerate(all_entities):
|
247 |
+
nested = False
|
248 |
+
|
249 |
+
# Compare the current entity with already processed entities
|
250 |
+
for parent_entity in final_entities:
|
251 |
+
if self.is_within(entity, parent_entity):
|
252 |
+
# If the current entity is nested, add it as a field in the parent entity
|
253 |
+
field_name = entity["entity"].split(".")[
|
254 |
+
-1
|
255 |
+
] # Last part of the label as the field
|
256 |
+
if field_name not in parent_entity:
|
257 |
+
parent_entity[field_name] = []
|
258 |
+
parent_entity[field_name].append(entity)
|
259 |
+
nested = True
|
260 |
+
break
|
261 |
+
|
262 |
+
if not nested:
|
263 |
+
# If not nested, add the entity as a new outermost entity
|
264 |
+
final_entities.append(entity)
|
265 |
+
|
266 |
+
return final_entities
|
267 |
|
268 |
def postprocess(self, outputs, **kwargs):
|
269 |
"""
|
|
|
292 |
|
293 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
294 |
|
295 |
+
print(self.postprocess_entities(entities))
|
296 |
return entities
|