LayoutLMv3_for_recepits / cord_inference.py
mp-02's picture
Upload folder using huggingface_hub
6d1caf6 verified
raw
history blame
4.07 kB
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box
labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def prediction(image_path: str):
image = Image.open(image_path).convert('RGB')
boxes, words = OCR(image_path)
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offset_mapping = encoding.pop('offset_mapping')
for k, v in encoding.items():
encoding[k] = v.to(device)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
inp_ids = encoding.input_ids.squeeze().tolist()
inp_words = [tokenizer.decode(i) for i in inp_ids]
width, height = image.size
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
true_words = []
for id, i in enumerate(inp_words):
if not is_subword[id]:
true_words.append(i)
else:
true_words[-1] = true_words[-1]+i
true_predictions = true_predictions[1:-1]
true_boxes = true_boxes[1:-1]
true_words = true_words[1:-1]
preds = []
l_words = []
bboxes = []
for i, j in enumerate(true_predictions):
if j != 'others':
preds.append(true_predictions[i])
l_words.append(true_words[i])
bboxes.append(true_boxes[i])
d = {}
for id, i in enumerate(preds):
if i not in d.keys():
d[i] = l_words[id]
else:
d[i] = d[i] + ", " + l_words[id]
d = {k: v.strip() for (k, v) in d.items()}
# TODO: process the json
draw = ImageDraw.Draw(image, "RGBA")
font = ImageFont.load_default()
for prediction, box in zip(preds, bboxes):
draw.rectangle(box)
draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
return d, image