emanuelaboros commited on
Commit
7d92279
·
verified ·
1 Parent(s): 51adc94

Update generic_ner.py

Browse files
Files changed (1) hide show
  1. generic_ner.py +28 -14
generic_ner.py CHANGED
@@ -127,25 +127,40 @@ def get_entities(tokens, tags, confidences, text):
127
 
128
  return entities
129
 
130
-
131
  def realign(
132
- text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
133
  ):
 
 
 
 
 
 
 
 
 
134
  preds_list, words_list, confidence_list = [], [], []
135
- word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
136
- for idx, word in enumerate(text_sentence):
137
- beginning_index = word_ids.index(idx)
138
- try:
139
- preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
140
- confidence_list.append(max(softmax_scores[beginning_index]))
141
- except Exception as ex: # the sentence was longer then max_length
142
- preds_list.append("O")
143
- confidence_list.append(0.0)
144
- words_list.append(word)
 
 
 
 
 
 
145
 
146
  return words_list, preds_list, confidence_list
147
 
148
 
 
149
  def segment_and_trim_sentences(article, language, max_length):
150
 
151
  try:
@@ -248,14 +263,12 @@ class MultitaskTokenClassificationPipeline(Pipeline):
248
  for task, logits in chunk_result.logits.items():
249
  predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
250
  confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
251
- print(predictions)
252
  # Decode and process the predictions
253
  decoded_predictions = {}
254
  for task, preds in predictions.items():
255
  decoded_predictions[task] = [
256
  [self.id2label[task][label] for label in seq] for seq in preds
257
  ]
258
- print(decoded_predictions)
259
  # Extract entities from the combined predictions
260
  entities = {}
261
  for task, preds in predictions.items():
@@ -266,6 +279,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
266
  self.tokenizer,
267
  self.id2label[task],
268
  )
 
269
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
270
 
271
  return entities
 
127
 
128
  return entities
129
 
 
130
  def realign(
131
+ text_sentences, out_label_preds, softmax_scores, tokenizer, reverted_label_map
132
  ):
133
+ """
134
+ Realign predictions across multiple text chunks.
135
+
136
+ text_sentences: List of text chunks (the original text split into chunks)
137
+ out_label_preds: Predictions for each chunk
138
+ softmax_scores: Confidence scores for each chunk
139
+ tokenizer: The tokenizer used for encoding/decoding
140
+ reverted_label_map: Mapping from predicted labels to readable labels
141
+ """
142
  preds_list, words_list, confidence_list = [], [], []
143
+
144
+ # Process each chunk individually
145
+ for chunk_idx, text_sentence in enumerate(text_sentences):
146
+ word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
147
+
148
+ for idx, word in enumerate(text_sentence):
149
+ try:
150
+ # Align based on word indices within the current chunk
151
+ beginning_index = word_ids.index(idx)
152
+ preds_list.append(reverted_label_map[out_label_preds[chunk_idx][beginning_index]])
153
+ confidence_list.append(max(softmax_scores[chunk_idx][beginning_index]))
154
+ except Exception as ex: # Handle any misalignment issues
155
+ preds_list.append("O")
156
+ confidence_list.append(0.0)
157
+
158
+ words_list.append(word)
159
 
160
  return words_list, preds_list, confidence_list
161
 
162
 
163
+
164
  def segment_and_trim_sentences(article, language, max_length):
165
 
166
  try:
 
263
  for task, logits in chunk_result.logits.items():
264
  predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
265
  confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
 
266
  # Decode and process the predictions
267
  decoded_predictions = {}
268
  for task, preds in predictions.items():
269
  decoded_predictions[task] = [
270
  [self.id2label[task][label] for label in seq] for seq in preds
271
  ]
 
272
  # Extract entities from the combined predictions
273
  entities = {}
274
  for task, preds in predictions.items():
 
279
  self.tokenizer,
280
  self.id2label[task],
281
  )
282
+ print(words_list, preds_list, confidence_list)
283
  entities[task] = get_entities(words_list, preds_list, confidence_list, text)
284
 
285
  return entities