Spaces:
Sleeping
Sleeping
File size: 3,832 Bytes
e41ca05 a369b59 e41ca05 53d7087 e41ca05 97095a3 e41ca05 97095a3 e41ca05 97095a3 e41ca05 f83550f e41ca05 4064ff6 e41ca05 ff7a14a e41ca05 4064ff6 e41ca05 bf30ca5 e41ca05 cbee1cf 4f4c048 e41ca05 d68e9cd f83550f e41ca05 b11d3ce f83550f e41ca05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from utils import OCR, unnormalize_box
# [B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL", "O"]
labels = ["COMPANY", "COMPANY", "DATE", "DATE", "ADDRESS", "ADDRESS", "TOTAL", "TOTAL", "O"]
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def blur(image, boxes):
image = np.array(image)
for box in boxes:
blur_x = int(box[0])
blur_y = int(box[1])
blur_width = int(box[2]-box[0])
blur_height = int(box[3]-box[1])
roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
return Image.fromarray(image, 'RGB')
def prediction(image):
boxes, words = OCR(image)
encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
offset_mapping = encoding.pop('offset_mapping')
for k, v in encoding.items():
encoding[k] = v.to(device)
outputs = model(**encoding)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding.bbox.squeeze().tolist()
probabilities = torch.softmax(outputs.logits, dim=-1)
confidence_scores = probabilities.max(-1).values.squeeze().tolist()
inp_ids = encoding.input_ids.squeeze().tolist()
inp_words = [tokenizer.decode(i) for i in inp_ids]
width, height = image.size
is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
true_words = []
for id, i in enumerate(inp_words):
if not is_subword[id]:
true_words.append(i)
else:
true_words[-1] = true_words[-1]+i
true_predictions = true_predictions[1:-1]
true_boxes = true_boxes[1:-1]
true_words = true_words[1:-1]
true_confidence_scores = true_confidence_scores[1:-1]
#for i, j in enumerate(true_confidence_scores):
# if j < 0.8: #####################################
# true_predictions[i] = "O"
d = {}
for id, i in enumerate(true_predictions):
if i not in d.keys():
d[i] = true_words[id]
else:
d[i] = d[i] + ", " + true_words[id]
d = {k: v.strip() for (k, v) in d.items()}
if "O" in d: d.pop("O")
if "TOTAL" in d: d.pop("TOTAL")
blur_boxes = []
for prediction, box in zip(true_predictions, true_boxes):
if prediction != 'O' and prediction != 'TOTAL':
blur_boxes.append(box)
image = (blur(image, blur_boxes))
draw = ImageDraw.Draw(image, "RGBA")
font = ImageFont.load_default()
for prediction, box in zip(true_predictions, true_boxes):
draw.rectangle(box)
draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
return d, image
|