Update generic_ner.py
Browse files- generic_ner.py +13 -26
generic_ner.py
CHANGED
@@ -128,39 +128,25 @@ def get_entities(tokens, tags, confidences, text):
|
|
128 |
return entities
|
129 |
|
130 |
def realign(
|
131 |
-
|
132 |
):
|
133 |
-
"""
|
134 |
-
Realign predictions across multiple text chunks.
|
135 |
-
|
136 |
-
text_sentences: List of text chunks (the original text split into chunks)
|
137 |
-
out_label_preds: Predictions for each chunk
|
138 |
-
softmax_scores: Confidence scores for each chunk
|
139 |
-
tokenizer: The tokenizer used for encoding/decoding
|
140 |
-
reverted_label_map: Mapping from predicted labels to readable labels
|
141 |
-
"""
|
142 |
preds_list, words_list, confidence_list = [], [], []
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
confidence_list.append(max(softmax_scores[chunk_idx][beginning_index]))
|
154 |
-
except Exception as ex: # Handle any misalignment issues
|
155 |
-
preds_list.append("O")
|
156 |
-
confidence_list.append(0.0)
|
157 |
-
|
158 |
-
words_list.append(word)
|
159 |
|
160 |
return words_list, preds_list, confidence_list
|
161 |
|
162 |
|
163 |
|
|
|
164 |
def segment_and_trim_sentences(article, language, max_length):
|
165 |
|
166 |
try:
|
@@ -271,6 +257,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
271 |
]
|
272 |
# Extract entities from the combined predictions
|
273 |
entities = {}
|
|
|
274 |
for task, preds in predictions.items():
|
275 |
words_list, preds_list, confidence_list = realign(
|
276 |
text_chunks,
|
|
|
128 |
return entities
|
129 |
|
130 |
def realign(
|
131 |
+
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
132 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
preds_list, words_list, confidence_list = [], [], []
|
134 |
+
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
|
135 |
+
for idx, word in enumerate(text_sentence):
|
136 |
+
beginning_index = word_ids.index(idx)
|
137 |
+
try:
|
138 |
+
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
|
139 |
+
confidence_list.append(max(softmax_scores[beginning_index]))
|
140 |
+
except Exception as ex: # the sentence was longer then max_length
|
141 |
+
preds_list.append("O")
|
142 |
+
confidence_list.append(0.0)
|
143 |
+
words_list.append(word)
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
return words_list, preds_list, confidence_list
|
146 |
|
147 |
|
148 |
|
149 |
+
|
150 |
def segment_and_trim_sentences(article, language, max_length):
|
151 |
|
152 |
try:
|
|
|
257 |
]
|
258 |
# Extract entities from the combined predictions
|
259 |
entities = {}
|
260 |
+
print(predictions)
|
261 |
for task, preds in predictions.items():
|
262 |
words_list, preds_list, confidence_list = realign(
|
263 |
text_chunks,
|