Update generic_ner.py
Browse files- generic_ner.py +2 -2
generic_ner.py
CHANGED
@@ -234,8 +234,8 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
234 |
predictions = {}
|
235 |
confidence_scores = {}
|
236 |
for task, logits in tokens_result.logits.items():
|
237 |
-
predictions[task] = torch.argmax(logits, dim=-1).tolist()
|
238 |
-
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()
|
239 |
|
240 |
decoded_predictions = {}
|
241 |
for task, preds in predictions.items():
|
|
|
234 |
predictions = {}
|
235 |
confidence_scores = {}
|
236 |
for task, logits in tokens_result.logits.items():
|
237 |
+
predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
|
238 |
+
confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
|
239 |
|
240 |
decoded_predictions = {}
|
241 |
for task, preds in predictions.items():
|