Emanuela Boros commited on
Commit
c549c79
·
1 Parent(s): 4efcbf3

update handler

Browse files
Files changed (1) hide show
  1. generic_ner.py +48 -4
generic_ner.py CHANGED
@@ -2,8 +2,9 @@ from transformers import Pipeline
2
  import numpy as np
3
  import torch
4
  import nltk
5
- nltk.download('averaged_perceptron_tagger')
6
- nltk.download('averaged_perceptron_tagger_eng')
 
7
  from nltk.chunk import conlltags2tree
8
  from nltk import pos_tag
9
  from nltk.tree import Tree
@@ -107,9 +108,13 @@ def get_entities(tokens, tags, confidences, text):
107
  entities.append(
108
  {
109
  "entity": original_label,
110
- "score": round(np.average(confidences[idx : idx + len(subtree)]) * 100, 2),
 
 
111
  "index": (idx, idx + len(subtree)),
112
- "word": text[entity_start_position:entity_end_position], #original_string,
 
 
113
  "start": entity_start_position,
114
  "end": entity_end_position,
115
  }
@@ -221,6 +226,44 @@ class MultitaskTokenClassificationPipeline(Pipeline):
221
  outputs = self.model(input_ids, attention_mask)
222
  return outputs, text_sentences, text
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  def postprocess(self, outputs, **kwargs):
226
  """
@@ -249,4 +292,5 @@ class MultitaskTokenClassificationPipeline(Pipeline):
249
 
250
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
251
 
 
252
  return entities
 
2
  import numpy as np
3
  import torch
4
  import nltk
5
+
6
+ nltk.download("averaged_perceptron_tagger")
7
+ nltk.download("averaged_perceptron_tagger_eng")
8
  from nltk.chunk import conlltags2tree
9
  from nltk import pos_tag
10
  from nltk.tree import Tree
 
108
  entities.append(
109
  {
110
  "entity": original_label,
111
+ "score": round(
112
+ np.average(confidences[idx : idx + len(subtree)]) * 100, 2
113
+ ),
114
  "index": (idx, idx + len(subtree)),
115
+ "word": text[
116
+ entity_start_position:entity_end_position
117
+ ], # original_string,
118
  "start": entity_start_position,
119
  "end": entity_end_position,
120
  }
 
226
  outputs = self.model(input_ids, attention_mask)
227
  return outputs, text_sentences, text
228
 
229
+ def is_within(self, entity1, entity2):
230
+ """Check if entity1 is fully within the bounds of entity2."""
231
+ return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]
232
+
233
+ def postprocess_entities(self, ner_results):
234
+ # Collect all entities in one list for processing
235
+ all_entities = []
236
+ for key in ner_results:
237
+ all_entities.extend(ner_results[key])
238
+
239
+ # Sort entities by start position, then by end position (to handle nested structures)
240
+ all_entities.sort(key=lambda x: (x["start"], -x["end"]))
241
+
242
+ # Create a new list for final processed entities
243
+ final_entities = []
244
+
245
+ # Process each entity and check for nesting
246
+ for i, entity in enumerate(all_entities):
247
+ nested = False
248
+
249
+ # Compare the current entity with already processed entities
250
+ for parent_entity in final_entities:
251
+ if self.is_within(entity, parent_entity):
252
+ # If the current entity is nested, add it as a field in the parent entity
253
+ field_name = entity["entity"].split(".")[
254
+ -1
255
+ ] # Last part of the label as the field
256
+ if field_name not in parent_entity:
257
+ parent_entity[field_name] = []
258
+ parent_entity[field_name].append(entity)
259
+ nested = True
260
+ break
261
+
262
+ if not nested:
263
+ # If not nested, add the entity as a new outermost entity
264
+ final_entities.append(entity)
265
+
266
+ return final_entities
267
 
268
  def postprocess(self, outputs, **kwargs):
269
  """
 
292
 
293
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
294
 
295
+ print(self.postprocess_entities(entities))
296
  return entities