@@ -0,0 +1,146 @@
1 |
from typing import Dict, List, Any
2 |
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
3 |
import torch
4 |
from subprocess import run
5 |
6 |
# install tesseract-ocr and pytesseract
7 |
# run("apt install -y tesseract-ocr", shell=True, check=True)
8 |
run("pip install pytesseract", shell=True, check=True)
9 |
10 |
# helper function to unnormalize bboxes for drawing onto the image
11 |
def unnormalize_box(bbox, width, height):
12 |
return [
13 |
width * (bbox[0] / 1000),
14 |
height * (bbox[1] / 1000),
15 |
width * (bbox[2] / 1000),
16 |
height * (bbox[3] / 1000),
17 |
18 |
19 |
def predict(Image, processor, model):
20 |
"""Process document and prepare the data for LayoutLM inference
21 |
22 |
23 |
urls (List[str]): Batch of pre-signed document urls
24 |
25 |
(List[List[Dict]]): Features extraction
26 |
27 |
28 |
29 |
# images = [get_image_from_url(url) for url in urls]
30 |
encoding = processor(
31 |
images = Image,
32 |
33 |
34 |
35 |
36 |
del encoding["image"] # LayoutLM doesn't require the image
37 |
outputs = model(**encoding)
38 |
results = process_outputs(
39 |
outputs, encoding=encoding,
40 |
images=Image, model=model,
41 |
42 |
threshold = 0.75
43 |
44 |
return results, encoding
45 |
def get_uniqueLabelList(labels):
46 |
uqnieue_labels =[]
47 |
for label in labels[0]:
48 |
49 |
label_short = label.split("-")[1]
50 |
if label_short not in uqnieue_labels:
51 |
52 |
53 |
if label not in uqnieue_labels:
54 |
55 |
56 |
57 |
return uqnieue_labels
58 |
59 |
def process_outputs(outputs, encoding, images, model, processor, threshold):
60 |
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1)
61 |
scores = scores.tolist()
62 |
predictions = outputs.logits.argmax(-1)
63 |
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions]
64 |
results = _process_outputs(
65 |
66 |
67 |
processor = processor,
68 |
69 |
70 |
71 |
threshold = threshold
72 |
73 |
return results
74 |
75 |
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold):
76 |
results = []
77 |
78 |
width, height = images.size
79 |
entities = []
80 |
previous_word_idx = 0
81 |
unique_lables = get_uniqueLabelList(labels)
82 |
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
83 |
# word_ids = encoding.word_ids(batch_index=batch_idx)
84 |
# word = ""
85 |
entite_wordsidx = []
86 |
for idx, label in enumerate(unique_lables):
87 |
score_sum = float(0)
88 |
if label != "O":
89 |
for ix, pred in enumerate(labels[0]):
90 |
if scores[0][ix] > threshold:
91 |
if label in pred:
92 |
score_sum += scores[0][ix]
93 |
94 |
95 |
96 |
97 |
score_mean = f'{score_sum/len(entite_wordsidx):.2f}'
98 |
99 |
score_mean = 0.0
100 |
# entite_wordsidx.append(entite_wordsidx[-1] + 1)
101 |
102 |
103 |
"word": processor.decode(encoding.input_ids[0][entite_wordsidx]),
104 |
"label": unique_lables[idx],
105 |
"score": score_mean
106 |
107 |
108 |
109 |
110 |
entite_wordsidx = []
111 |
112 |
113 |
114 |
return results
115 |
116 |
def unnormalize_box(bbox, width, height):
117 |
return [
118 |
int(width * (bbox[0] / 1000)),
119 |
int(height * (bbox[1] / 1000)),
120 |
int(width * (bbox[2] / 1000)),
121 |
int(height * (bbox[3] / 1000)),
122 |
123 |
def get_image_from_url(Image):
124 |
return"RGB") # LayoutLMv2Processor requires RGB format
125 |
# set device
126 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127 |
128 |
129 |
class EndpointHandler:
130 |
def __init__(self, path=""):
131 |
# load model and processor from path
132 |
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
133 |
self.processor = LayoutLMv2Processor.from_pretrained(path)
134 |
135 |
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
136 |
137 |
138 |
data (:obj:):
139 |
includes the deserialized image file as PIL.Image
140 |
141 |
# process input
142 |
image = data.pop("inputs", data)
143 |
144 |
145 |
result, encod = predict(image, self.processor, self.model)
146 |
return {"predictions": result}
Binary file (6.83 kB). View file
@@ -0,0 +1,145 @@
1 |
from typing import Dict, List, Any
2 |
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
3 |
import torch
4 |
from subprocess import run
5 |
6 |
# install tesseract-ocr and pytesseract
7 |
run("apt install -y tesseract-ocr", shell=True, check=True)
8 |
run("pip install pytesseract", shell=True, check=True)
9 |
10 |
# helper function to unnormalize bboxes for drawing onto the image
11 |
def unnormalize_box(bbox, width, height):
12 |
return [
13 |
width * (bbox[0] / 1000),
14 |
height * (bbox[1] / 1000),
15 |
width * (bbox[2] / 1000),
16 |
height * (bbox[3] / 1000),
17 |
18 |
19 |
def predict(Image, processor, model):
20 |
"""Process document and prepare the data for LayoutLM inference
21 |
22 |
23 |
urls (List[str]): Batch of pre-signed document urls
24 |
25 |
(List[List[Dict]]): Features extraction
26 |
27 |
28 |
29 |
# images = [get_image_from_url(url) for url in urls]
30 |
encoding = processor(
31 |
images = Image,
32 |
33 |
34 |
35 |
36 |
del encoding["image"] # LayoutLM doesn't require the image
37 |
outputs = model(**encoding)
38 |
results = process_outputs(
39 |
outputs, encoding=encoding,
40 |
images=Image, model=model,
41 |
42 |
threshold = 0.75
43 |
44 |
return results, encoding
45 |
def get_uniqueLabelList(labels):
46 |
uqnieue_labels =[]
47 |
for label in labels[0]:
48 |
49 |
label_short = label.split("-")[1]
50 |
if label_short not in uqnieue_labels:
51 |
52 |
53 |
if label not in uqnieue_labels:
54 |
55 |
56 |
57 |
return uqnieue_labels
58 |
59 |
def process_outputs(outputs, encoding, images, model, processor, threshold):
60 |
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1)
61 |
scores = scores.tolist()
62 |
predictions = outputs.logits.argmax(-1)
63 |
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions]
64 |
results = _process_outputs(
65 |
66 |
67 |
processor = processor,
68 |
69 |
70 |
71 |
threshold = threshold
72 |
73 |
return results
74 |
75 |
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold):
76 |
results = []
77 |
78 |
width, height = images.size
79 |
entities = []
80 |
previous_word_idx = 0
81 |
unique_lables = get_uniqueLabelList(labels)
82 |
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
83 |
# word_ids = encoding.word_ids(batch_index=batch_idx)
84 |
# word = ""
85 |
entite_wordsidx = []
86 |
for idx, label in enumerate(unique_lables):
87 |
score_sum = float(0)
88 |
if label != "O":
89 |
for ix, pred in enumerate(labels[0]):
90 |
if scores[0][ix] > threshold:
91 |
if label in pred:
92 |
score_sum += scores[0][ix]
93 |
94 |
95 |
96 |
97 |
score_mean = f'{score_sum/len(entite_wordsidx):.2f}'
98 |
99 |
score_mean = 0.0
100 |
# entite_wordsidx.append(entite_wordsidx[-1] + 1)
101 |
102 |
103 |
"word": processor.decode(encoding.input_ids[0][entite_wordsidx]),
104 |
"label": unique_lables[idx],
105 |
"score": score_mean
106 |
107 |
108 |
109 |
110 |
entite_wordsidx = []
111 |
112 |
113 |
114 |
return results
115 |
116 |
def unnormalize_box(bbox, width, height):
117 |
return [
118 |
int(width * (bbox[0] / 1000)),
119 |
int(height * (bbox[1] / 1000)),
120 |
int(width * (bbox[2] / 1000)),
121 |
int(height * (bbox[3] / 1000)),
122 |
123 |
def get_image_from_url(Image):
124 |
return"RGB") # LayoutLMv2Processor requires RGB format
125 |
# set device
126 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127 |
128 |
129 |
class EndpointHandler:
130 |
def __init__(self, path=""):
131 |
# load model and processor from path
132 |
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
133 |
self.processor = LayoutLMv2Processor.from_pretrained(path, apply_ocr=True)
134 |
135 |
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
136 |
137 |
138 |
data (:obj:):
139 |
includes the deserialized image file as PIL.Image
140 |
141 |
# process input
142 |
image = data.pop("inputs", data)
143 |
144 |
result, encod = predict(image, self.processor, self.model)
145 |
return {"predictions": result}