Spaces:
Running
Running
Update sroie_inference.py
Browse files- sroie_inference.py +9 -4
sroie_inference.py
CHANGED
@@ -46,6 +46,8 @@ def prediction(image):
|
|
46 |
|
47 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
48 |
token_boxes = encoding.bbox.squeeze().tolist()
|
|
|
|
|
49 |
|
50 |
inp_ids = encoding.input_ids.squeeze().tolist()
|
51 |
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
@@ -55,6 +57,7 @@ def prediction(image):
|
|
55 |
|
56 |
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
57 |
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
|
|
58 |
true_words = []
|
59 |
|
60 |
for id, i in enumerate(inp_words):
|
@@ -66,16 +69,18 @@ def prediction(image):
|
|
66 |
true_predictions = true_predictions[1:-1]
|
67 |
true_boxes = true_boxes[1:-1]
|
68 |
true_words = true_words[1:-1]
|
|
|
69 |
|
70 |
preds = []
|
71 |
l_words = []
|
72 |
bboxes = []
|
73 |
|
74 |
for i, j in enumerate(true_predictions):
|
75 |
-
if
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
|
80 |
d = {}
|
81 |
for id, i in enumerate(preds):
|
|
|
46 |
|
47 |
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
48 |
token_boxes = encoding.bbox.squeeze().tolist()
|
49 |
+
probabilities = torch.softmax(outputs.logits, dim=-1)
|
50 |
+
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
|
51 |
|
52 |
inp_ids = encoding.input_ids.squeeze().tolist()
|
53 |
inp_words = [tokenizer.decode(i) for i in inp_ids]
|
|
|
57 |
|
58 |
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
|
59 |
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
|
60 |
+
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
|
61 |
true_words = []
|
62 |
|
63 |
for id, i in enumerate(inp_words):
|
|
|
69 |
true_predictions = true_predictions[1:-1]
|
70 |
true_boxes = true_boxes[1:-1]
|
71 |
true_words = true_words[1:-1]
|
72 |
+
true_confidence_scores = true_confidence_scores[1:-1]
|
73 |
|
74 |
preds = []
|
75 |
l_words = []
|
76 |
bboxes = []
|
77 |
|
78 |
for i, j in enumerate(true_predictions):
|
79 |
+
if true_confidence_scores[i] < 0.9: #####################################àà
|
80 |
+
true_predictions[i] = "O"
|
81 |
+
preds.append(true_predictions[i])
|
82 |
+
l_words.append(true_words[i])
|
83 |
+
bboxes.append(true_boxes[i])
|
84 |
|
85 |
d = {}
|
86 |
for id, i in enumerate(preds):
|