Update generic_ner.py
Browse files- generic_ner.py +28 -14
generic_ner.py
CHANGED
@@ -127,25 +127,40 @@ def get_entities(tokens, tags, confidences, text):
|
|
127 |
|
128 |
return entities
|
129 |
|
130 |
-
|
131 |
def realign(
|
132 |
-
|
133 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
preds_list, words_list, confidence_list = [], [], []
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
return words_list, preds_list, confidence_list
|
147 |
|
148 |
|
|
|
149 |
def segment_and_trim_sentences(article, language, max_length):
|
150 |
|
151 |
try:
|
@@ -248,14 +263,12 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
248 |
for task, logits in chunk_result.logits.items():
|
249 |
predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
|
250 |
confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
|
251 |
-
print(predictions)
|
252 |
# Decode and process the predictions
|
253 |
decoded_predictions = {}
|
254 |
for task, preds in predictions.items():
|
255 |
decoded_predictions[task] = [
|
256 |
[self.id2label[task][label] for label in seq] for seq in preds
|
257 |
]
|
258 |
-
print(decoded_predictions)
|
259 |
# Extract entities from the combined predictions
|
260 |
entities = {}
|
261 |
for task, preds in predictions.items():
|
@@ -266,6 +279,7 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
266 |
self.tokenizer,
|
267 |
self.id2label[task],
|
268 |
)
|
|
|
269 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
270 |
|
271 |
return entities
|
|
|
127 |
|
128 |
return entities
|
129 |
|
|
|
130 |
def realign(
|
131 |
+
text_sentences, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
132 |
):
|
133 |
+
"""
|
134 |
+
Realign predictions across multiple text chunks.
|
135 |
+
|
136 |
+
text_sentences: List of text chunks (the original text split into chunks)
|
137 |
+
out_label_preds: Predictions for each chunk
|
138 |
+
softmax_scores: Confidence scores for each chunk
|
139 |
+
tokenizer: The tokenizer used for encoding/decoding
|
140 |
+
reverted_label_map: Mapping from predicted labels to readable labels
|
141 |
+
"""
|
142 |
preds_list, words_list, confidence_list = [], [], []
|
143 |
+
|
144 |
+
# Process each chunk individually
|
145 |
+
for chunk_idx, text_sentence in enumerate(text_sentences):
|
146 |
+
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
|
147 |
+
|
148 |
+
for idx, word in enumerate(text_sentence):
|
149 |
+
try:
|
150 |
+
# Align based on word indices within the current chunk
|
151 |
+
beginning_index = word_ids.index(idx)
|
152 |
+
preds_list.append(reverted_label_map[out_label_preds[chunk_idx][beginning_index]])
|
153 |
+
confidence_list.append(max(softmax_scores[chunk_idx][beginning_index]))
|
154 |
+
except Exception as ex: # Handle any misalignment issues
|
155 |
+
preds_list.append("O")
|
156 |
+
confidence_list.append(0.0)
|
157 |
+
|
158 |
+
words_list.append(word)
|
159 |
|
160 |
return words_list, preds_list, confidence_list
|
161 |
|
162 |
|
163 |
+
|
164 |
def segment_and_trim_sentences(article, language, max_length):
|
165 |
|
166 |
try:
|
|
|
263 |
for task, logits in chunk_result.logits.items():
|
264 |
predictions[task].extend(torch.argmax(logits, dim=-1).tolist())
|
265 |
confidence_scores[task].extend(F.softmax(logits, dim=-1).tolist())
|
|
|
266 |
# Decode and process the predictions
|
267 |
decoded_predictions = {}
|
268 |
for task, preds in predictions.items():
|
269 |
decoded_predictions[task] = [
|
270 |
[self.id2label[task][label] for label in seq] for seq in preds
|
271 |
]
|
|
|
272 |
# Extract entities from the combined predictions
|
273 |
entities = {}
|
274 |
for task, preds in predictions.items():
|
|
|
279 |
self.tokenizer,
|
280 |
self.id2label[task],
|
281 |
)
|
282 |
+
print(words_list, preds_list, confidence_list)
|
283 |
entities[task] = get_entities(words_list, preds_list, confidence_list, text)
|
284 |
|
285 |
return entities
|