Obai33 commited on
Commit
d199483
·
verified ·
1 Parent(s): 6ba62be

Upload app.ipynb

Browse files
Files changed (1) hide show
  1. app.ipynb +570 -0
app.ipynb ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "1c550b9b-ab70-46a5-a584-29a0d3ae31ee",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Defaulting to user installation because normal site-packages is not writeable\n",
14
+ "Collecting gradio\n",
15
+ " Using cached gradio-4.41.0-py3-none-any.whl.metadata (15 kB)\n",
16
+ "Collecting aiofiles<24.0,>=22.0 (from gradio)\n",
17
+ " Using cached aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)\n",
18
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (4.4.0)\n",
19
+ "Collecting fastapi (from gradio)\n",
20
+ " Using cached fastapi-0.112.1-py3-none-any.whl.metadata (27 kB)\n",
21
+ "Collecting ffmpy (from gradio)\n",
22
+ " Using cached ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)\n",
23
+ "Collecting gradio-client==1.3.0 (from gradio)\n",
24
+ " Using cached gradio_client-1.3.0-py3-none-any.whl.metadata (7.1 kB)\n",
25
+ "Requirement already satisfied: httpx>=0.24.1 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (0.27.0)\n",
26
+ "Collecting huggingface-hub>=0.19.3 (from gradio)\n",
27
+ " Using cached huggingface_hub-0.24.5-py3-none-any.whl.metadata (13 kB)\n",
28
+ "Collecting importlib-resources<7.0,>=1.3 (from gradio)\n",
29
+ " Using cached importlib_resources-6.4.3-py3-none-any.whl.metadata (3.9 kB)\n",
30
+ "Requirement already satisfied: jinja2<4.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (3.1.4)\n",
31
+ "Requirement already satisfied: markupsafe~=2.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (2.1.5)\n",
32
+ "Requirement already satisfied: matplotlib~=3.0 in /usr/lib/python3/dist-packages (from gradio) (3.5.1)\n",
33
+ "Requirement already satisfied: numpy<3.0,>=1.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (1.26.4)\n",
34
+ "Collecting orjson~=3.0 (from gradio)\n",
35
+ " Using cached orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (50 kB)\n",
36
+ "Requirement already satisfied: packaging in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (24.0)\n",
37
+ "Requirement already satisfied: pandas<3.0,>=1.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (2.2.2)\n",
38
+ "Requirement already satisfied: pillow<11.0,>=8.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (10.3.0)\n",
39
+ "Collecting pydantic>=2.0 (from gradio)\n",
40
+ " Using cached pydantic-2.8.2-py3-none-any.whl.metadata (125 kB)\n",
41
+ "Collecting pydub (from gradio)\n",
42
+ " Using cached pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)\n",
43
+ "Collecting python-multipart>=0.0.9 (from gradio)\n",
44
+ " Using cached python_multipart-0.0.9-py3-none-any.whl.metadata (2.5 kB)\n",
45
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/lib/python3/dist-packages (from gradio) (5.4.1)\n",
46
+ "Collecting ruff>=0.2.2 (from gradio)\n",
47
+ " Using cached ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)\n",
48
+ "Collecting semantic-version~=2.0 (from gradio)\n",
49
+ " Using cached semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)\n",
50
+ "Collecting tomlkit==0.12.0 (from gradio)\n",
51
+ " Using cached tomlkit-0.12.0-py3-none-any.whl.metadata (2.7 kB)\n",
52
+ "Collecting typer<1.0,>=0.12 (from gradio)\n",
53
+ " Using cached typer-0.12.4-py3-none-any.whl.metadata (15 kB)\n",
54
+ "Requirement already satisfied: typing-extensions~=4.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (4.12.1)\n",
55
+ "Requirement already satisfied: urllib3~=2.0 in /home/obai33/.local/lib/python3.10/site-packages (from gradio) (2.2.1)\n",
56
+ "Collecting uvicorn>=0.14.0 (from gradio)\n",
57
+ " Using cached uvicorn-0.30.6-py3-none-any.whl.metadata (6.6 kB)\n",
58
+ "Requirement already satisfied: fsspec in /home/obai33/.local/lib/python3.10/site-packages (from gradio-client==1.3.0->gradio) (2024.5.0)\n",
59
+ "Collecting websockets<13.0,>=10.0 (from gradio-client==1.3.0->gradio)\n",
60
+ " Using cached websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n",
61
+ "Requirement already satisfied: idna>=2.8 in /home/obai33/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (3.7)\n",
62
+ "Requirement already satisfied: sniffio>=1.1 in /home/obai33/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)\n",
63
+ "Requirement already satisfied: exceptiongroup>=1.0.2 in /home/obai33/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.2.1)\n",
64
+ "Requirement already satisfied: certifi in /home/obai33/.local/lib/python3.10/site-packages (from httpx>=0.24.1->gradio) (2024.6.2)\n",
65
+ "Requirement already satisfied: httpcore==1.* in /home/obai33/.local/lib/python3.10/site-packages (from httpx>=0.24.1->gradio) (1.0.5)\n",
66
+ "Requirement already satisfied: h11<0.15,>=0.13 in /home/obai33/.local/lib/python3.10/site-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n",
67
+ "Requirement already satisfied: filelock in /home/obai33/.local/lib/python3.10/site-packages (from huggingface-hub>=0.19.3->gradio) (3.14.0)\n",
68
+ "Requirement already satisfied: requests in /home/obai33/.local/lib/python3.10/site-packages (from huggingface-hub>=0.19.3->gradio) (2.32.3)\n",
69
+ "Requirement already satisfied: tqdm>=4.42.1 in /home/obai33/.local/lib/python3.10/site-packages (from huggingface-hub>=0.19.3->gradio) (4.66.4)\n",
70
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /home/obai33/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)\n",
71
+ "Requirement already satisfied: pytz>=2020.1 in /usr/lib/python3/dist-packages (from pandas<3.0,>=1.0->gradio) (2022.1)\n",
72
+ "Requirement already satisfied: tzdata>=2022.7 in /home/obai33/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n",
73
+ "Collecting annotated-types>=0.4.0 (from pydantic>=2.0->gradio)\n",
74
+ " Using cached annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)\n",
75
+ "Collecting pydantic-core==2.20.1 (from pydantic>=2.0->gradio)\n",
76
+ " Using cached pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)\n",
77
+ "Collecting click>=8.0.0 (from typer<1.0,>=0.12->gradio)\n",
78
+ " Using cached click-8.1.7-py3-none-any.whl.metadata (3.0 kB)\n",
79
+ "Collecting shellingham>=1.3.0 (from typer<1.0,>=0.12->gradio)\n",
80
+ " Using cached shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)\n",
81
+ "Requirement already satisfied: rich>=10.11.0 in /home/obai33/.local/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio) (13.7.1)\n",
82
+ "Collecting starlette<0.39.0,>=0.37.2 (from fastapi->gradio)\n",
83
+ " Using cached starlette-0.38.2-py3-none-any.whl.metadata (5.9 kB)\n",
84
+ "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.16.0)\n",
85
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /home/obai33/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)\n",
86
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/obai33/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.18.0)\n",
87
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/obai33/.local/lib/python3.10/site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.3.2)\n",
88
+ "Requirement already satisfied: mdurl~=0.1 in /home/obai33/.local/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)\n",
89
+ "Downloading gradio-4.41.0-py3-none-any.whl (12.6 MB)\n",
90
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.6/12.6 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
91
+ "\u001b[?25hDownloading gradio_client-1.3.0-py3-none-any.whl (318 kB)\n",
92
+ "Downloading tomlkit-0.12.0-py3-none-any.whl (37 kB)\n",
93
+ "Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)\n",
94
+ "Downloading huggingface_hub-0.24.5-py3-none-any.whl (417 kB)\n",
95
+ "Downloading importlib_resources-6.4.3-py3-none-any.whl (35 kB)\n",
96
+ "Downloading orjson-3.10.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (141 kB)\n",
97
+ "Downloading pydantic-2.8.2-py3-none-any.whl (423 kB)\n",
98
+ "Downloading pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB)\n",
99
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
100
+ "\u001b[?25hDownloading python_multipart-0.0.9-py3-none-any.whl (22 kB)\n",
101
+ "Downloading ruff-0.6.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.2 MB)\n",
102
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.2/10.2 MB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m \u001b[36m0:00:01\u001b[0m\n",
103
+ "\u001b[?25hDownloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)\n",
104
+ "Downloading typer-0.12.4-py3-none-any.whl (47 kB)\n",
105
+ "Downloading uvicorn-0.30.6-py3-none-any.whl (62 kB)\n",
106
+ "Downloading fastapi-0.112.1-py3-none-any.whl (93 kB)\n",
107
+ "Downloading ffmpy-0.4.0-py3-none-any.whl (5.8 kB)\n",
108
+ "Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n",
109
+ "Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)\n",
110
+ "Downloading click-8.1.7-py3-none-any.whl (97 kB)\n",
111
+ "Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)\n",
112
+ "Downloading starlette-0.38.2-py3-none-any.whl (72 kB)\n",
113
+ "Downloading websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (130 kB)\n",
114
+ "Installing collected packages: pydub, websockets, tomlkit, shellingham, semantic-version, ruff, python-multipart, pydantic-core, orjson, importlib-resources, ffmpy, click, annotated-types, aiofiles, uvicorn, starlette, pydantic, huggingface-hub, typer, gradio-client, fastapi, gradio\n",
115
+ "Successfully installed aiofiles-23.2.1 annotated-types-0.7.0 click-8.1.7 fastapi-0.112.1 ffmpy-0.4.0 gradio-4.41.0 gradio-client-1.3.0 huggingface-hub-0.24.5 importlib-resources-6.4.3 orjson-3.10.7 pydantic-2.8.2 pydantic-core-2.20.1 pydub-0.25.1 python-multipart-0.0.9 ruff-0.6.1 semantic-version-2.10.0 shellingham-1.5.4 starlette-0.38.2 tomlkit-0.12.0 typer-0.12.4 uvicorn-0.30.6 websockets-12.0\n"
116
+ ]
117
+ }
118
+ ],
119
+ "source": [
120
+ "!pip install gradio"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 23,
126
+ "id": "c850d5ae-5d43-45fb-91cc-084427440a97",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "import torch\n",
131
+ "import torch.nn.functional as F\n",
132
+ "import torchvision\n",
133
+ "import matplotlib.pyplot as plt\n",
134
+ "import zipfile\n",
135
+ "import os\n",
136
+ "import gradio as gr\n",
137
+ "from PIL import Image\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 70,
143
+ "id": "c3d9ca4d-fb14-495e-b405-8a964ecc9a51",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "CHARS = \"~=\" + \" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\\\"\"\n",
148
+ "BLANK = 0\n",
149
+ "PAD = 1\n",
150
+ "CHARS_DICT = {c: i for i, c in enumerate(CHARS)}\n",
151
+ "TEXTLEN = 30\n",
152
+ "\n",
153
+ "tokens_list = list(CHARS_DICT.keys())\n",
154
+ "silence_token = '|'\n",
155
+ "\n",
156
+ "if silence_token not in tokens_list:\n",
157
+ " tokens_list.append(silence_token)\n",
158
+ "\n",
159
+ "\n",
160
+ "def fit_picture(img):\n",
161
+ " target_height = 32\n",
162
+ " target_width = 400\n",
163
+ " \n",
164
+ " # Calculate resize dimensions\n",
165
+ " aspect_ratio = img.width / img.height\n",
166
+ " if aspect_ratio > (target_width / target_height):\n",
167
+ " resize_width = target_width\n",
168
+ " resize_height = int(target_width / aspect_ratio)\n",
169
+ " else:\n",
170
+ " resize_height = target_height\n",
171
+ " resize_width = int(target_height * aspect_ratio)\n",
172
+ " \n",
173
+ " # Resize transformation\n",
174
+ " resize_transform = transforms.Resize((resize_height, resize_width))\n",
175
+ " \n",
176
+ " # Pad transformation\n",
177
+ " padding_height = (target_height - resize_height) if target_height > resize_height else 0\n",
178
+ " padding_width = (target_width - resize_width) if target_width > resize_width else 0\n",
179
+ " pad_transform = transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant')\n",
180
+ " \n",
181
+ " transform = torchvision.transforms.Compose([\n",
182
+ " torchvision.transforms.Grayscale(num_output_channels = 1),\n",
183
+ " torchvision.transforms.ToTensor(),\n",
184
+ " torchvision.transforms.Normalize(0.5,0.5),\n",
185
+ " resize_transform,\n",
186
+ " pad_transform\n",
187
+ " ])\n",
188
+ "\n",
189
+ " fin_img = transform(img)\n",
190
+ " return fin_img\n",
191
+ "\n",
192
+ "def load_model(filename):\n",
193
+ " data = torch.load(filename)\n",
194
+ " recognizer.load_state_dict(data[\"recognizer\"])\n",
195
+ " optimizer.load_state_dict(data[\"optimizer\"])\n",
196
+ "\n",
197
+ "def ctc_decode_sequence(seq):\n",
198
+ " \"\"\"Removes blanks and repetitions from the sequence.\"\"\"\n",
199
+ " ret = []\n",
200
+ " prev = BLANK\n",
201
+ " for x in seq:\n",
202
+ " if prev != BLANK and prev != x:\n",
203
+ " ret.append(prev)\n",
204
+ " prev = x\n",
205
+ " if seq[-1] == 66:\n",
206
+ " ret.append(66)\n",
207
+ " return ret\n",
208
+ "\n",
209
+ "def ctc_decode(codes):\n",
210
+ " \"\"\"Decode a batch of sequences.\"\"\"\n",
211
+ " ret = []\n",
212
+ " for cs in codes.T:\n",
213
+ " ret.append(ctc_decode_sequence(cs))\n",
214
+ " return ret\n",
215
+ "\n",
216
+ "\n",
217
+ "def decode_text(codes):\n",
218
+ " chars = [CHARS[c] for c in codes]\n",
219
+ " return ''.join(chars)"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": 65,
225
+ "id": "6722e370-e7df-4efe-aa9d-e9436d3cc08e",
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "name": "stdout",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "Device: cuda\n"
233
+ ]
234
+ }
235
+ ],
236
+ "source": [
237
+ "class Residual(torch.nn.Module):\n",
238
+ " def __init__(self, in_channels, out_channels, stride, pdrop = 0.2):\n",
239
+ " super().__init__()\n",
240
+ " self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1)\n",
241
+ " self.bn1 = torch.nn.BatchNorm2d(out_channels)\n",
242
+ " self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1)\n",
243
+ " self.bn2 = torch.nn.BatchNorm2d(out_channels)\n",
244
+ " if in_channels != out_channels or stride != 1:\n",
245
+ " self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0)\n",
246
+ " else:\n",
247
+ " self.skip = torch.nn.Identity()\n",
248
+ " self.dropout = torch.nn.Dropout2d(pdrop)\n",
249
+ "\n",
250
+ " def forward(self, x):\n",
251
+ " y = torch.nn.functional.relu(self.bn1(self.conv1(x)))\n",
252
+ " y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x))\n",
253
+ " y = self.dropout(y)\n",
254
+ " return y\n",
255
+ " \n",
256
+ "class TextRecognizer(torch.nn.Module):\n",
257
+ " def __init__(self, labels):\n",
258
+ " super().__init__()\n",
259
+ " self.feature_extractor = torch.nn.Sequential(\n",
260
+ " Residual(1, 32, 1),\n",
261
+ " Residual(32, 32, 2),\n",
262
+ " Residual(32, 32, 1),\n",
263
+ " Residual(32, 64, 2),\n",
264
+ " Residual(64, 64, 1),\n",
265
+ " Residual(64, 128, (2,1)),\n",
266
+ " Residual(128, 128, 1),\n",
267
+ " Residual(128, 128, (2,1)),\n",
268
+ " Residual(128, 128, (2,1)),\n",
269
+ " )\n",
270
+ " self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True)\n",
271
+ " self.output = torch.nn.Linear(256, labels)\n",
272
+ "\n",
273
+ " def forward(self, x):\n",
274
+ " x = self.feature_extractor(x)\n",
275
+ " x = x.squeeze(2)\n",
276
+ " x = x.permute(2,0,1)\n",
277
+ " x,_ = self.recurrent(x)\n",
278
+ " x = self.output(x)\n",
279
+ " return x\n",
280
+ "\n",
281
+ "recognizer = TextRecognizer(len(CHARS))\n",
282
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
283
+ "print(\"Device:\", DEVICE)\n",
284
+ "LR = 1e-3\n",
285
+ "\n",
286
+ "recognizer.to(DEVICE)\n",
287
+ "optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR)"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 75,
293
+ "id": "e61f1d87-4a82-4714-b4e1-33719064a735",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stdout",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "Running on local URL: http://127.0.0.1:7889\n",
301
+ "Running on public URL: https://e1090d81e4ea8bf190.gradio.live\n",
302
+ "\n",
303
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
304
+ ]
305
+ },
306
+ {
307
+ "data": {
308
+ "text/html": [
309
+ "<div><iframe src=\"https://e1090d81e4ea8bf190.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
310
+ ],
311
+ "text/plain": [
312
+ "<IPython.core.display.HTML object>"
313
+ ]
314
+ },
315
+ "metadata": {},
316
+ "output_type": "display_data"
317
+ },
318
+ {
319
+ "data": {
320
+ "text/plain": []
321
+ },
322
+ "execution_count": 75,
323
+ "metadata": {},
324
+ "output_type": "execute_result"
325
+ },
326
+ {
327
+ "name": "stderr",
328
+ "output_type": "stream",
329
+ "text": [
330
+ "Traceback (most recent call last):\n",
331
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 536, in process_events\n",
332
+ " response = await route_utils.call_process_api(\n",
333
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 288, in call_process_api\n",
334
+ " output = await app.get_blocks().process_api(\n",
335
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1931, in process_api\n",
336
+ " result = await self.call_function(\n",
337
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1516, in call_function\n",
338
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
339
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
340
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
341
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 2177, in run_sync_in_worker_thread\n",
342
+ " return await future\n",
343
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 859, in run\n",
344
+ " result = context.run(func, *args)\n",
345
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/utils.py\", line 826, in wrapper\n",
346
+ " response = f(*args, **kwargs)\n",
347
+ " File \"/tmp/ipykernel_848/2152623987.py\", line 5, in ctc_read\n",
348
+ " imagefin = fit_picture(image)\n",
349
+ " File \"/tmp/ipykernel_848/2382022948.py\", line 19, in fit_picture\n",
350
+ " aspect_ratio = img.width / img.height\n",
351
+ "AttributeError: 'NoneType' object has no attribute 'width'\n",
352
+ "Traceback (most recent call last):\n",
353
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 536, in process_events\n",
354
+ " response = await route_utils.call_process_api(\n",
355
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 288, in call_process_api\n",
356
+ " output = await app.get_blocks().process_api(\n",
357
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1931, in process_api\n",
358
+ " result = await self.call_function(\n",
359
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1516, in call_function\n",
360
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
361
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
362
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
363
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 2177, in run_sync_in_worker_thread\n",
364
+ " return await future\n",
365
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 859, in run\n",
366
+ " result = context.run(func, *args)\n",
367
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/utils.py\", line 826, in wrapper\n",
368
+ " response = f(*args, **kwargs)\n",
369
+ " File \"/tmp/ipykernel_848/2152623987.py\", line 5, in ctc_read\n",
370
+ " imagefin = fit_picture(image)\n",
371
+ " File \"/tmp/ipykernel_848/2382022948.py\", line 19, in fit_picture\n",
372
+ " aspect_ratio = img.width / img.height\n",
373
+ "AttributeError: 'NoneType' object has no attribute 'width'\n"
374
+ ]
375
+ },
376
+ {
377
+ "name": "stdout",
378
+ "output_type": "stream",
379
+ "text": [
380
+ "torch.Size([1, 1, 32, 400])\n"
381
+ ]
382
+ },
383
+ {
384
+ "name": "stderr",
385
+ "output_type": "stream",
386
+ "text": [
387
+ "/home/obai33/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
388
+ " return F.conv2d(input, weight, bias, self.stride,\n"
389
+ ]
390
+ },
391
+ {
392
+ "name": "stdout",
393
+ "output_type": "stream",
394
+ "text": [
395
+ "torch.Size([1, 1, 32, 400])\n",
396
+ "torch.Size([1, 1, 32, 400])\n"
397
+ ]
398
+ },
399
+ {
400
+ "name": "stderr",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "/home/obai33/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
404
+ " return F.conv2d(input, weight, bias, self.stride,\n"
405
+ ]
406
+ },
407
+ {
408
+ "name": "stdout",
409
+ "output_type": "stream",
410
+ "text": [
411
+ "torch.Size([1, 1, 32, 400])\n",
412
+ "torch.Size([1, 1, 32, 400])\n"
413
+ ]
414
+ },
415
+ {
416
+ "name": "stderr",
417
+ "output_type": "stream",
418
+ "text": [
419
+ "/home/obai33/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
420
+ " return F.conv2d(input, weight, bias, self.stride,\n"
421
+ ]
422
+ },
423
+ {
424
+ "name": "stdout",
425
+ "output_type": "stream",
426
+ "text": [
427
+ "torch.Size([1, 1, 32, 400])\n",
428
+ "torch.Size([1, 1, 32, 400])\n",
429
+ "torch.Size([1, 1, 32, 400])\n"
430
+ ]
431
+ },
432
+ {
433
+ "name": "stderr",
434
+ "output_type": "stream",
435
+ "text": [
436
+ "Traceback (most recent call last):\n",
437
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 536, in process_events\n",
438
+ " response = await route_utils.call_process_api(\n",
439
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 288, in call_process_api\n",
440
+ " output = await app.get_blocks().process_api(\n",
441
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1931, in process_api\n",
442
+ " result = await self.call_function(\n",
443
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1516, in call_function\n",
444
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
445
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
446
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
447
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 2177, in run_sync_in_worker_thread\n",
448
+ " return await future\n",
449
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 859, in run\n",
450
+ " result = context.run(func, *args)\n",
451
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/utils.py\", line 826, in wrapper\n",
452
+ " response = f(*args, **kwargs)\n",
453
+ " File \"/tmp/ipykernel_848/2152623987.py\", line 5, in ctc_read\n",
454
+ " imagefin = fit_picture(image)\n",
455
+ " File \"/tmp/ipykernel_848/2382022948.py\", line 19, in fit_picture\n",
456
+ " aspect_ratio = img.width / img.height\n",
457
+ "AttributeError: 'NoneType' object has no attribute 'width'\n"
458
+ ]
459
+ },
460
+ {
461
+ "name": "stdout",
462
+ "output_type": "stream",
463
+ "text": [
464
+ "torch.Size([1, 1, 32, 400])\n"
465
+ ]
466
+ },
467
+ {
468
+ "name": "stderr",
469
+ "output_type": "stream",
470
+ "text": [
471
+ "/home/obai33/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
472
+ " return F.conv2d(input, weight, bias, self.stride,\n",
473
+ "Traceback (most recent call last):\n",
474
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/queueing.py\", line 536, in process_events\n",
475
+ " response = await route_utils.call_process_api(\n",
476
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/route_utils.py\", line 288, in call_process_api\n",
477
+ " output = await app.get_blocks().process_api(\n",
478
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1931, in process_api\n",
479
+ " result = await self.call_function(\n",
480
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/blocks.py\", line 1516, in call_function\n",
481
+ " prediction = await anyio.to_thread.run_sync( # type: ignore\n",
482
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/to_thread.py\", line 56, in run_sync\n",
483
+ " return await get_async_backend().run_sync_in_worker_thread(\n",
484
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 2177, in run_sync_in_worker_thread\n",
485
+ " return await future\n",
486
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py\", line 859, in run\n",
487
+ " result = context.run(func, *args)\n",
488
+ " File \"/home/obai33/.local/lib/python3.10/site-packages/gradio/utils.py\", line 826, in wrapper\n",
489
+ " response = f(*args, **kwargs)\n",
490
+ " File \"/tmp/ipykernel_848/2152623987.py\", line 5, in ctc_read\n",
491
+ " imagefin = fit_picture(image)\n",
492
+ " File \"/tmp/ipykernel_848/2382022948.py\", line 19, in fit_picture\n",
493
+ " aspect_ratio = img.width / img.height\n",
494
+ "AttributeError: 'NoneType' object has no attribute 'width'\n"
495
+ ]
496
+ },
497
+ {
498
+ "name": "stdout",
499
+ "output_type": "stream",
500
+ "text": [
501
+ "torch.Size([1, 1, 32, 400])\n"
502
+ ]
503
+ },
504
+ {
505
+ "name": "stderr",
506
+ "output_type": "stream",
507
+ "text": [
508
+ "/home/obai33/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
509
+ " return F.conv2d(input, weight, bias, self.stride,\n"
510
+ ]
511
+ }
512
+ ],
513
+ "source": [
514
+ "load_model('model.pt')\n",
515
+ "recognizer.eval()\n",
516
+ "\n",
517
+ "def ctc_read(image):\n",
518
+ " imagefin = fit_picture(image)\n",
519
+ " image_tensor = imagefin.unsqueeze(0).to(DEVICE)\n",
520
+ " print(image_tensor.size())\n",
521
+ " \n",
522
+ " with torch.no_grad():\n",
523
+ " scores = recognizer(image_tensor)\n",
524
+ "\n",
525
+ " predictions = scores.argmax(2).cpu().numpy()\n",
526
+ "\n",
527
+ " decoded_sequences = ctc_decode(predictions)\n",
528
+ "\n",
529
+ " # Convert decoded sequences to text\n",
530
+ " for i in decoded_sequences:\n",
531
+ " decoded_text = decode_text(i)\n",
532
+ "\n",
533
+ " return decoded_text\n",
534
+ "\n",
535
+ "\n",
536
+ "# Gradio Interface\n",
537
+ "iface = gr.Interface(\n",
538
+ " fn=ctc_read,\n",
539
+ " inputs=gr.Image(type=\"pil\"), # PIL Image input\n",
540
+ " outputs=\"text\", # Text output\n",
541
+ " title=\"Handwritten Text Recognition\",\n",
542
+ " description=\"Upload an image, and the custome AI will extract the text.\"\n",
543
+ ")\n",
544
+ "\n",
545
+ "iface.launch(share=True)\n"
546
+ ]
547
+ }
548
+ ],
549
+ "metadata": {
550
+ "kernelspec": {
551
+ "display_name": "Python 3 (ipykernel)",
552
+ "language": "python",
553
+ "name": "python3"
554
+ },
555
+ "language_info": {
556
+ "codemirror_mode": {
557
+ "name": "ipython",
558
+ "version": 3
559
+ },
560
+ "file_extension": ".py",
561
+ "mimetype": "text/x-python",
562
+ "name": "python",
563
+ "nbconvert_exporter": "python",
564
+ "pygments_lexer": "ipython3",
565
+ "version": "3.10.12"
566
+ }
567
+ },
568
+ "nbformat": 4,
569
+ "nbformat_minor": 5
570
+ }