Spaces:
Running
Running
Update sroie_inference.py
Browse files- sroie_inference.py +3 -4
sroie_inference.py
CHANGED
@@ -5,15 +5,14 @@ from PIL import Image, ImageDraw, ImageFont
|
|
5 |
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
6 |
from utils import OCR, unnormalize_box
|
7 |
|
8 |
-
# [B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL", "O"]
|
9 |
-
labels = ["COMPANY", "COMPANY", "DATE", "DATE", "ADDRESS", "ADDRESS", "TOTAL", "TOTAL", "O"]
|
10 |
-
id2label = {v: k for v, k in enumerate(labels)}
|
11 |
-
label2id = {k: v for v, k in enumerate(labels)}
|
12 |
|
13 |
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
14 |
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
15 |
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
|
16 |
|
|
|
|
|
|
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
18 |
model.to(device)
|
19 |
|
|
|
5 |
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
6 |
from utils import OCR, unnormalize_box
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
9 |
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
10 |
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
11 |
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
|
12 |
|
13 |
+
id2label = model.config.id2label
|
14 |
+
label2id = model.config.label2id
|
15 |
+
|
16 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
17 |
model.to(device)
|
18 |
|