mp-02 commited on
Commit
bf30ca5
1 Parent(s): 45ab6df

Update sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +4 -14
sroie_inference.py CHANGED
@@ -5,8 +5,8 @@ from PIL import Image, ImageDraw, ImageFont
5
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from utils import OCR, unnormalize_box
7
 
8
-
9
- labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
  id2label = {v: k for v, k in enumerate(labels)}
11
  label2id = {k: v for v, k in enumerate(labels)}
12
 
@@ -83,19 +83,9 @@ def prediction(image):
83
  d[i] = d[i] + ", " + true_words[id]
84
  d = {k: v.strip() for (k, v) in d.items()}
85
 
86
- keys_to_pop = []
87
- for k, v in d.items():
88
- if k[:2] == "I-":
89
- d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
90
- keys_to_pop.append(k)
91
-
92
  if "O" in d: d.pop("O")
93
- if "B-TOTAL" in d: d.pop("B-TOTAL")
94
- for k in keys_to_pop: d.pop(k)
95
-
96
- for k in d.keys():
97
- k = k[2:]
98
-
99
  blur_boxes = []
100
  for prediction, box in zip(true_predictions, true_boxes):
101
  if prediction != 'O' and prediction[2:] != 'TOTAL':
 
5
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from utils import OCR, unnormalize_box
7
 
8
+ # ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
9
+ labels = ["O", "COMPANY", "COMPANY", "DATE", "DATE", "ADDRESS", "ADDRESS", "TOTAL", "TOTAL"]
10
  id2label = {v: k for v, k in enumerate(labels)}
11
  label2id = {k: v for v, k in enumerate(labels)}
12
 
 
83
  d[i] = d[i] + ", " + true_words[id]
84
  d = {k: v.strip() for (k, v) in d.items()}
85
 
 
 
 
 
 
 
86
  if "O" in d: d.pop("O")
87
+ if "TOTAL" in d: d.pop("TOTAL")
88
+
 
 
 
 
89
  blur_boxes = []
90
  for prediction, box in zip(true_predictions, true_boxes):
91
  if prediction != 'O' and prediction[2:] != 'TOTAL':