Szczotar93 commited on
Commit
327def7
1 Parent(s): cc75d9a

Upload create_handler.ipynb

Browse files
Files changed (1) hide show
  1. create_handler.ipynb +223 -0
create_handler.ipynb ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## 1. Setup & Installation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "!apt install -y tesseract-ocr\n",
17
+ "pip install pytesseract"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## 2. Create Custom Handler for Inference Endpoints\n"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 20,
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "Overwriting handler.py\n"
37
+ ]
38
+ }
39
+ ],
40
+ "source": [
41
+ "%%writefile handler.py\n",
42
+ "from typing import Dict, List, Any\n",
43
+ "from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor\n",
44
+ "import torch\n",
45
+ "from subprocess import run\n",
46
+ "\n",
47
+ "# install tesseract-ocr and pytesseract\n",
48
+ "run(\"apt install -y tesseract-ocr\", shell=True, check=True)\n",
49
+ "run(\"pip install pytesseract\", shell=True, check=True)\n",
50
+ "\n",
51
+ "# helper function to unnormalize bboxes for drawing onto the image\n",
52
+ "def unnormalize_box(bbox, width, height):\n",
53
+ " return [\n",
54
+ " width * (bbox[0] / 1000),\n",
55
+ " height * (bbox[1] / 1000),\n",
56
+ " width * (bbox[2] / 1000),\n",
57
+ " height * (bbox[3] / 1000),\n",
58
+ " ]\n",
59
+ "\n",
60
+ "\n",
61
+ "# set device\n",
62
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
63
+ "\n",
64
+ "\n",
65
+ "class EndpointHandler:\n",
66
+ " def __init__(self, path=\"\"):\n",
67
+ " # load model and processor from path\n",
68
+ " self.model = LayoutLMForTokenClassification.from_pretrained(\"philschmid/layoutlm-funsd\").to(device)\n",
69
+ " self.processor = LayoutLMv2Processor.from_pretrained(\"philschmid/layoutlm-funsd\")\n",
70
+ "\n",
71
+ " def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:\n",
72
+ " \"\"\"\n",
73
+ " Args:\n",
74
+ " data (:obj:):\n",
75
+ " includes the deserialized image file as PIL.Image\n",
76
+ " \"\"\"\n",
77
+ " # process input\n",
78
+ " image = data.pop(\"inputs\", data)\n",
79
+ "\n",
80
+ " # process image\n",
81
+ " encoding = self.processor(image, return_tensors=\"pt\")\n",
82
+ "\n",
83
+ " # run prediction\n",
84
+ " with torch.inference_mode():\n",
85
+ " outputs = self.model(\n",
86
+ " input_ids=encoding.input_ids.to(device),\n",
87
+ " bbox=encoding.bbox.to(device),\n",
88
+ " attention_mask=encoding.attention_mask.to(device),\n",
89
+ " token_type_ids=encoding.token_type_ids.to(device),\n",
90
+ " )\n",
91
+ " predictions = outputs.logits.softmax(-1)\n",
92
+ "\n",
93
+ " # post process output\n",
94
+ " result = []\n",
95
+ " for item, inp_ids, bbox in zip(\n",
96
+ " predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()\n",
97
+ " ):\n",
98
+ " label = self.model.config.id2label[int(item.argmax().cpu())]\n",
99
+ " if label == \"O\":\n",
100
+ " continue\n",
101
+ " score = item.max().item()\n",
102
+ " text = self.processor.tokenizer.decode(inp_ids)\n",
103
+ " bbox = unnormalize_box(bbox.tolist(), image.width, image.height)\n",
104
+ " result.append({\"label\": label, \"score\": score, \"text\": text, \"bbox\": bbox})\n",
105
+ " return {\"predictions\": result}\n"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "test custom pipeline"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 2,
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "from handler import EndpointHandler\n",
122
+ "\n",
123
+ "my_handler = EndpointHandler(\".\")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 13,
129
+ "metadata": {},
130
+ "outputs": [
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
136
+ "To disable this warning, you can either:\n",
137
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
138
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "import base64\n",
144
+ "from PIL import Image\n",
145
+ "from io import BytesIO\n",
146
+ "import json\n",
147
+ "\n",
148
+ "# read image from disk\n",
149
+ "image = Image.open(\"invoice_example.png\")\n",
150
+ "request = {\"inputs\":image }\n",
151
+ "\n",
152
+ "# test the handler\n",
153
+ "pred = my_handler(request)"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 16,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "from PIL import Image, ImageDraw, ImageFont\n",
163
+ "\n",
164
+ "\n",
165
+ "def draw_result(image,result):\n",
166
+ " label2color = {\n",
167
+ " \"B-HEADER\": \"blue\",\n",
168
+ " \"B-QUESTION\": \"red\",\n",
169
+ " \"B-ANSWER\": \"green\",\n",
170
+ " \"I-HEADER\": \"blue\",\n",
171
+ " \"I-QUESTION\": \"red\",\n",
172
+ " \"I-ANSWER\": \"green\",\n",
173
+ " }\n",
174
+ "\n",
175
+ "\n",
176
+ " # draw predictions over the image\n",
177
+ " draw = ImageDraw.Draw(image)\n",
178
+ " font = ImageFont.load_default()\n",
179
+ " for res in result:\n",
180
+ " draw.rectangle(res[\"bbox\"], outline=\"black\")\n",
181
+ " draw.rectangle(res[\"bbox\"], outline=label2color[res[\"label\"]])\n",
182
+ " draw.text((res[\"bbox\"][0] + 10, res[\"bbox\"][1] - 10), text=res[\"label\"], fill=label2color[res[\"label\"]], font=font)\n",
183
+ " return image\n",
184
+ "\n",
185
+ "draw_result(image,pred[\"predictions\"])"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": []
194
+ }
195
+ ],
196
+ "metadata": {
197
+ "kernelspec": {
198
+ "display_name": "Python 3.9.13 ('dev': conda)",
199
+ "language": "python",
200
+ "name": "python3"
201
+ },
202
+ "language_info": {
203
+ "codemirror_mode": {
204
+ "name": "ipython",
205
+ "version": 3
206
+ },
207
+ "file_extension": ".py",
208
+ "mimetype": "text/x-python",
209
+ "name": "python",
210
+ "nbconvert_exporter": "python",
211
+ "pygments_lexer": "ipython3",
212
+ "version": "3.9.13"
213
+ },
214
+ "orig_nbformat": 4,
215
+ "vscode": {
216
+ "interpreter": {
217
+ "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
218
+ }
219
+ }
220
+ },
221
+ "nbformat": 4,
222
+ "nbformat_minor": 2
223
+ }