mp-02 commited on
Commit
35e30e2
1 Parent(s): a369b59

Update sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +3 -4
sroie_inference.py CHANGED
@@ -5,15 +5,14 @@ from PIL import Image, ImageDraw, ImageFont
5
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from utils import OCR, unnormalize_box
7
 
8
- # [B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL", "O"]
9
- labels = ["COMPANY", "COMPANY", "DATE", "DATE", "ADDRESS", "ADDRESS", "TOTAL", "TOTAL", "O"]
10
- id2label = {v: k for v, k in enumerate(labels)}
11
- label2id = {k: v for v, k in enumerate(labels)}
12
 
13
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
  processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
  model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
16
 
 
 
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
  model.to(device)
19
 
 
5
  from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
  from utils import OCR, unnormalize_box
7
 
 
 
 
 
8
 
9
  tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
10
  processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
11
  model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
12
 
13
+ id2label = model.config.id2label
14
+ label2id = model.config.label2id
15
+
16
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
  model.to(device)
18