File size: 3,832 Bytes
e41ca05
 
 
 
 
 
 
a369b59
 
e41ca05
 
 
53d7087
 
 
e41ca05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97095a3
 
e41ca05
 
 
 
 
 
 
 
 
97095a3
e41ca05
 
 
 
 
 
 
 
 
 
 
97095a3
e41ca05
f83550f
 
 
e41ca05
 
4064ff6
e41ca05
ff7a14a
e41ca05
4064ff6
e41ca05
 
 
bf30ca5
 
e41ca05
cbee1cf
4f4c048
e41ca05
 
 
 
d68e9cd
f83550f
e41ca05
b11d3ce
f83550f
 
e41ca05
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from utils import OCR, unnormalize_box

#  [B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL", "O"]
labels = ["COMPANY", "COMPANY", "DATE", "DATE", "ADDRESS", "ADDRESS", "TOTAL", "TOTAL", "O"]
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}

tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def blur(image, boxes):
    image = np.array(image)
    for box in boxes:

        blur_x = int(box[0])
        blur_y = int(box[1])
        blur_width = int(box[2]-box[0])
        blur_height = int(box[3]-box[1])

        roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
        blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
        image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image

    return Image.fromarray(image, 'RGB')


def prediction(image):
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    probabilities = torch.softmax(outputs.logits, dim=-1)
    confidence_scores = probabilities.max(-1).values.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]
    true_confidence_scores = true_confidence_scores[1:-1]

    #for i, j in enumerate(true_confidence_scores):
    #    if j < 0.8:  #####################################
    #        true_predictions[i] = "O"

    d = {}
    for id, i in enumerate(true_predictions):
        if i not in d.keys():
            d[i] = true_words[id]
        else:
            d[i] = d[i] + ", " + true_words[id]
    d = {k: v.strip() for (k, v) in d.items()}

    if "O" in d: d.pop("O")
    if "TOTAL" in d: d.pop("TOTAL")
        
    blur_boxes = []
    for prediction, box in zip(true_predictions, true_boxes):
        if prediction != 'O' and prediction != 'TOTAL':
            blur_boxes.append(box)

    image = (blur(image, blur_boxes))

    draw = ImageDraw.Draw(image, "RGBA")
    font = ImageFont.load_default()

    for prediction, box in zip(true_predictions, true_boxes):
       draw.rectangle(box)
       draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")

    return d, image